prithivMLmods commited on
Commit
9e64b5c
Β·
verified Β·
1 Parent(s): 3c9b81c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -122
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  from collections.abc import Iterator
3
  from threading import Thread
4
-
5
  import gradio as gr
6
  import spaces
7
  import torch
@@ -11,6 +10,7 @@ from PIL import Image
11
  import uuid
12
  import io
13
 
 
14
  DESCRIPTION = """
15
  # GWQ PREV
16
  """
@@ -21,25 +21,24 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
21
 
22
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
 
24
- # Load Gemma model for text-only inputs
25
- gemma_model_id = "prithivMLmods/GWQ2b"
26
- gemma_tokenizer = AutoTokenizer.from_pretrained(gemma_model_id)
27
- gemma_model = AutoModelForCausalLM.from_pretrained(
28
- gemma_model_id,
29
  device_map="auto",
30
  torch_dtype=torch.bfloat16,
31
  )
32
- gemma_model.config.sliding_window = 4096
33
- gemma_model.eval()
34
 
35
- # Load Qwen model for multimodal inputs
36
- qwen_model_id = "Qwen/Qwen2-VL-2B-Instruct"
37
- qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
38
- qwen_model_id,
39
  trust_remote_code=True,
40
  torch_dtype=torch.float16
41
  ).to("cuda").eval()
42
- qwen_processor = AutoProcessor.from_pretrained(qwen_model_id, trust_remote_code=True)
43
 
44
  image_extensions = Image.registered_extensions()
45
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
@@ -72,19 +71,6 @@ def identify_and_save_blob(blob_path):
72
  except Exception as e:
73
  raise ValueError(f"An error occurred while processing the file: {e}")
74
 
75
- def process_vision_info(messages):
76
- """Processes vision information (images or videos) from messages."""
77
- image_inputs = []
78
- video_inputs = []
79
- for message in messages:
80
- for content in message["content"]:
81
- if content["type"] == "image":
82
- image = Image.open(content["image"])
83
- image_inputs.append(image)
84
- elif content["type"] == "video":
85
- video_inputs.append(content["video"])
86
- return image_inputs, video_inputs
87
-
88
  @spaces.GPU()
89
  def generate(
90
  message: str,
@@ -94,26 +80,21 @@ def generate(
94
  top_p: float = 0.9,
95
  top_k: int = 50,
96
  repetition_penalty: float = 1.2,
97
- media_input: str = None,
98
  ) -> Iterator[str]:
99
- if media_input:
100
- # Use Qwen model for multimodal inputs
101
- if isinstance(media_input, str): # If it's a filepath
102
- media_path = media_input
103
- if media_path.endswith(tuple([i for i, f in image_extensions.items()])):
104
- media_type = "image"
105
- elif media_path.endswith(video_extensions):
106
- media_type = "video"
107
- else:
108
- try:
109
- media_path, media_type = identify_and_save_blob(media_input)
110
- print(media_path, media_type)
111
- except Exception as e:
112
- print(e)
113
- raise ValueError(
114
- "Unsupported media type. Please upload an image or video."
115
- )
116
-
117
  messages = [
118
  {
119
  "role": "user",
@@ -128,11 +109,11 @@ def generate(
128
  }
129
  ]
130
 
131
- text = qwen_processor.apply_chat_template(
132
  messages, tokenize=False, add_generation_prompt=True
133
  )
134
  image_inputs, video_inputs = process_vision_info(messages)
135
- inputs = qwen_processor(
136
  text=[text],
137
  images=image_inputs,
138
  videos=video_inputs,
@@ -141,11 +122,11 @@ def generate(
141
  ).to("cuda")
142
 
143
  streamer = TextIteratorStreamer(
144
- qwen_processor, skip_prompt=True, **{"skip_special_tokens": True}
145
  )
146
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
147
 
148
- thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
149
  thread.start()
150
 
151
  buffer = ""
@@ -153,17 +134,17 @@ def generate(
153
  buffer += new_text
154
  yield buffer
155
  else:
156
- # Use Gemma model for text-only inputs
157
  conversation = chat_history.copy()
158
  conversation.append({"role": "user", "content": message})
159
 
160
- input_ids = gemma_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
161
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
162
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
163
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
164
- input_ids = input_ids.to(gemma_model.device)
165
 
166
- streamer = TextIteratorStreamer(gemma_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
167
  generate_kwargs = dict(
168
  {"input_ids": input_ids},
169
  streamer=streamer,
@@ -175,7 +156,7 @@ def generate(
175
  num_beams=1,
176
  repetition_penalty=repetition_penalty,
177
  )
178
- t = Thread(target=gemma_model.generate, kwargs=generate_kwargs)
179
  t.start()
180
 
181
  outputs = []
@@ -183,72 +164,61 @@ def generate(
183
  outputs.append(text)
184
  yield "".join(outputs)
185
 
186
- css = """
187
- #output {
188
- height: 500px;
189
- overflow: auto;
190
- border: 1px solid #ccc;
191
- }
192
- """
193
-
194
- with gr.Blocks(css=css) as demo:
195
- gr.Markdown(DESCRIPTION)
196
-
197
- with gr.Tab(label="Chat Interface"):
198
- chat_interface = gr.ChatInterface(
199
- fn=generate,
200
- additional_inputs=[
201
- gr.Slider(
202
- label="Max new tokens",
203
- minimum=1,
204
- maximum=MAX_MAX_NEW_TOKENS,
205
- step=1,
206
- value=DEFAULT_MAX_NEW_TOKENS,
207
- ),
208
- gr.Slider(
209
- label="Temperature",
210
- minimum=0.1,
211
- maximum=4.0,
212
- step=0.1,
213
- value=0.6,
214
- ),
215
- gr.Slider(
216
- label="Top-p (nucleus sampling)",
217
- minimum=0.05,
218
- maximum=1.0,
219
- step=0.05,
220
- value=0.9,
221
- ),
222
- gr.Slider(
223
- label="Top-k",
224
- minimum=1,
225
- maximum=1000,
226
- step=1,
227
- value=50,
228
- ),
229
- gr.Slider(
230
- label="Repetition penalty",
231
- minimum=1.0,
232
- maximum=2.0,
233
- step=0.05,
234
- value=1.2,
235
- ),
236
- ],
237
- stop_btn=None,
238
- examples=[
239
- ["Hello there! How are you doing?"],
240
- ["Can you explain briefly to me what is the Python programming language?"],
241
- ["Explain the plot of Cinderella in a sentence."],
242
- ["How many hours does it take a man to eat a Helicopter?"],
243
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
244
- ],
245
- cache_examples=False,
246
- type="messages",
247
- description=DESCRIPTION,
248
- css_paths="style.css",
249
- fill_height=True,
250
- textbox=gr.MultimodalTextbox(),
251
- multimodal=True,
252
- )
253
 
254
- demo.launch(debug=True)
 
 
1
  import os
2
  from collections.abc import Iterator
3
  from threading import Thread
 
4
  import gradio as gr
5
  import spaces
6
  import torch
 
10
  import uuid
11
  import io
12
 
13
+ # Text-only model setup
14
  DESCRIPTION = """
15
  # GWQ PREV
16
  """
 
21
 
22
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
 
24
+ model_id = "prithivMLmods/GWQ2b"
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_id,
 
28
  device_map="auto",
29
  torch_dtype=torch.bfloat16,
30
  )
31
+ model.config.sliding_window = 4096
32
+ model.eval()
33
 
34
+ # Multimodal model setup
35
+ MULTIMODAL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
36
+ multimodal_model = Qwen2VLForConditionalGeneration.from_pretrained(
37
+ MULTIMODAL_MODEL_ID,
38
  trust_remote_code=True,
39
  torch_dtype=torch.float16
40
  ).to("cuda").eval()
41
+ multimodal_processor = AutoProcessor.from_pretrained(MULTIMODAL_MODEL_ID, trust_remote_code=True)
42
 
43
  image_extensions = Image.registered_extensions()
44
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
 
71
  except Exception as e:
72
  raise ValueError(f"An error occurred while processing the file: {e}")
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  @spaces.GPU()
75
  def generate(
76
  message: str,
 
80
  top_p: float = 0.9,
81
  top_k: int = 50,
82
  repetition_penalty: float = 1.2,
83
+ files: list = None,
84
  ) -> Iterator[str]:
85
+ if files and len(files) > 0:
86
+ # Multimodal input
87
+ media_path = files[0]
88
+ if media_path.endswith(tuple([i for i, f in image_extensions.items()])):
89
+ media_type = "image"
90
+ elif media_path.endswith(video_extensions):
91
+ media_type = "video"
92
+ else:
93
+ try:
94
+ media_path, media_type = identify_and_save_blob(media_path)
95
+ except Exception as e:
96
+ raise ValueError("Unsupported media type. Please upload an image or video.")
97
+
 
 
 
 
 
98
  messages = [
99
  {
100
  "role": "user",
 
109
  }
110
  ]
111
 
112
+ text = multimodal_processor.apply_chat_template(
113
  messages, tokenize=False, add_generation_prompt=True
114
  )
115
  image_inputs, video_inputs = process_vision_info(messages)
116
+ inputs = multimodal_processor(
117
  text=[text],
118
  images=image_inputs,
119
  videos=video_inputs,
 
122
  ).to("cuda")
123
 
124
  streamer = TextIteratorStreamer(
125
+ multimodal_processor, skip_prompt=True, **{"skip_special_tokens": True}
126
  )
127
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
128
 
129
+ thread = Thread(target=multimodal_model.generate, kwargs=generation_kwargs)
130
  thread.start()
131
 
132
  buffer = ""
 
134
  buffer += new_text
135
  yield buffer
136
  else:
137
+ # Text-only input
138
  conversation = chat_history.copy()
139
  conversation.append({"role": "user", "content": message})
140
 
141
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
142
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
143
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
144
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
145
+ input_ids = input_ids.to(model.device)
146
 
147
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
148
  generate_kwargs = dict(
149
  {"input_ids": input_ids},
150
  streamer=streamer,
 
156
  num_beams=1,
157
  repetition_penalty=repetition_penalty,
158
  )
159
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
160
  t.start()
161
 
162
  outputs = []
 
164
  outputs.append(text)
165
  yield "".join(outputs)
166
 
167
+ demo = gr.ChatInterface(
168
+ fn=generate,
169
+ additional_inputs=[
170
+ gr.Slider(
171
+ label="Max new tokens",
172
+ minimum=1,
173
+ maximum=MAX_MAX_NEW_TOKENS,
174
+ step=1,
175
+ value=DEFAULT_MAX_NEW_TOKENS,
176
+ ),
177
+ gr.Slider(
178
+ label="Temperature",
179
+ minimum=0.1,
180
+ maximum=4.0,
181
+ step=0.1,
182
+ value=0.6,
183
+ ),
184
+ gr.Slider(
185
+ label="Top-p (nucleus sampling)",
186
+ minimum=0.05,
187
+ maximum=1.0,
188
+ step=0.05,
189
+ value=0.9,
190
+ ),
191
+ gr.Slider(
192
+ label="Top-k",
193
+ minimum=1,
194
+ maximum=1000,
195
+ step=1,
196
+ value=50,
197
+ ),
198
+ gr.Slider(
199
+ label="Repetition penalty",
200
+ minimum=1.0,
201
+ maximum=2.0,
202
+ step=0.05,
203
+ value=1.2,
204
+ ),
205
+ ],
206
+ stop_btn=None,
207
+ examples=[
208
+ ["Hello there! How are you doing?"],
209
+ ["Can you explain briefly to me what is the Python programming language?"],
210
+ ["Explain the plot of Cinderella in a sentence."],
211
+ ["How many hours does it take a man to eat a Helicopter?"],
212
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
213
+ ],
214
+ cache_examples=False,
215
+ type="messages",
216
+ description=DESCRIPTION,
217
+ css_paths="style.css",
218
+ fill_height=True,
219
+ multimodal=True,
220
+ textbox=gr.MultimodalTextbox(),
221
+ )
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ if __name__ == "__main__":
224
+ demo.queue(max_size=20).launch()