prithivMLmods commited on
Commit
eab6c4d
·
verified ·
1 Parent(s): 528a976

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +330 -0
  2. requirements.txt +23 -0
app.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
5
+ import time
6
+ import asyncio
7
+ from threading import Thread
8
+
9
+ import gradio as gr
10
+ import spaces
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ import edge_tts
15
+
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoTokenizer,
19
+ TextIteratorStreamer,
20
+ Qwen2VLForConditionalGeneration,
21
+ AutoProcessor,
22
+ )
23
+ from transformers.image_utils import load_image
24
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
25
+
26
+
27
+ DESCRIPTION = """
28
+ # QwQ Edge 💬
29
+ """
30
+
31
+ css = '''
32
+ h1 {
33
+ text-align: center;
34
+ display: block;
35
+ }
36
+
37
+ #duplicate-button {
38
+ margin: auto;
39
+ color: #fff;
40
+ background: #1565c0;
41
+ border-radius: 100vh;
42
+ }
43
+ '''
44
+
45
+ MAX_MAX_NEW_TOKENS = 2048
46
+ DEFAULT_MAX_NEW_TOKENS = 1024
47
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
48
+
49
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
+
51
+ # Load text-only model and tokenizer
52
+ model_id = "prithivMLmods/FastThink-0.5B-Tiny"
53
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ model_id,
56
+ device_map="auto",
57
+ torch_dtype=torch.bfloat16,
58
+ )
59
+ model.eval()
60
+
61
+ TTS_VOICES = [
62
+ "en-US-JennyNeural", # @tts1
63
+ "en-US-GuyNeural", # @tts2
64
+ ]
65
+
66
+ MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
67
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
68
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
69
+ MODEL_ID,
70
+ trust_remote_code=True,
71
+ torch_dtype=torch.float16
72
+ ).to("cuda").eval()
73
+
74
+ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
75
+ """Convert text to speech using Edge TTS and save as MP3"""
76
+ communicate = edge_tts.Communicate(text, voice)
77
+ await communicate.save(output_file)
78
+ return output_file
79
+
80
+ def clean_chat_history(chat_history):
81
+ """
82
+ Filter out any chat entries whose "content" is not a string.
83
+ This helps prevent errors when concatenating previous messages.
84
+ """
85
+ cleaned = []
86
+ for msg in chat_history:
87
+ if isinstance(msg, dict) and isinstance(msg.get("content"), str):
88
+ cleaned.append(msg)
89
+ return cleaned
90
+
91
+ # Environment variables and parameters for Stable Diffusion XL
92
+ MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
93
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
94
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
95
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
96
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
97
+
98
+ # Load the SDXL pipeline
99
+ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
100
+ MODEL_ID_SD,
101
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
102
+ use_safetensors=True,
103
+ add_watermarker=False,
104
+ ).to(device)
105
+ sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
106
+
107
+ # Ensure that the text encoder is in half-precision if using CUDA.
108
+ if torch.cuda.is_available():
109
+ sd_pipe.text_encoder = sd_pipe.text_encoder.half()
110
+
111
+ # Optional: compile the model for speedup if enabled
112
+ if USE_TORCH_COMPILE:
113
+ sd_pipe.compile()
114
+
115
+ # Optional: offload parts of the model to CPU if needed
116
+ if ENABLE_CPU_OFFLOAD:
117
+ sd_pipe.enable_model_cpu_offload()
118
+
119
+ MAX_SEED = np.iinfo(np.int32).max
120
+
121
+ def save_image(img: Image.Image) -> str:
122
+ """Save a PIL image with a unique filename and return the path."""
123
+ unique_name = str(uuid.uuid4()) + ".png"
124
+ img.save(unique_name)
125
+ return unique_name
126
+
127
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
128
+ if randomize_seed:
129
+ seed = random.randint(0, MAX_SEED)
130
+ return seed
131
+
132
+ @spaces.GPU(duration=60, enable_queue=True)
133
+ def generate_image_fn(
134
+ prompt: str,
135
+ negative_prompt: str = "",
136
+ use_negative_prompt: bool = False,
137
+ seed: int = 1,
138
+ width: int = 1024,
139
+ height: int = 1024,
140
+ guidance_scale: float = 3,
141
+ num_inference_steps: int = 25,
142
+ randomize_seed: bool = False,
143
+ use_resolution_binning: bool = True,
144
+ num_images: int = 1,
145
+ progress=gr.Progress(track_tqdm=True),
146
+ ):
147
+ """Generate images using the SDXL pipeline."""
148
+ seed = int(randomize_seed_fn(seed, randomize_seed))
149
+ generator = torch.Generator(device=device).manual_seed(seed)
150
+
151
+ options = {
152
+ "prompt": [prompt] * num_images,
153
+ "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
154
+ "width": width,
155
+ "height": height,
156
+ "guidance_scale": guidance_scale,
157
+ "num_inference_steps": num_inference_steps,
158
+ "generator": generator,
159
+ "output_type": "pil",
160
+ }
161
+ if use_resolution_binning:
162
+ options["use_resolution_binning"] = True
163
+
164
+ images = []
165
+ # Process in batches
166
+ for i in range(0, num_images, BATCH_SIZE):
167
+ batch_options = options.copy()
168
+ batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
169
+ if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
170
+ batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
171
+ # Wrap the pipeline call in autocast if using CUDA
172
+ if device.type == "cuda":
173
+ with torch.autocast("cuda", dtype=torch.float16):
174
+ outputs = sd_pipe(**batch_options)
175
+ else:
176
+ outputs = sd_pipe(**batch_options)
177
+ images.extend(outputs.images)
178
+ image_paths = [save_image(img) for img in images]
179
+ return image_paths, seed
180
+
181
+ @spaces.GPU
182
+ def generate(
183
+ input_dict: dict,
184
+ chat_history: list[dict],
185
+ max_new_tokens: int = 1024,
186
+ temperature: float = 0.6,
187
+ top_p: float = 0.9,
188
+ top_k: int = 50,
189
+ repetition_penalty: float = 1.2,
190
+ ):
191
+ """
192
+ Generates chatbot responses with support for multimodal input, TTS, and image generation.
193
+ Special commands:
194
+ - "@tts1" or "@tts2": triggers text-to-speech.
195
+ - "@image": triggers image generation using the SDXL pipeline.
196
+ """
197
+ text = input_dict["text"]
198
+ files = input_dict.get("files", [])
199
+
200
+ if text.strip().lower().startswith("@image"):
201
+ # Remove the "@image" tag and use the rest as prompt
202
+ prompt = text[len("@image"):].strip()
203
+ yield "Generating image..."
204
+ image_paths, used_seed = generate_image_fn(
205
+ prompt=prompt,
206
+ negative_prompt="",
207
+ use_negative_prompt=False,
208
+ seed=1,
209
+ width=1024,
210
+ height=1024,
211
+ guidance_scale=3,
212
+ num_inference_steps=25,
213
+ randomize_seed=True,
214
+ use_resolution_binning=True,
215
+ num_images=1,
216
+ )
217
+ # Yield the generated image so that the chat interface displays it.
218
+ yield gr.Image(image_paths[0])
219
+ return # Exit early
220
+
221
+ tts_prefix = "@tts"
222
+ is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
223
+ voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
224
+
225
+ if is_tts and voice_index:
226
+ voice = TTS_VOICES[voice_index - 1]
227
+ text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
228
+ # Clear previous chat history for a fresh TTS request.
229
+ conversation = [{"role": "user", "content": text}]
230
+ else:
231
+ voice = None
232
+ # Remove any stray @tts tags and build the conversation history.
233
+ text = text.replace(tts_prefix, "").strip()
234
+ conversation = clean_chat_history(chat_history)
235
+ conversation.append({"role": "user", "content": text})
236
+
237
+ if files:
238
+ if len(files) > 1:
239
+ images = [load_image(image) for image in files]
240
+ elif len(files) == 1:
241
+ images = [load_image(files[0])]
242
+ else:
243
+ images = []
244
+ messages = [{
245
+ "role": "user",
246
+ "content": [
247
+ *[{"type": "image", "image": image} for image in images],
248
+ {"type": "text", "text": text},
249
+ ]
250
+ }]
251
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
252
+ inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
253
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
254
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
255
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
256
+ thread.start()
257
+
258
+ buffer = ""
259
+ yield "Thinking..."
260
+ for new_text in streamer:
261
+ buffer += new_text
262
+ buffer = buffer.replace("<|im_end|>", "")
263
+ time.sleep(0.01)
264
+ yield buffer
265
+ else:
266
+
267
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
268
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
269
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
270
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
271
+ input_ids = input_ids.to(model.device)
272
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
273
+ generation_kwargs = {
274
+ "input_ids": input_ids,
275
+ "streamer": streamer,
276
+ "max_new_tokens": max_new_tokens,
277
+ "do_sample": True,
278
+ "top_p": top_p,
279
+ "top_k": top_k,
280
+ "temperature": temperature,
281
+ "num_beams": 1,
282
+ "repetition_penalty": repetition_penalty,
283
+ }
284
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
285
+ t.start()
286
+
287
+ outputs = []
288
+ for new_text in streamer:
289
+ outputs.append(new_text)
290
+ yield "".join(outputs)
291
+
292
+ final_response = "".join(outputs)
293
+ yield final_response
294
+
295
+ # If TTS was requested, convert the final response to speech.
296
+ if is_tts and voice:
297
+ output_file = asyncio.run(text_to_speech(final_response, voice))
298
+ yield gr.Audio(output_file, autoplay=True)
299
+
300
+ demo = gr.ChatInterface(
301
+ fn=generate,
302
+ additional_inputs=[
303
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
304
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
305
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
306
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
307
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
308
+ ],
309
+ examples=[
310
+ ["@tts1 Who is Nikola Tesla, and why did he die?"],
311
+ [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
312
+ [{"text": "summarize the letter", "files": ["examples/1.png"]}],
313
+ ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
314
+ ["Write a Python function to check if a number is prime."],
315
+ ["@tts2 What causes rainbows to form?"],
316
+
317
+ ],
318
+ cache_examples=False,
319
+ type="messages",
320
+ description=DESCRIPTION,
321
+ css=css,
322
+ fill_height=True,
323
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
324
+ stop_btn="Stop Generation",
325
+ multimodal=True,
326
+ )
327
+
328
+ if __name__ == "__main__":
329
+ # To create a public link, set share=True in launch().
330
+ demo.queue(max_size=20).launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ git+https://github.com/huggingface/transformers.git
3
+ gradio_client==1.3.0
4
+ qwen-vl-utils==0.0.2
5
+ transformers-stream-generator==0.0.4
6
+ accelerate
7
+ diffusers
8
+ peft
9
+ trimesh
10
+ torch==2.4.0
11
+ torchvision==0.19.0
12
+ sentencepiece
13
+ spaces
14
+ requests
15
+ safetensors
16
+ edge-tts
17
+ audiosegment
18
+ asyncio
19
+ scipy
20
+ librosa
21
+ pydub
22
+ ffmpeg-python
23
+ av