prithivMLmods commited on
Commit
a29f7c2
·
verified ·
1 Parent(s): 3238705

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -27
app.py CHANGED
@@ -4,7 +4,9 @@ import uuid
4
  import json
5
  import time
6
  import asyncio
 
7
  from threading import Thread
 
8
 
9
  import gradio as gr
10
  import spaces
@@ -12,6 +14,7 @@ import torch
12
  import numpy as np
13
  from PIL import Image
14
  import edge_tts
 
15
 
16
  from transformers import (
17
  AutoModelForCausalLM,
@@ -21,11 +24,99 @@ from transformers import (
21
  AutoProcessor,
22
  )
23
  from transformers.image_utils import load_image
 
24
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  DESCRIPTION = """
28
- # SmolLM Edge 🌠
29
  """
30
 
31
  css = '''
@@ -48,9 +139,12 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
48
 
49
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
 
51
- # Load text-only model and tokenizer
52
- #model_id = "prithivMLmods/FastThink-0.5B-Tiny"
53
- model_id = "prithivMLmods/SmolLM2_135M_Grpo_Gsm8k"
 
 
 
54
  tokenizer = AutoTokenizer.from_pretrained(model_id)
55
  model = AutoModelForCausalLM.from_pretrained(
56
  model_id,
@@ -59,11 +153,13 @@ model = AutoModelForCausalLM.from_pretrained(
59
  )
60
  model.eval()
61
 
 
62
  TTS_VOICES = [
63
  "en-US-JennyNeural", # @tts1
64
  "en-US-GuyNeural", # @tts2
65
  ]
66
 
 
67
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
68
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
69
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
@@ -72,12 +168,20 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
72
  torch_dtype=torch.float16
73
  ).to("cuda").eval()
74
 
 
 
 
 
75
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
76
  """Convert text to speech using Edge TTS and save as MP3"""
77
  communicate = edge_tts.Communicate(text, voice)
78
  await communicate.save(output_file)
79
  return output_file
80
 
 
 
 
 
81
  def clean_chat_history(chat_history):
82
  """
83
  Filter out any chat entries whose "content" is not a string.
@@ -89,14 +193,16 @@ def clean_chat_history(chat_history):
89
  cleaned.append(msg)
90
  return cleaned
91
 
92
- # Environment variables and parameters for Stable Diffusion XL
 
 
 
93
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
94
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
95
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
96
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
97
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
98
 
99
- # Load the SDXL pipeline
100
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
101
  MODEL_ID_SD,
102
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -105,31 +211,21 @@ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
105
  ).to(device)
106
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
107
 
108
- # Ensure that the text encoder is in half-precision if using CUDA.
109
  if torch.cuda.is_available():
110
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
111
 
112
- # Optional: compile the model for speedup if enabled
113
  if USE_TORCH_COMPILE:
114
  sd_pipe.compile()
115
 
116
- # Optional: offload parts of the model to CPU if needed
117
  if ENABLE_CPU_OFFLOAD:
118
  sd_pipe.enable_model_cpu_offload()
119
 
120
- MAX_SEED = np.iinfo(np.int32).max
121
-
122
  def save_image(img: Image.Image) -> str:
123
  """Save a PIL image with a unique filename and return the path."""
124
  unique_name = str(uuid.uuid4()) + ".png"
125
  img.save(unique_name)
126
  return unique_name
127
 
128
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
129
- if randomize_seed:
130
- seed = random.randint(0, MAX_SEED)
131
- return seed
132
-
133
  @spaces.GPU(duration=60, enable_queue=True)
134
  def generate_image_fn(
135
  prompt: str,
@@ -169,7 +265,6 @@ def generate_image_fn(
169
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
170
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
171
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
172
- # Wrap the pipeline call in autocast if using CUDA
173
  if device.type == "cuda":
174
  with torch.autocast("cuda", dtype=torch.float16):
175
  outputs = sd_pipe(**batch_options)
@@ -179,6 +274,31 @@ def generate_image_fn(
179
  image_paths = [save_image(img) for img in images]
180
  return image_paths, seed
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  @spaces.GPU
183
  def generate(
184
  input_dict: dict,
@@ -190,16 +310,39 @@ def generate(
190
  repetition_penalty: float = 1.2,
191
  ):
192
  """
193
- Generates chatbot responses with support for multimodal input, TTS, and image generation.
 
 
194
  Special commands:
195
  - "@tts1" or "@tts2": triggers text-to-speech.
196
  - "@image": triggers image generation using the SDXL pipeline.
 
197
  """
198
  text = input_dict["text"]
199
  files = input_dict.get("files", [])
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  if text.strip().lower().startswith("@image"):
202
- # Remove the "@image" tag and use the rest as prompt
203
  prompt = text[len("@image"):].strip()
204
  yield "Generating image..."
205
  image_paths, used_seed = generate_image_fn(
@@ -215,10 +358,10 @@ def generate(
215
  use_resolution_binning=True,
216
  num_images=1,
217
  )
218
- # Yield the generated image so that the chat interface displays it.
219
  yield gr.Image(image_paths[0])
220
- return # Exit early
221
 
 
222
  tts_prefix = "@tts"
223
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
224
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
@@ -226,11 +369,9 @@ def generate(
226
  if is_tts and voice_index:
227
  voice = TTS_VOICES[voice_index - 1]
228
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
229
- # Clear previous chat history for a fresh TTS request.
230
  conversation = [{"role": "user", "content": text}]
231
  else:
232
  voice = None
233
- # Remove any stray @tts tags and build the conversation history.
234
  text = text.replace(tts_prefix, "").strip()
235
  conversation = clean_chat_history(chat_history)
236
  conversation.append({"role": "user", "content": text})
@@ -264,7 +405,6 @@ def generate(
264
  time.sleep(0.01)
265
  yield buffer
266
  else:
267
-
268
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
269
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
270
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -293,11 +433,14 @@ def generate(
293
  final_response = "".join(outputs)
294
  yield final_response
295
 
296
- # If TTS was requested, convert the final response to speech.
297
  if is_tts and voice:
298
  output_file = asyncio.run(text_to_speech(final_response, voice))
299
  yield gr.Audio(output_file, autoplay=True)
300
 
 
 
 
 
301
  demo = gr.ChatInterface(
302
  fn=generate,
303
  additional_inputs=[
@@ -309,12 +452,11 @@ demo = gr.ChatInterface(
309
  ],
310
  examples=[
311
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
312
- [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
313
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
314
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
315
  ["Write a Python function to check if a number is prime."],
316
  ["@tts2 What causes rainbows to form?"],
317
-
318
  ],
319
  cache_examples=False,
320
  type="messages",
 
4
  import json
5
  import time
6
  import asyncio
7
+ import tempfile
8
  from threading import Thread
9
+ import base64
10
 
11
  import gradio as gr
12
  import spaces
 
14
  import numpy as np
15
  from PIL import Image
16
  import edge_tts
17
+ import trimesh
18
 
19
  from transformers import (
20
  AutoModelForCausalLM,
 
24
  AutoProcessor,
25
  )
26
  from transformers.image_utils import load_image
27
+
28
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
29
+ from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
30
+ from diffusers.utils import export_to_ply
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # Global constants and helper functions
34
+ # -----------------------------------------------------------------------------
35
+
36
+ MAX_SEED = np.iinfo(np.int32).max
37
+
38
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
39
+ if randomize_seed:
40
+ seed = random.randint(0, MAX_SEED)
41
+ return seed
42
 
43
+ def glb_to_data_url(glb_path: str) -> str:
44
+ """
45
+ Reads a GLB file from disk and returns a data URL with a base64 encoded representation.
46
+ This data URL can be used as the `src` for an HTML <model-viewer> tag.
47
+ """
48
+ with open(glb_path, "rb") as f:
49
+ data = f.read()
50
+ b64_data = base64.b64encode(data).decode("utf-8")
51
+ return f"data:model/gltf-binary;base64,{b64_data}"
52
+
53
+ # -----------------------------------------------------------------------------
54
+ # Model class for Text-to-3D Generation (ShapE)
55
+ # -----------------------------------------------------------------------------
56
+
57
+ class Model:
58
+ def __init__(self):
59
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+ self.pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16)
61
+ self.pipe.to(self.device)
62
+ # Ensure the text encoder is in half precision to avoid dtype mismatches.
63
+ if torch.cuda.is_available():
64
+ try:
65
+ self.pipe.text_encoder = self.pipe.text_encoder.half()
66
+ except AttributeError:
67
+ pass
68
+
69
+ self.pipe_img = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16)
70
+ self.pipe_img.to(self.device)
71
+ # Use getattr with a default value to avoid AttributeError if text_encoder is missing.
72
+ if torch.cuda.is_available():
73
+ text_encoder_img = getattr(self.pipe_img, "text_encoder", None)
74
+ if text_encoder_img is not None:
75
+ self.pipe_img.text_encoder = text_encoder_img.half()
76
+
77
+ def to_glb(self, ply_path: str) -> str:
78
+ mesh = trimesh.load(ply_path)
79
+ # Rotate the mesh for proper orientation
80
+ rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
81
+ mesh.apply_transform(rot)
82
+ rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
83
+ mesh.apply_transform(rot)
84
+ mesh_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
85
+ mesh.export(mesh_path.name, file_type="glb")
86
+ return mesh_path.name
87
+
88
+ def run_text(self, prompt: str, seed: int = 0, guidance_scale: float = 15.0, num_steps: int = 64) -> str:
89
+ generator = torch.Generator(device=self.device).manual_seed(seed)
90
+ images = self.pipe(
91
+ prompt,
92
+ generator=generator,
93
+ guidance_scale=guidance_scale,
94
+ num_inference_steps=num_steps,
95
+ output_type="mesh",
96
+ ).images
97
+ ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
98
+ export_to_ply(images[0], ply_path.name)
99
+ return self.to_glb(ply_path.name)
100
+
101
+ def run_image(self, image: Image.Image, seed: int = 0, guidance_scale: float = 3.0, num_steps: int = 64) -> str:
102
+ generator = torch.Generator(device=self.device).manual_seed(seed)
103
+ images = self.pipe_img(
104
+ image,
105
+ generator=generator,
106
+ guidance_scale=guidance_scale,
107
+ num_inference_steps=num_steps,
108
+ output_type="mesh",
109
+ ).images
110
+ ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
111
+ export_to_ply(images[0], ply_path.name)
112
+ return self.to_glb(ply_path.name)
113
+
114
+ # -----------------------------------------------------------------------------
115
+ # Gradio UI configuration
116
+ # -----------------------------------------------------------------------------
117
 
118
  DESCRIPTION = """
119
+ # QwQ Edge 💬
120
  """
121
 
122
  css = '''
 
139
 
140
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
141
 
142
+ # -----------------------------------------------------------------------------
143
+ # Load Models and Pipelines for Chat, Image, and Multimodal Processing
144
+ # -----------------------------------------------------------------------------
145
+
146
+ # Load the text-only model and tokenizer (for pure text chat)
147
+ model_id = "prithivMLmods/FastThink-0.5B-Tiny"
148
  tokenizer = AutoTokenizer.from_pretrained(model_id)
149
  model = AutoModelForCausalLM.from_pretrained(
150
  model_id,
 
153
  )
154
  model.eval()
155
 
156
+ # Voices for text-to-speech
157
  TTS_VOICES = [
158
  "en-US-JennyNeural", # @tts1
159
  "en-US-GuyNeural", # @tts2
160
  ]
161
 
162
+ # Load multimodal processor and model (e.g. for OCR and image processing)
163
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
164
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
165
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
 
168
  torch_dtype=torch.float16
169
  ).to("cuda").eval()
170
 
171
+ # -----------------------------------------------------------------------------
172
+ # Asynchronous text-to-speech
173
+ # -----------------------------------------------------------------------------
174
+
175
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
176
  """Convert text to speech using Edge TTS and save as MP3"""
177
  communicate = edge_tts.Communicate(text, voice)
178
  await communicate.save(output_file)
179
  return output_file
180
 
181
+ # -----------------------------------------------------------------------------
182
+ # Utility function to clean conversation history
183
+ # -----------------------------------------------------------------------------
184
+
185
  def clean_chat_history(chat_history):
186
  """
187
  Filter out any chat entries whose "content" is not a string.
 
193
  cleaned.append(msg)
194
  return cleaned
195
 
196
+ # -----------------------------------------------------------------------------
197
+ # Stable Diffusion XL Pipeline for Image Generation
198
+ # -----------------------------------------------------------------------------
199
+
200
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
201
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
202
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
203
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
204
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
205
 
 
206
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
207
  MODEL_ID_SD,
208
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
211
  ).to(device)
212
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
213
 
 
214
  if torch.cuda.is_available():
215
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
216
 
 
217
  if USE_TORCH_COMPILE:
218
  sd_pipe.compile()
219
 
 
220
  if ENABLE_CPU_OFFLOAD:
221
  sd_pipe.enable_model_cpu_offload()
222
 
 
 
223
  def save_image(img: Image.Image) -> str:
224
  """Save a PIL image with a unique filename and return the path."""
225
  unique_name = str(uuid.uuid4()) + ".png"
226
  img.save(unique_name)
227
  return unique_name
228
 
 
 
 
 
 
229
  @spaces.GPU(duration=60, enable_queue=True)
230
  def generate_image_fn(
231
  prompt: str,
 
265
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
266
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
267
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
 
268
  if device.type == "cuda":
269
  with torch.autocast("cuda", dtype=torch.float16):
270
  outputs = sd_pipe(**batch_options)
 
274
  image_paths = [save_image(img) for img in images]
275
  return image_paths, seed
276
 
277
+ # -----------------------------------------------------------------------------
278
+ # Text-to-3D Generation using the ShapE Pipeline
279
+ # -----------------------------------------------------------------------------
280
+
281
+ @spaces.GPU(duration=120, enable_queue=True)
282
+ def generate_3d_fn(
283
+ prompt: str,
284
+ seed: int = 1,
285
+ guidance_scale: float = 15.0,
286
+ num_steps: int = 64,
287
+ randomize_seed: bool = False,
288
+ ):
289
+ """
290
+ Generate a 3D model from text using the ShapE pipeline.
291
+ Returns a tuple of (glb_file_path, used_seed).
292
+ """
293
+ seed = int(randomize_seed_fn(seed, randomize_seed))
294
+ model3d = Model()
295
+ glb_path = model3d.run_text(prompt, seed=seed, guidance_scale=guidance_scale, num_steps=num_steps)
296
+ return glb_path, seed
297
+
298
+ # -----------------------------------------------------------------------------
299
+ # Chat Generation Function with support for @tts, @image, and @3d commands
300
+ # -----------------------------------------------------------------------------
301
+
302
  @spaces.GPU
303
  def generate(
304
  input_dict: dict,
 
310
  repetition_penalty: float = 1.2,
311
  ):
312
  """
313
+ Generates chatbot responses with support for multimodal input, TTS, image generation,
314
+ and 3D model generation.
315
+
316
  Special commands:
317
  - "@tts1" or "@tts2": triggers text-to-speech.
318
  - "@image": triggers image generation using the SDXL pipeline.
319
+ - "@3d": triggers 3D model generation using the ShapE pipeline.
320
  """
321
  text = input_dict["text"]
322
  files = input_dict.get("files", [])
323
 
324
+ # --- 3D Generation branch ---
325
+ if text.strip().lower().startswith("@3d"):
326
+ prompt = text[len("@3d"):].strip()
327
+ yield "Generating 3D model..."
328
+ glb_path, used_seed = generate_3d_fn(
329
+ prompt=prompt,
330
+ seed=1,
331
+ guidance_scale=15.0,
332
+ num_steps=64,
333
+ randomize_seed=True,
334
+ )
335
+ # Convert the GLB file to a base64 data URL and embed it in an HTML <model-viewer> tag.
336
+ data_url = glb_to_data_url(glb_path)
337
+ html_output = f'''
338
+ <model-viewer src="{data_url}" alt="3D Model" auto-rotate camera-controls style="width: 100%; height: 400px;"></model-viewer>
339
+ <script type="module" src="https://unpkg.com/@google/model-viewer/dist/model-viewer.min.js"></script>
340
+ '''
341
+ yield gr.HTML(html_output)
342
+ return
343
+
344
+ # --- Image Generation branch ---
345
  if text.strip().lower().startswith("@image"):
 
346
  prompt = text[len("@image"):].strip()
347
  yield "Generating image..."
348
  image_paths, used_seed = generate_image_fn(
 
358
  use_resolution_binning=True,
359
  num_images=1,
360
  )
 
361
  yield gr.Image(image_paths[0])
362
+ return
363
 
364
+ # --- Text and TTS branch ---
365
  tts_prefix = "@tts"
366
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
367
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
 
369
  if is_tts and voice_index:
370
  voice = TTS_VOICES[voice_index - 1]
371
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
372
  conversation = [{"role": "user", "content": text}]
373
  else:
374
  voice = None
 
375
  text = text.replace(tts_prefix, "").strip()
376
  conversation = clean_chat_history(chat_history)
377
  conversation.append({"role": "user", "content": text})
 
405
  time.sleep(0.01)
406
  yield buffer
407
  else:
 
408
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
409
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
410
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
433
  final_response = "".join(outputs)
434
  yield final_response
435
 
 
436
  if is_tts and voice:
437
  output_file = asyncio.run(text_to_speech(final_response, voice))
438
  yield gr.Audio(output_file, autoplay=True)
439
 
440
+ # -----------------------------------------------------------------------------
441
+ # Gradio Chat Interface Setup and Launch
442
+ # -----------------------------------------------------------------------------
443
+
444
  demo = gr.ChatInterface(
445
  fn=generate,
446
  additional_inputs=[
 
452
  ],
453
  examples=[
454
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
455
+ ["@3d A birthday cupcake with cherry"],
456
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
457
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
458
  ["Write a Python function to check if a number is prime."],
459
  ["@tts2 What causes rainbows to form?"],
 
460
  ],
461
  cache_examples=False,
462
  type="messages",