Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import os
|
2 |
from collections.abc import Iterator
|
3 |
from threading import Thread
|
4 |
-
|
5 |
import gradio as gr
|
6 |
import spaces
|
7 |
import torch
|
@@ -11,6 +10,7 @@ from PIL import Image
|
|
11 |
import uuid
|
12 |
import io
|
13 |
|
|
|
14 |
DESCRIPTION = """
|
15 |
# GWQ PREV
|
16 |
"""
|
@@ -21,25 +21,24 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
|
21 |
|
22 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
gemma_model_id,
|
29 |
device_map="auto",
|
30 |
torch_dtype=torch.bfloat16,
|
31 |
)
|
32 |
-
|
33 |
-
|
34 |
|
35 |
-
#
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
trust_remote_code=True,
|
40 |
torch_dtype=torch.float16
|
41 |
).to("cuda").eval()
|
42 |
-
|
43 |
|
44 |
image_extensions = Image.registered_extensions()
|
45 |
video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
|
@@ -72,19 +71,6 @@ def identify_and_save_blob(blob_path):
|
|
72 |
except Exception as e:
|
73 |
raise ValueError(f"An error occurred while processing the file: {e}")
|
74 |
|
75 |
-
def process_vision_info(messages):
|
76 |
-
"""Processes vision information (images or videos) from messages."""
|
77 |
-
image_inputs = []
|
78 |
-
video_inputs = []
|
79 |
-
for message in messages:
|
80 |
-
for content in message["content"]:
|
81 |
-
if content["type"] == "image":
|
82 |
-
image = Image.open(content["image"])
|
83 |
-
image_inputs.append(image)
|
84 |
-
elif content["type"] == "video":
|
85 |
-
video_inputs.append(content["video"])
|
86 |
-
return image_inputs, video_inputs
|
87 |
-
|
88 |
@spaces.GPU()
|
89 |
def generate(
|
90 |
message: str,
|
@@ -94,26 +80,21 @@ def generate(
|
|
94 |
top_p: float = 0.9,
|
95 |
top_k: int = 50,
|
96 |
repetition_penalty: float = 1.2,
|
97 |
-
|
98 |
) -> Iterator[str]:
|
99 |
-
if
|
100 |
-
#
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
print(e)
|
113 |
-
raise ValueError(
|
114 |
-
"Unsupported media type. Please upload an image or video."
|
115 |
-
)
|
116 |
-
|
117 |
messages = [
|
118 |
{
|
119 |
"role": "user",
|
@@ -128,11 +109,11 @@ def generate(
|
|
128 |
}
|
129 |
]
|
130 |
|
131 |
-
text =
|
132 |
messages, tokenize=False, add_generation_prompt=True
|
133 |
)
|
134 |
image_inputs, video_inputs = process_vision_info(messages)
|
135 |
-
inputs =
|
136 |
text=[text],
|
137 |
images=image_inputs,
|
138 |
videos=video_inputs,
|
@@ -141,11 +122,11 @@ def generate(
|
|
141 |
).to("cuda")
|
142 |
|
143 |
streamer = TextIteratorStreamer(
|
144 |
-
|
145 |
)
|
146 |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
|
147 |
|
148 |
-
thread = Thread(target=
|
149 |
thread.start()
|
150 |
|
151 |
buffer = ""
|
@@ -153,17 +134,17 @@ def generate(
|
|
153 |
buffer += new_text
|
154 |
yield buffer
|
155 |
else:
|
156 |
-
#
|
157 |
conversation = chat_history.copy()
|
158 |
conversation.append({"role": "user", "content": message})
|
159 |
|
160 |
-
input_ids =
|
161 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
162 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
163 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
164 |
-
input_ids = input_ids.to(
|
165 |
|
166 |
-
streamer = TextIteratorStreamer(
|
167 |
generate_kwargs = dict(
|
168 |
{"input_ids": input_ids},
|
169 |
streamer=streamer,
|
@@ -175,7 +156,7 @@ def generate(
|
|
175 |
num_beams=1,
|
176 |
repetition_penalty=repetition_penalty,
|
177 |
)
|
178 |
-
t = Thread(target=
|
179 |
t.start()
|
180 |
|
181 |
outputs = []
|
@@ -183,72 +164,61 @@ def generate(
|
|
183 |
outputs.append(text)
|
184 |
yield "".join(outputs)
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
["Explain the plot of Cinderella in a sentence."],
|
242 |
-
["How many hours does it take a man to eat a Helicopter?"],
|
243 |
-
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
|
244 |
-
],
|
245 |
-
cache_examples=False,
|
246 |
-
type="messages",
|
247 |
-
description=DESCRIPTION,
|
248 |
-
css_paths="style.css",
|
249 |
-
fill_height=True,
|
250 |
-
textbox=gr.MultimodalTextbox(),
|
251 |
-
multimodal=True,
|
252 |
-
)
|
253 |
|
254 |
-
|
|
|
|
1 |
import os
|
2 |
from collections.abc import Iterator
|
3 |
from threading import Thread
|
|
|
4 |
import gradio as gr
|
5 |
import spaces
|
6 |
import torch
|
|
|
10 |
import uuid
|
11 |
import io
|
12 |
|
13 |
+
# Text-only model setup
|
14 |
DESCRIPTION = """
|
15 |
# GWQ PREV
|
16 |
"""
|
|
|
21 |
|
22 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
23 |
|
24 |
+
model_id = "prithivMLmods/GWQ2b"
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
26 |
+
model = AutoModelForCausalLM.from_pretrained(
|
27 |
+
model_id,
|
|
|
28 |
device_map="auto",
|
29 |
torch_dtype=torch.bfloat16,
|
30 |
)
|
31 |
+
model.config.sliding_window = 4096
|
32 |
+
model.eval()
|
33 |
|
34 |
+
# Multimodal model setup
|
35 |
+
MULTIMODAL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
36 |
+
multimodal_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
37 |
+
MULTIMODAL_MODEL_ID,
|
38 |
trust_remote_code=True,
|
39 |
torch_dtype=torch.float16
|
40 |
).to("cuda").eval()
|
41 |
+
multimodal_processor = AutoProcessor.from_pretrained(MULTIMODAL_MODEL_ID, trust_remote_code=True)
|
42 |
|
43 |
image_extensions = Image.registered_extensions()
|
44 |
video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
|
|
|
71 |
except Exception as e:
|
72 |
raise ValueError(f"An error occurred while processing the file: {e}")
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
@spaces.GPU()
|
75 |
def generate(
|
76 |
message: str,
|
|
|
80 |
top_p: float = 0.9,
|
81 |
top_k: int = 50,
|
82 |
repetition_penalty: float = 1.2,
|
83 |
+
files: list = None,
|
84 |
) -> Iterator[str]:
|
85 |
+
if files and len(files) > 0:
|
86 |
+
# Multimodal input
|
87 |
+
media_path = files[0]
|
88 |
+
if media_path.endswith(tuple([i for i, f in image_extensions.items()])):
|
89 |
+
media_type = "image"
|
90 |
+
elif media_path.endswith(video_extensions):
|
91 |
+
media_type = "video"
|
92 |
+
else:
|
93 |
+
try:
|
94 |
+
media_path, media_type = identify_and_save_blob(media_path)
|
95 |
+
except Exception as e:
|
96 |
+
raise ValueError("Unsupported media type. Please upload an image or video.")
|
97 |
+
|
|
|
|
|
|
|
|
|
|
|
98 |
messages = [
|
99 |
{
|
100 |
"role": "user",
|
|
|
109 |
}
|
110 |
]
|
111 |
|
112 |
+
text = multimodal_processor.apply_chat_template(
|
113 |
messages, tokenize=False, add_generation_prompt=True
|
114 |
)
|
115 |
image_inputs, video_inputs = process_vision_info(messages)
|
116 |
+
inputs = multimodal_processor(
|
117 |
text=[text],
|
118 |
images=image_inputs,
|
119 |
videos=video_inputs,
|
|
|
122 |
).to("cuda")
|
123 |
|
124 |
streamer = TextIteratorStreamer(
|
125 |
+
multimodal_processor, skip_prompt=True, **{"skip_special_tokens": True}
|
126 |
)
|
127 |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
|
128 |
|
129 |
+
thread = Thread(target=multimodal_model.generate, kwargs=generation_kwargs)
|
130 |
thread.start()
|
131 |
|
132 |
buffer = ""
|
|
|
134 |
buffer += new_text
|
135 |
yield buffer
|
136 |
else:
|
137 |
+
# Text-only input
|
138 |
conversation = chat_history.copy()
|
139 |
conversation.append({"role": "user", "content": message})
|
140 |
|
141 |
+
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
142 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
143 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
144 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
145 |
+
input_ids = input_ids.to(model.device)
|
146 |
|
147 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
148 |
generate_kwargs = dict(
|
149 |
{"input_ids": input_ids},
|
150 |
streamer=streamer,
|
|
|
156 |
num_beams=1,
|
157 |
repetition_penalty=repetition_penalty,
|
158 |
)
|
159 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
160 |
t.start()
|
161 |
|
162 |
outputs = []
|
|
|
164 |
outputs.append(text)
|
165 |
yield "".join(outputs)
|
166 |
|
167 |
+
demo = gr.ChatInterface(
|
168 |
+
fn=generate,
|
169 |
+
additional_inputs=[
|
170 |
+
gr.Slider(
|
171 |
+
label="Max new tokens",
|
172 |
+
minimum=1,
|
173 |
+
maximum=MAX_MAX_NEW_TOKENS,
|
174 |
+
step=1,
|
175 |
+
value=DEFAULT_MAX_NEW_TOKENS,
|
176 |
+
),
|
177 |
+
gr.Slider(
|
178 |
+
label="Temperature",
|
179 |
+
minimum=0.1,
|
180 |
+
maximum=4.0,
|
181 |
+
step=0.1,
|
182 |
+
value=0.6,
|
183 |
+
),
|
184 |
+
gr.Slider(
|
185 |
+
label="Top-p (nucleus sampling)",
|
186 |
+
minimum=0.05,
|
187 |
+
maximum=1.0,
|
188 |
+
step=0.05,
|
189 |
+
value=0.9,
|
190 |
+
),
|
191 |
+
gr.Slider(
|
192 |
+
label="Top-k",
|
193 |
+
minimum=1,
|
194 |
+
maximum=1000,
|
195 |
+
step=1,
|
196 |
+
value=50,
|
197 |
+
),
|
198 |
+
gr.Slider(
|
199 |
+
label="Repetition penalty",
|
200 |
+
minimum=1.0,
|
201 |
+
maximum=2.0,
|
202 |
+
step=0.05,
|
203 |
+
value=1.2,
|
204 |
+
),
|
205 |
+
],
|
206 |
+
stop_btn=None,
|
207 |
+
examples=[
|
208 |
+
["Hello there! How are you doing?"],
|
209 |
+
["Can you explain briefly to me what is the Python programming language?"],
|
210 |
+
["Explain the plot of Cinderella in a sentence."],
|
211 |
+
["How many hours does it take a man to eat a Helicopter?"],
|
212 |
+
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
|
213 |
+
],
|
214 |
+
cache_examples=False,
|
215 |
+
type="messages",
|
216 |
+
description=DESCRIPTION,
|
217 |
+
css_paths="style.css",
|
218 |
+
fill_height=True,
|
219 |
+
multimodal=True,
|
220 |
+
textbox=gr.MultimodalTextbox(),
|
221 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
+
if __name__ == "__main__":
|
224 |
+
demo.queue(max_size=20).launch()
|