prithivMLmods commited on
Commit
2b1f8da
·
verified ·
1 Parent(s): 3df271a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -112
app.py CHANGED
@@ -4,7 +4,6 @@ import uuid
4
  import json
5
  import time
6
  import asyncio
7
- import tempfile
8
  from threading import Thread
9
 
10
  import gradio as gr
@@ -13,7 +12,6 @@ import torch
13
  import numpy as np
14
  from PIL import Image
15
  import edge_tts
16
- import trimesh
17
 
18
  from transformers import (
19
  AutoModelForCausalLM,
@@ -23,75 +21,8 @@ from transformers import (
23
  AutoProcessor,
24
  )
25
  from transformers.image_utils import load_image
26
-
27
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
28
- from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
29
- from diffusers.utils import export_to_ply
30
-
31
-
32
- MAX_SEED = np.iinfo(np.int32).max
33
-
34
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
35
- if randomize_seed:
36
- seed = random.randint(0, MAX_SEED)
37
- return seed
38
 
39
- class Model:
40
- def __init__(self):
41
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
- self.pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16)
43
- self.pipe.to(self.device)
44
- # Ensure the text encoder is in half precision to avoid dtype mismatches.
45
- if torch.cuda.is_available():
46
- try:
47
- self.pipe.text_encoder = self.pipe.text_encoder.half()
48
- except AttributeError:
49
- pass
50
-
51
- self.pipe_img = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16)
52
- self.pipe_img.to(self.device)
53
- # Use getattr with a default value to avoid AttributeError if text_encoder is missing.
54
- if torch.cuda.is_available():
55
- text_encoder_img = getattr(self.pipe_img, "text_encoder", None)
56
- if text_encoder_img is not None:
57
- self.pipe_img.text_encoder = text_encoder_img.half()
58
-
59
- def to_glb(self, ply_path: str) -> str:
60
- mesh = trimesh.load(ply_path)
61
- # Rotate the mesh for proper orientation
62
- rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
63
- mesh.apply_transform(rot)
64
- rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
65
- mesh.apply_transform(rot)
66
- mesh_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
67
- mesh.export(mesh_path.name, file_type="glb")
68
- return mesh_path.name
69
-
70
- def run_text(self, prompt: str, seed: int = 0, guidance_scale: float = 15.0, num_steps: int = 64) -> str:
71
- generator = torch.Generator(device=self.device).manual_seed(seed)
72
- images = self.pipe(
73
- prompt,
74
- generator=generator,
75
- guidance_scale=guidance_scale,
76
- num_inference_steps=num_steps,
77
- output_type="mesh",
78
- ).images
79
- ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
80
- export_to_ply(images[0], ply_path.name)
81
- return self.to_glb(ply_path.name)
82
-
83
- def run_image(self, image: Image.Image, seed: int = 0, guidance_scale: float = 3.0, num_steps: int = 64) -> str:
84
- generator = torch.Generator(device=self.device).manual_seed(seed)
85
- images = self.pipe_img(
86
- image,
87
- generator=generator,
88
- guidance_scale=guidance_scale,
89
- num_inference_steps=num_steps,
90
- output_type="mesh",
91
- ).images
92
- ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
93
- export_to_ply(images[0], ply_path.name)
94
- return self.to_glb(ply_path.name)
95
 
96
  DESCRIPTION = """
97
  # QwQ Edge 💬
@@ -117,7 +48,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
117
 
118
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
119
 
120
- # Load the text-only model and tokenizer (for pure text chat)
121
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
122
  tokenizer = AutoTokenizer.from_pretrained(model_id)
123
  model = AutoModelForCausalLM.from_pretrained(
@@ -127,13 +58,11 @@ model = AutoModelForCausalLM.from_pretrained(
127
  )
128
  model.eval()
129
 
130
- # Voices for text-to-speech
131
  TTS_VOICES = [
132
  "en-US-JennyNeural", # @tts1
133
  "en-US-GuyNeural", # @tts2
134
  ]
135
 
136
- # Load multimodal processor and model (e.g. for OCR and image processing)
137
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
138
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
139
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
@@ -159,12 +88,14 @@ def clean_chat_history(chat_history):
159
  cleaned.append(msg)
160
  return cleaned
161
 
 
162
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
163
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
164
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
165
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
166
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
167
 
 
168
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
169
  MODEL_ID_SD,
170
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -173,21 +104,31 @@ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
173
  ).to(device)
174
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
175
 
 
176
  if torch.cuda.is_available():
177
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
178
 
 
179
  if USE_TORCH_COMPILE:
180
  sd_pipe.compile()
181
 
 
182
  if ENABLE_CPU_OFFLOAD:
183
  sd_pipe.enable_model_cpu_offload()
184
 
 
 
185
  def save_image(img: Image.Image) -> str:
186
  """Save a PIL image with a unique filename and return the path."""
187
  unique_name = str(uuid.uuid4()) + ".png"
188
  img.save(unique_name)
189
  return unique_name
190
 
 
 
 
 
 
191
  @spaces.GPU(duration=60, enable_queue=True)
192
  def generate_image_fn(
193
  prompt: str,
@@ -227,6 +168,7 @@ def generate_image_fn(
227
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
228
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
229
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
 
230
  if device.type == "cuda":
231
  with torch.autocast("cuda", dtype=torch.float16):
232
  outputs = sd_pipe(**batch_options)
@@ -236,23 +178,6 @@ def generate_image_fn(
236
  image_paths = [save_image(img) for img in images]
237
  return image_paths, seed
238
 
239
- @spaces.GPU(duration=120, enable_queue=True)
240
- def generate_3d_fn(
241
- prompt: str,
242
- seed: int = 1,
243
- guidance_scale: float = 15.0,
244
- num_steps: int = 64,
245
- randomize_seed: bool = False,
246
- ):
247
- """
248
- Generate a 3D model from text using the ShapE pipeline.
249
- Returns a tuple of (glb_file_path, used_seed).
250
- """
251
- seed = int(randomize_seed_fn(seed, randomize_seed))
252
- model3d = Model()
253
- glb_path = model3d.run_text(prompt, seed=seed, guidance_scale=guidance_scale, num_steps=num_steps)
254
- return glb_path, seed
255
-
256
  @spaces.GPU
257
  def generate(
258
  input_dict: dict,
@@ -264,34 +189,16 @@ def generate(
264
  repetition_penalty: float = 1.2,
265
  ):
266
  """
267
- Generates chatbot responses with support for multimodal input, TTS, image generation,
268
- and 3D model generation.
269
-
270
  Special commands:
271
  - "@tts1" or "@tts2": triggers text-to-speech.
272
  - "@image": triggers image generation using the SDXL pipeline.
273
- - "@3d": triggers 3D model generation using the ShapE pipeline.
274
  """
275
  text = input_dict["text"]
276
  files = input_dict.get("files", [])
277
 
278
- # --- 3D Generation branch ---
279
- if text.strip().lower().startswith("@3d"):
280
- prompt = text[len("@3d"):].strip()
281
- yield "Generating 3D model..."
282
- glb_path, used_seed = generate_3d_fn(
283
- prompt=prompt,
284
- seed=1,
285
- guidance_scale=15.0,
286
- num_steps=64,
287
- randomize_seed=True,
288
- )
289
- # Instead of returning as a file, yield a 3D model component so it displays inline.
290
- yield gr.Model3D(value=glb_path, label="3D Model")
291
- return
292
-
293
- # --- Image Generation branch ---
294
  if text.strip().lower().startswith("@image"):
 
295
  prompt = text[len("@image"):].strip()
296
  yield "Generating image..."
297
  image_paths, used_seed = generate_image_fn(
@@ -307,10 +214,10 @@ def generate(
307
  use_resolution_binning=True,
308
  num_images=1,
309
  )
 
310
  yield gr.Image(image_paths[0])
311
- return
312
 
313
- # --- Text and TTS branch ---
314
  tts_prefix = "@tts"
315
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
316
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
@@ -318,9 +225,11 @@ def generate(
318
  if is_tts and voice_index:
319
  voice = TTS_VOICES[voice_index - 1]
320
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
321
  conversation = [{"role": "user", "content": text}]
322
  else:
323
  voice = None
 
324
  text = text.replace(tts_prefix, "").strip()
325
  conversation = clean_chat_history(chat_history)
326
  conversation.append({"role": "user", "content": text})
@@ -354,6 +263,7 @@ def generate(
354
  time.sleep(0.01)
355
  yield buffer
356
  else:
 
357
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
358
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
359
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -382,6 +292,7 @@ def generate(
382
  final_response = "".join(outputs)
383
  yield final_response
384
 
 
385
  if is_tts and voice:
386
  output_file = asyncio.run(text_to_speech(final_response, voice))
387
  yield gr.Audio(output_file, autoplay=True)
@@ -397,11 +308,12 @@ demo = gr.ChatInterface(
397
  ],
398
  examples=[
399
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
400
- ["@3d A birthday cupcake with cherry"],
401
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
402
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
403
  ["Write a Python function to check if a number is prime."],
404
  ["@tts2 What causes rainbows to form?"],
 
405
  ],
406
  cache_examples=False,
407
  type="messages",
 
4
  import json
5
  import time
6
  import asyncio
 
7
  from threading import Thread
8
 
9
  import gradio as gr
 
12
  import numpy as np
13
  from PIL import Image
14
  import edge_tts
 
15
 
16
  from transformers import (
17
  AutoModelForCausalLM,
 
21
  AutoProcessor,
22
  )
23
  from transformers.image_utils import load_image
 
24
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
 
 
 
 
 
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  DESCRIPTION = """
28
  # QwQ Edge 💬
 
48
 
49
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
 
51
+ # Load text-only model and tokenizer
52
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
53
  tokenizer = AutoTokenizer.from_pretrained(model_id)
54
  model = AutoModelForCausalLM.from_pretrained(
 
58
  )
59
  model.eval()
60
 
 
61
  TTS_VOICES = [
62
  "en-US-JennyNeural", # @tts1
63
  "en-US-GuyNeural", # @tts2
64
  ]
65
 
 
66
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
67
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
68
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
 
88
  cleaned.append(msg)
89
  return cleaned
90
 
91
+ # Environment variables and parameters for Stable Diffusion XL
92
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
93
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
94
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
95
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
96
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
97
 
98
+ # Load the SDXL pipeline
99
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
100
  MODEL_ID_SD,
101
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
104
  ).to(device)
105
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
106
 
107
+ # Ensure that the text encoder is in half-precision if using CUDA.
108
  if torch.cuda.is_available():
109
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
110
 
111
+ # Optional: compile the model for speedup if enabled
112
  if USE_TORCH_COMPILE:
113
  sd_pipe.compile()
114
 
115
+ # Optional: offload parts of the model to CPU if needed
116
  if ENABLE_CPU_OFFLOAD:
117
  sd_pipe.enable_model_cpu_offload()
118
 
119
+ MAX_SEED = np.iinfo(np.int32).max
120
+
121
  def save_image(img: Image.Image) -> str:
122
  """Save a PIL image with a unique filename and return the path."""
123
  unique_name = str(uuid.uuid4()) + ".png"
124
  img.save(unique_name)
125
  return unique_name
126
 
127
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
128
+ if randomize_seed:
129
+ seed = random.randint(0, MAX_SEED)
130
+ return seed
131
+
132
  @spaces.GPU(duration=60, enable_queue=True)
133
  def generate_image_fn(
134
  prompt: str,
 
168
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
169
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
170
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
171
+ # Wrap the pipeline call in autocast if using CUDA
172
  if device.type == "cuda":
173
  with torch.autocast("cuda", dtype=torch.float16):
174
  outputs = sd_pipe(**batch_options)
 
178
  image_paths = [save_image(img) for img in images]
179
  return image_paths, seed
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  @spaces.GPU
182
  def generate(
183
  input_dict: dict,
 
189
  repetition_penalty: float = 1.2,
190
  ):
191
  """
192
+ Generates chatbot responses with support for multimodal input, TTS, and image generation.
 
 
193
  Special commands:
194
  - "@tts1" or "@tts2": triggers text-to-speech.
195
  - "@image": triggers image generation using the SDXL pipeline.
 
196
  """
197
  text = input_dict["text"]
198
  files = input_dict.get("files", [])
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  if text.strip().lower().startswith("@image"):
201
+ # Remove the "@image" tag and use the rest as prompt
202
  prompt = text[len("@image"):].strip()
203
  yield "Generating image..."
204
  image_paths, used_seed = generate_image_fn(
 
214
  use_resolution_binning=True,
215
  num_images=1,
216
  )
217
+ # Yield the generated image so that the chat interface displays it.
218
  yield gr.Image(image_paths[0])
219
+ return # Exit early
220
 
 
221
  tts_prefix = "@tts"
222
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
223
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
 
225
  if is_tts and voice_index:
226
  voice = TTS_VOICES[voice_index - 1]
227
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
228
+ # Clear previous chat history for a fresh TTS request.
229
  conversation = [{"role": "user", "content": text}]
230
  else:
231
  voice = None
232
+ # Remove any stray @tts tags and build the conversation history.
233
  text = text.replace(tts_prefix, "").strip()
234
  conversation = clean_chat_history(chat_history)
235
  conversation.append({"role": "user", "content": text})
 
263
  time.sleep(0.01)
264
  yield buffer
265
  else:
266
+
267
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
268
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
269
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
292
  final_response = "".join(outputs)
293
  yield final_response
294
 
295
+ # If TTS was requested, convert the final response to speech.
296
  if is_tts and voice:
297
  output_file = asyncio.run(text_to_speech(final_response, voice))
298
  yield gr.Audio(output_file, autoplay=True)
 
308
  ],
309
  examples=[
310
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
311
+ [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
312
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
313
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
314
  ["Write a Python function to check if a number is prime."],
315
  ["@tts2 What causes rainbows to form?"],
316
+
317
  ],
318
  cache_examples=False,
319
  type="messages",