Product Catalog Generator
Collection
Product Catalog Generator for Persian products which is hosted by Basalam
•
7 items
•
Updated
•
8
language:
This model is "llava-hf/llava-1.5-7b-hf"
, fine-tuned on "Basalam product"
data for extracting visual attributes of products. The outputs are in JSON format and can be parsed.
Below is an example script to run generation in float16
precision on a GPU device:
import requests
from PIL import Image
import torch
import json
from transformers import AutoProcessor, LlavaForConditionalGeneration
model_id = "BaSalam/Llava-1.5-7b-hf-bslm-product-attributes-v0"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(0)
processor = AutoProcessor.from_pretrained(model_id)
def prompt_formatter(entity):
json_format = """attributes': {'attribute_name_1' : <list of attribute values>, 'attribute_name_2': <list of attribute values>, ...}"""
final_prompt = f"""برای محصول داده شده، ویژگیهای تصویری محصول را در قالب جیسون (json) استخراج کن. ساختار JSON باید به این شکل باشد: {json_format}. محصول از یک بازار اینترنتی ایرانی است پس خروجی Json باید به زبان فارسی باشد.
محصول: '{entity}'."""
return final_prompt
prompt = prompt_formatter(entity='تیشرت مردانه')
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image_file = "https://statics.basalam.com/public-16/users/6eOEg/01-24/qJ34XziHu7Orp3GToVWTms1nKvCv0X86Ux7tQLtuRoyTXTxyQ4.jpg_800X800X70.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16)
output = model.generate(**inputs, max_new_tokens=384, do_sample=False)
generated_title = processor.decode(output[0], skip_special_tokens=True)[len(text.replace('<image>', ' ')):]
output = generated_title.replace('ASSISTANT: ', '')
json_output = json.loads(output)
print(json_output)
[
{
"attributes": {
"نوع": [
"تیشرت مردانه"
],
"طرح چاپی": [
"MVP"
],
"رنگ": [
"زرد",
"آبی",
"سفید",
"مشکی",
"کرم",
"سبز"
],
"سایز": [
"L",
"XL",
"2XL",
"3XL"
]
}
}
]
bitsandbytes
library
First make sure to install bitsandbytes
, pip install bitsandbytes
and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with:
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
+ load_in_4bit=True
)
First make sure to install flash-attn
. Refer to the original repository of Flash Attention regarding that package installation. Simply change the snippet above with:
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
+ use_flash_attention_2=True
).to(0)
Base model
llava-hf/llava-1.5-7b-hf