File size: 3,800 Bytes
f173c33
 
 
 
 
 
 
99e9bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f173c33
 
99e9bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f173c33
99e9bf3
 
 
 
 
 
 
 
 
 
 
f173c33
99e9bf3
 
 
 
 
f173c33
 
 
 
86a5e82
d5cce75
86a5e82
f173c33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99e9bf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
---
library_name: transformers
license: apache-2.0
language:
- fa
base_model: llava-hf/llava-1.5-7b-hf
---

language:
  - fa
datasets:
  - BaSalam/vision-catalogs-llava-format-v3
pipeline_tag: image-text-to-text

# LLaVA Model Card

## Model details
This model is [`"llava-hf/llava-1.5-7b-hf"`](https://huggingface.co./llava-hf/llava-1.5-7b-hf), fine-tuned on [`"Basalam product"`](https://huggingface.co./datasets/BaSalam/vision-catalogs-llava-format-v3) data for extracting visual attributes of products. The outputs are in JSON format and can be parsed.

## How to use the model
Below is an example script to run generation in `float16` precision on a GPU device:

```python
import requests
from PIL import Image
import torch
import json

from transformers import AutoProcessor, LlavaForConditionalGeneration
model_id = "BaSalam/Llava-1.5-7b-hf-bslm-product-attributes-v0"
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
).to(0)
processor = AutoProcessor.from_pretrained(model_id)

def prompt_formatter(entity):
    json_format = """attributes': {'attribute_name_1' : <list of attribute values>, 'attribute_name_2': <list of attribute values>, ...}"""
    final_prompt = f"""برای محصول داده شده، ویژگی‌های تصویری محصول را در قالب جیسون (json) استخراج کن. ساختار JSON باید به این شکل باشد: {json_format}. محصول از یک بازار اینترنتی ایرانی است پس خروجی Json باید به زبان فارسی باشد.
محصول: '{entity}'."""
    return final_prompt

prompt = prompt_formatter(entity='تیشرت مردانه')
conversation = [
    {
      "role": "user",
      "content": [
          {"type": "text", "text": prompt},
          {"type": "image"},
        ],
    },
]

prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image_file = "https://statics.basalam.com/public-16/users/6eOEg/01-24/qJ34XziHu7Orp3GToVWTms1nKvCv0X86Ux7tQLtuRoyTXTxyQ4.jpg_800X800X70.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)

inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16)

output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
generated_title = processor.decode(output[0], skip_special_tokens=True)[len(text.replace('<image>', ' ')):]
output = generated_title.replace('ASSISTANT: ', '')
json_output = json.loads(output)
print(json_output)
```

```
[
  {
    "attributes": {
      "نوع": [
        "تیشرت مردانه"
      ],
      "طرح چاپی": [
        "MVP"
      ],
      "رنگ": [
        "زرد",
        "آبی",
        "سفید",
        "مشکی",
        "کرم",
        "سبز"
      ],
      "سایز": [
        "L",
        "XL",
        "2XL",
        "3XL"
      ]
    }
  }
]
```

### Model optimization

#### 4-bit quantization through `bitsandbytes` library

First make sure to install `bitsandbytes`, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with: 

```diff
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True,
+   load_in_4bit=True
)
```

#### Use Flash-Attention 2 to further speed-up generation

First make sure to install `flash-attn`. Refer to the [original repository of Flash Attention](https://github.com/Dao-AILab/flash-attention) regarding that package installation. Simply change the snippet above with: 

```diff
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True,
+   use_flash_attention_2=True
).to(0)
```