|
import gradio as gr |
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
import spaces |
|
|
|
import io |
|
from PIL import Image |
|
import base64 |
|
import subprocess |
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
model_id = 'J-LAB/Florence-vl3' |
|
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval() |
|
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) |
|
|
|
DESCRIPTION = "# Product Describe by Fluxi IA\n### Base Model [Florence-2] (https://huggingface.co./microsoft/Florence-2-large)" |
|
|
|
@spaces.GPU |
|
def run_example(task_prompt, image): |
|
inputs = processor(text=task_prompt, images=image, return_tensors="pt").to("cuda") |
|
generated_ids = model.generate( |
|
input_ids=inputs["input_ids"], |
|
pixel_values=inputs["pixel_values"], |
|
max_new_tokens=1024, |
|
early_stopping=False, |
|
do_sample=False, |
|
num_beams=3, |
|
) |
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
|
parsed_answer = processor.post_process_generation( |
|
generated_text, |
|
task=task_prompt, |
|
image_size=(image.width, image.height) |
|
) |
|
return parsed_answer |
|
|
|
def process_image(image, task_prompt): |
|
if isinstance(image, str): |
|
if image.startswith('data:image/png;base64,'): |
|
|
|
image_data = base64.b64decode(image.split(',')[1]) |
|
image = Image.open(io.BytesIO(image_data)) |
|
|
|
image = Image.fromarray(image) |
|
if task_prompt == 'Product Caption': |
|
task_prompt = '<MORE_DETAILED_CAPTION>' |
|
elif task_prompt == 'OCR': |
|
task_prompt = '<OCR>' |
|
|
|
results = run_example(task_prompt, image) |
|
|
|
|
|
if results and task_prompt in results: |
|
output_text = results[task_prompt] |
|
else: |
|
output_text = "" |
|
|
|
|
|
output_text = output_text.replace("\n\n", "<br><br>").replace("\n", "<br>") |
|
|
|
return output_text |
|
|
|
css = """ |
|
#output { |
|
overflow: auto; |
|
border: 1px solid #ccc; |
|
padding: 10px; |
|
background-color: rgb(31 41 55); |
|
color: #fff; |
|
} |
|
""" |
|
|
|
js = """ |
|
function adjustHeight() { |
|
var outputElement = document.getElementById('output'); |
|
outputElement.style.height = 'auto'; // Reset height to auto to get the actual content height |
|
var height = outputElement.scrollHeight + 'px'; // Get the scrollHeight |
|
outputElement.style.height = height; // Set the height |
|
} |
|
|
|
// Attach the adjustHeight function to the click event of the submit button |
|
document.querySelector('button').addEventListener('click', function() { |
|
setTimeout(adjustHeight, 500); // Adjust the height after a small delay to ensure content is loaded |
|
}); |
|
""" |
|
|
|
single_task_list =[ |
|
'Product Caption', 'OCR' |
|
] |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown(DESCRIPTION) |
|
with gr.Tab(label="Product Image Select"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_img = gr.Image(label="Input Picture") |
|
task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Product Caption") |
|
submit_btn = gr.Button(value="Submit") |
|
with gr.Column(): |
|
output_text = gr.HTML(label="Output Text", elem_id="output") |
|
|
|
gr.Markdown(""" |
|
## How to use via API |
|
To use this model via API, you can follow the example code below: |
|
|
|
|
|
python |
|
!pip install gradio_client |
|
from gradio_client import Client, handle_file |
|
|
|
client = Client("J-LAB/Fluxi-IA") |
|
result = client.predict( |
|
image=handle_file('https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png'), |
|
api_name="/process_image" |
|
) |
|
print(result) |
|
|
|
""") |
|
|
|
submit_btn.click(process_image, [input_img, task_prompt], [output_text]) |
|
|
|
demo.load(lambda: None, inputs=None, outputs=None, js=js) |
|
|
|
demo.launch(debug=True) |