prithivMLmods commited on
Commit
29259b5
Β·
verified Β·
1 Parent(s): e63c665

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -109
app.py CHANGED
@@ -1,144 +1,101 @@
1
  import gradio as gr
2
- import spaces
3
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
4
- from qwen_vl_utils import process_vision_info
5
- import torch
6
- from PIL import Image
7
- import subprocess
8
- import numpy as np
9
- import os
10
  from threading import Thread
11
- import uuid
12
- import io
 
13
 
14
- # Model and Processor Loading (Done once at startup)
15
  MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
 
16
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
17
  MODEL_ID,
18
  trust_remote_code=True,
19
  torch_dtype=torch.float16
20
  ).to("cuda").eval()
21
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
22
-
23
- DESCRIPTION = "# **Qwen2.5-VL-3B-Instruct**"
24
-
25
- image_extensions = Image.registered_extensions()
26
- video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
27
-
28
-
29
- def identify_and_save_blob(blob_path):
30
- """Identifies if the blob is an image or video and saves it accordingly."""
31
- try:
32
- with open(blob_path, 'rb') as file:
33
- blob_content = file.read()
34
-
35
- # Try to identify if it's an image
36
- try:
37
- Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
38
- extension = ".png" # Default to PNG for saving
39
- media_type = "image"
40
- except (IOError, SyntaxError):
41
- # If it's not a valid image, assume it's a video
42
- extension = ".mp4" # Default to MP4 for saving
43
- media_type = "video"
44
-
45
- # Create a unique filename
46
- filename = f"temp_{uuid.uuid4()}_media{extension}"
47
- with open(filename, "wb") as f:
48
- f.write(blob_content)
49
-
50
- return filename, media_type
51
-
52
- except FileNotFoundError:
53
- raise ValueError(f"The file {blob_path} was not found.")
54
- except Exception as e:
55
- raise ValueError(f"An error occurred while processing the file: {e}")
56
-
57
 
58
  @spaces.GPU
59
- def qwen_inference(media_input, text_input=None):
60
- if isinstance(media_input, str): # If it's a filepath
61
- media_path = media_input
62
- if media_path.endswith(tuple([i for i, f in image_extensions.items()])):
63
- media_type = "image"
64
- elif media_path.endswith(video_extensions):
65
- media_type = "video"
66
- else:
67
- try:
68
- media_path, media_type = identify_and_save_blob(media_input)
69
- print(media_path, media_type)
70
- except Exception as e:
71
- print(e)
72
- raise ValueError(
73
- "Unsupported media type. Please upload an image or video."
74
- )
75
-
76
-
77
- print(media_path)
78
-
 
79
  messages = [
80
  {
81
  "role": "user",
82
  "content": [
83
- {
84
- "type": media_type,
85
- media_type: media_path,
86
- **({"fps": 8.0} if media_type == "video" else {}),
87
- },
88
- {"type": "text", "text": text_input},
89
  ],
90
  }
91
  ]
92
 
93
- text = processor.apply_chat_template(
94
- messages, tokenize=False, add_generation_prompt=True
95
- )
96
- image_inputs, video_inputs = process_vision_info(messages)
97
  inputs = processor(
98
- text=[text],
99
- images=image_inputs,
100
- videos=video_inputs,
101
- padding=True,
102
  return_tensors="pt",
 
103
  ).to("cuda")
104
 
105
- streamer = TextIteratorStreamer(
106
- processor, skip_prompt=True, **{"skip_special_tokens": True}
107
- )
108
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
109
 
 
110
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
111
  thread.start()
112
 
 
113
  buffer = ""
 
114
  for new_text in streamer:
115
  buffer += new_text
 
116
  yield buffer
117
 
118
- css = """
119
- #output {
120
- height: 500px;
121
- overflow: auto;
122
- border: 1px solid #ccc;
123
- }
124
- """
125
-
126
- with gr.Blocks(css=css) as demo:
127
- gr.Markdown(DESCRIPTION)
128
-
129
- with gr.Tab(label="Image/Video Input"):
130
- with gr.Row():
131
- with gr.Column():
132
- input_media = gr.File(
133
- label="Upload Image or Video", type="filepath"
134
- )
135
- text_input = gr.Textbox(label="Question")
136
- submit_btn = gr.Button(value="Submit")
137
- with gr.Column():
138
- output_text = gr.Textbox(label="Output Text")
139
-
140
- submit_btn.click(
141
- qwen_inference, [input_media, text_input], [output_text]
142
- )
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  demo.launch(debug=True)
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
3
+ from transformers.image_utils import load_image
 
 
 
 
 
 
4
  from threading import Thread
5
+ import time
6
+ import torch
7
+ import spaces
8
 
9
+ # Load the Qwen2.5-VL-3B-Instruct model and processor
10
  MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
11
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
12
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
13
  MODEL_ID,
14
  trust_remote_code=True,
15
  torch_dtype=torch.float16
16
  ).to("cuda").eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  @spaces.GPU
19
+ def model_inference(input_dict, history):
20
+ text = input_dict["text"]
21
+ files = input_dict["files"]
22
+
23
+ # Load images if provided
24
+ if len(files) > 1:
25
+ images = [load_image(image) for image in files]
26
+ elif len(files) == 1:
27
+ images = [load_image(files[0])]
28
+ else:
29
+ images = []
30
+
31
+ # Validate input
32
+ if text == "" and not images:
33
+ gr.Error("Please input a query and optionally image(s).")
34
+ return
35
+ if text == "" and images:
36
+ gr.Error("Please input a text query along with the image(s).")
37
+ return
38
+
39
+ # Prepare messages for the model
40
  messages = [
41
  {
42
  "role": "user",
43
  "content": [
44
+ *[{"type": "image", "image": image} for image in images],
45
+ {"type": "text", "text": text},
 
 
 
 
46
  ],
47
  }
48
  ]
49
 
50
+ # Apply chat template and process inputs
51
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
52
  inputs = processor(
53
+ text=[prompt],
54
+ images=images if images else None,
 
 
55
  return_tensors="pt",
56
+ padding=True,
57
  ).to("cuda")
58
 
59
+ # Set up streamer for real-time output
60
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
 
61
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
62
 
63
+ # Start generation in a separate thread
64
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
65
  thread.start()
66
 
67
+ # Stream the output
68
  buffer = ""
69
+ yield "..."
70
  for new_text in streamer:
71
  buffer += new_text
72
+ time.sleep(0.01)
73
  yield buffer
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ # Example inputs
77
+ examples = [
78
+ [{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}],
79
+ [{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
80
+ [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
81
+ [{"text": "What art era do these artpieces belong to?", "files": ["example_images/rococo.jpg", "example_images/rococo_1.jpg"]}],
82
+ [{"text": "Describe this image.", "files": ["example_images/campeones.jpg"]}],
83
+ [{"text": "What does this say?", "files": ["example_images/math.jpg"]}],
84
+ [{"text": "What is the date in this document?", "files": ["example_images/document.jpg"]}],
85
+ [{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}],
86
+ ]
87
+
88
+ # Gradio interface
89
+ demo = gr.ChatInterface(
90
+ fn=model_inference,
91
+ title="# **Qwen2.5-VL-3B-Instruc**",
92
+ description="Interact with [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) in this demo. Upload an image and text, or try one of the examples. Each chat starts a new conversation.",
93
+ examples=examples,
94
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
95
+ stop_btn="Stop Generation",
96
+ multimodal=True,
97
+ cache_examples=False,
98
+ )
99
+
100
+ # Launch the demo
101
  demo.launch(debug=True)