burberry-vision / app.py
justinj92's picture
Update app.py
deffbd8 verified
raw
history blame
No virus
4.85 kB
import spaces
import os
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig, AutoProcessor
import gradio as gr
from threading import Thread
from PIL import Image
import subprocess
# Install flash-attn if not already installed
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# Model and tokenizer for the chatbot
MODEL_ID1 = "justinj92/phi-35-vision-burberry"
MODEL_LIST1 = ["justinj92/phi-35-vision-burberry"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
device = "cuda" if torch.cuda.is_available() else "cpu" # for GPU usage or "cpu" for CPU usage / But you need GPU :)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID1, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID1,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=quantization_config,
trust_remote_code=True
)
# Vision model setup
models = {
"justinj92/phi-35-vision-burberry": AutoModelForCausalLM.from_pretrained("justinj92/phi-35-vision-burberry", trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2").cuda().eval()
}
processors = {
"justinj92/phi-35-vision-burberry": AutoProcessor.from_pretrained("justinj92/phi-35-vision-burberry", trust_remote_code=True)
}
user_prompt = '\n'
assistant_prompt = '\n'
prompt_suffix = "\n"
# Vision model tab function
@spaces.GPU()
def stream_vision(image, model_id="justinj92/phi-35-vision-burberry"):
model = models[model_id]
processor = processors[model_id]
text_input="What is shown in this image?"
# Prepare the image list and corresponding tags
images = [Image.fromarray(image).convert("RGB")]
placeholder = "<|image_1|>\n" # Using the image tag as per the example
# Construct the prompt with the image tag and the user's text input
if text_input:
prompt_content = placeholder + text_input
else:
prompt_content = placeholder
messages = [
{"role": "user", "content": prompt_content},
]
# Apply the chat template to the messages
prompt = processor.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Process the inputs with the processor
inputs = processor(prompt, images, return_tensors="pt").to("cuda:0")
# Generation parameters
generation_args = {
"max_new_tokens": 2000,
"temperature": 0.0,
"do_sample": False,
}
# Generate the response
generate_ids = model.generate(
**inputs,
eos_token_id=processor.tokenizer.eos_token_id,
**generation_args
)
# Remove input tokens from the generated response
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
# Decode the generated output
response = processor.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
return response
# CSS for the interface
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
TITLE = "<h1><center>Burberry Product Categorizer</center></h1>"
EXPLANATION = """
<div style="text-align: center; margin-top: 20px;">
<p>App uses Microsoft Phi 3.5 Vision Model</p>
<p>Fine-Tuned version is built using open Burberry Product dataset.</p>
</div>
"""
footer = """
<div style="text-align: center; margin-top: 20px;">
<a href="https://www.linkedin.com/in/justin-j-4a77456b/" target="_blank">LinkedIn</a>
<br>
Made with πŸ’– by Justin J
</div>
"""
# Gradio app with two tabs
with gr.Blocks(css=CSS, theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.pink)) as demo:
gr.HTML(TITLE)
gr.HTML(EXPLANATION)
with gr.Tab("Burberry Vision"):
with gr.Row():
input_img = gr.Image(label="Upload a Burberry Product Image")
with gr.Row():
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="justinj92/phi-35-vision-burberry")
# with gr.Row():
# text_input = gr.Textbox(label="Question")
with gr.Row():
submit_btn = gr.Button(value="Tell me about this product")
with gr.Row():
output_text = gr.Textbox(label="Product Info")
submit_btn.click(stream_vision, [input_img, model_selector], [output_text])
gr.HTML(footer)
# Launch the combined app
demo.launch(debug=True)