metadata
library_name: transformers
license: apache-2.0
language:
- fa
base_model: llava-hf/llava-1.5-7b-hf
language:
- fa datasets:
- BaSalam/vision-catalogs-llava-format-v3 pipeline_tag: image-text-to-text
LLaVA Model Card
Model details
This model is "llava-hf/llava-1.5-7b-hf"
, fine-tuned on "Basalam product"
data for extracting visual attributes of products. The outputs are in JSON format and can be parsed.
How to use the model
Below is an example script to run generation in float16
precision on a GPU device:
import requests
from PIL import Image
import torch
import json
from transformers import AutoProcessor, LlavaForConditionalGeneration
model_id = "BaSalam/Llava-1.5-7b-hf-bslm-product-attributes-v0"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(0)
processor = AutoProcessor.from_pretrained(model_id)
def prompt_formatter(entity):
json_format = """attributes': {'attribute_name_1' : <list of attribute values>, 'attribute_name_2': <list of attribute values>, ...}"""
final_prompt = f"""برای محصول داده شده، ویژگیهای تصویری محصول را در قالب جیسون (json) استخراج کن. ساختار JSON باید به این شکل باشد: {json_format}. محصول از یک بازار اینترنتی ایرانی است پس خروجی Json باید به زبان فارسی باشد.
محصول: '{entity}'."""
return final_prompt
prompt = prompt_formatter(entity='تیشرت مردانه')
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image_file = "https://statics.basalam.com/public-16/users/6eOEg/01-24/qJ34XziHu7Orp3GToVWTms1nKvCv0X86Ux7tQLtuRoyTXTxyQ4.jpg_800X800X70.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16)
output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
generated_title = processor.decode(output[0], skip_special_tokens=True)[len(text.replace('<image>', ' ')):]
output = generated_title.replace('ASSISTANT: ', '')
json_output = json.loads(output)
print(json_output)
[
{
"attributes": {
"نوع": [
"تیشرت مردانه"
],
"طرح چاپی": [
"MVP"
],
"رنگ": [
"زرد",
"آبی",
"سفید",
"مشکی",
"کرم",
"سبز"
],
"سایز": [
"L",
"XL",
"2XL",
"3XL"
]
}
}
]
Model optimization
4-bit quantization through bitsandbytes
library
First make sure to install bitsandbytes
, pip install bitsandbytes
and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with:
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
+ load_in_4bit=True
)
Use Flash-Attention 2 to further speed-up generation
First make sure to install flash-attn
. Refer to the original repository of Flash Attention regarding that package installation. Simply change the snippet above with:
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
+ use_flash_attention_2=True
).to(0)