prithivMLmods commited on
Commit
5d63d59
·
verified ·
1 Parent(s): 7c85957

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -297
app.py CHANGED
@@ -1,22 +1,10 @@
1
  import gradio as gr
2
- import spaces
3
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
4
- from qwen_vl_utils import process_vision_info
5
- import torch
6
- from PIL import Image
7
- import os
8
- import uuid
9
- import io
10
  from threading import Thread
11
- from reportlab.lib.pagesizes import A4
12
- from reportlab.lib.styles import getSampleStyleSheet
13
- from reportlab.lib import colors
14
- from reportlab.platypus import SimpleDocTemplate, Image as RLImage, Paragraph, Spacer
15
- from reportlab.lib.units import inch
16
- from reportlab.pdfbase import pdfmetrics
17
- from reportlab.pdfbase.ttfonts import TTFont
18
- import docx
19
- from docx.enum.text import WD_ALIGN_PARAGRAPH
20
 
21
  # Define model options
22
  MODEL_OPTIONS = {
@@ -26,318 +14,128 @@ MODEL_OPTIONS = {
26
  "Text Analogy Ocrtest": "prithivMLmods/Qwen2-VL-Ocrtest-2B-Instruct"
27
  }
28
 
29
- # Preload models and processors into CUDA
30
- models = {}
31
- processors = {}
32
- for name, model_id in MODEL_OPTIONS.items():
33
- print(f"Loading {name}...")
34
- models[name] = Qwen2VLForConditionalGeneration.from_pretrained(
 
 
 
 
 
35
  model_id,
36
  trust_remote_code=True,
37
  torch_dtype=torch.float16
38
  ).to("cuda").eval()
39
- processors[name] = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
40
-
41
- image_extensions = Image.registered_extensions()
42
-
43
- def identify_and_save_blob(blob_path):
44
- """Identifies if the blob is an image and saves it."""
45
- try:
46
- with open(blob_path, 'rb') as file:
47
- blob_content = file.read()
48
- try:
49
- Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
50
- extension = ".png" # Default to PNG for saving
51
- media_type = "image"
52
- except (IOError, SyntaxError):
53
- raise ValueError("Unsupported media type. Please upload a valid image.")
54
-
55
- filename = f"temp_{uuid.uuid4()}_media{extension}"
56
- with open(filename, "wb") as f:
57
- f.write(blob_content)
58
-
59
- return filename, media_type
60
-
61
- except FileNotFoundError:
62
- raise ValueError(f"The file {blob_path} was not found.")
63
- except Exception as e:
64
- raise ValueError(f"An error occurred while processing the file: {e}")
65
 
66
  @spaces.GPU
67
- def qwen_inference(model_name, media_input, text_input=None):
68
- """Handles inference for the selected model."""
69
- model = models[model_name]
70
- processor = processors[model_name]
71
-
72
- if isinstance(media_input, str):
73
- media_path = media_input
74
- if media_path.endswith(tuple([i for i in image_extensions.keys()])):
75
- media_type = "image"
76
- else:
77
- try:
78
- media_path, media_type = identify_and_save_blob(media_input)
79
- except Exception as e:
80
- raise ValueError("Unsupported media type. Please upload a valid image.")
81
-
 
 
 
 
 
 
 
 
 
 
 
 
82
  messages = [
83
  {
84
  "role": "user",
85
  "content": [
86
- {
87
- "type": media_type,
88
- media_type: media_path
89
- },
90
- {"type": "text", "text": text_input},
91
  ],
92
  }
93
  ]
94
 
95
- text = processor.apply_chat_template(
96
- messages, tokenize=False, add_generation_prompt=True
97
- )
98
- image_inputs, _ = process_vision_info(messages)
99
  inputs = processor(
100
- text=[text],
101
- images=image_inputs,
102
- padding=True,
103
  return_tensors="pt",
 
104
  ).to("cuda")
105
 
106
- streamer = TextIteratorStreamer(
107
- processor.tokenizer, skip_prompt=True, skip_special_tokens=True
108
- )
109
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
110
 
 
111
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
112
  thread.start()
113
 
 
114
  buffer = ""
 
115
  for new_text in streamer:
116
  buffer += new_text
117
- # Remove <|im_end|> or similar tokens from the output
118
- buffer = buffer.replace("<|im_end|>", "")
119
  yield buffer
120
 
121
- def format_plain_text(output_text):
122
- """Formats the output text as plain text without LaTeX delimiters."""
123
- # Remove LaTeX delimiters and convert to plain text
124
- plain_text = output_text.replace("\\(", "").replace("\\)", "").replace("\\[", "").replace("\\]", "")
125
- return plain_text
126
 
127
- def generate_document(media_path, output_text, file_format, font_choice, font_size, line_spacing, alignment, image_size):
128
- """Generates a document with the input image and plain text output."""
129
- plain_text = format_plain_text(output_text)
130
- if file_format == "pdf":
131
- return generate_pdf(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size)
132
- elif file_format == "docx":
133
- return generate_docx(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size)
134
-
135
- def generate_pdf(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size):
136
- """Generates a PDF document."""
137
- filename = f"output_{uuid.uuid4()}.pdf"
138
- doc = SimpleDocTemplate(
139
- filename,
140
- pagesize=A4,
141
- rightMargin=inch,
142
- leftMargin=inch,
143
- topMargin=inch,
144
- bottomMargin=inch
 
 
145
  )
146
- styles = getSampleStyleSheet()
147
- styles["Normal"].fontName = font_choice
148
- styles["Normal"].fontSize = int(font_size)
149
- styles["Normal"].leading = int(font_size) * line_spacing
150
- styles["Normal"].alignment = {
151
- "Left": 0,
152
- "Center": 1,
153
- "Right": 2,
154
- "Justified": 4
155
- }[alignment]
156
-
157
- # Register font
158
- font_path = f"font/{font_choice}"
159
- pdfmetrics.registerFont(TTFont(font_choice, font_path))
160
-
161
- story = []
162
-
163
- # Add image with size adjustment
164
- image_sizes = {
165
- "Small": (200, 200),
166
- "Medium": (400, 400),
167
- "Large": (600, 600)
168
- }
169
- img = RLImage(media_path, width=image_sizes[image_size][0], height=image_sizes[image_size][1])
170
- story.append(img)
171
- story.append(Spacer(1, 12))
172
 
173
- # Add plain text output
174
- text = Paragraph(plain_text, styles["Normal"])
175
- story.append(text)
176
-
177
- doc.build(story)
178
- return filename
179
-
180
- def generate_docx(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size):
181
- """Generates a DOCX document."""
182
- filename = f"output_{uuid.uuid4()}.docx"
183
- doc = docx.Document()
184
-
185
- # Add image with size adjustment
186
- image_sizes = {
187
- "Small": docx.shared.Inches(2),
188
- "Medium": docx.shared.Inches(4),
189
- "Large": docx.shared.Inches(6)
190
- }
191
- doc.add_picture(media_path, width=image_sizes[image_size])
192
- doc.add_paragraph()
193
-
194
- # Add plain text output
195
- paragraph = doc.add_paragraph()
196
- paragraph.paragraph_format.line_spacing = line_spacing
197
- paragraph.paragraph_format.alignment = {
198
- "Left": WD_ALIGN_PARAGRAPH.LEFT,
199
- "Center": WD_ALIGN_PARAGRAPH.CENTER,
200
- "Right": WD_ALIGN_PARAGRAPH.RIGHT,
201
- "Justified": WD_ALIGN_PARAGRAPH.JUSTIFY
202
- }[alignment]
203
- run = paragraph.add_run(plain_text)
204
- run.font.name = font_choice
205
- run.font.size = docx.shared.Pt(int(font_size))
206
-
207
- doc.save(filename)
208
- return filename
209
-
210
- # CSS for output styling
211
- css = """
212
- #output {
213
- height: 500px;
214
- overflow: auto;
215
- border: 1px solid #ccc;
216
- }
217
- .submit-btn {
218
- background-color: #cf3434 !important;
219
- color: white !important;
220
- }
221
- .submit-btn:hover {
222
- background-color: #ff2323 !important;
223
- }
224
- .download-btn {
225
- background-color: #35a6d6 !important;
226
- color: white !important;
227
- }
228
- .download-btn:hover {
229
- background-color: #22bcff !important;
230
- }
231
- """
232
-
233
- # Gradio app setup
234
- with gr.Blocks(css=css) as demo:
235
- gr.Markdown("# Qwen2VL Models: Vision and Language Processing")
236
-
237
- with gr.Tab(label="Image Input"):
238
-
239
- with gr.Row():
240
- with gr.Column():
241
- model_choice = gr.Dropdown(
242
- label="Model Selection",
243
- choices=list(MODEL_OPTIONS.keys()),
244
- value="Latex OCR"
245
- )
246
- input_media = gr.File(
247
- label="Upload Image📸", type="filepath"
248
- )
249
- text_input = gr.Textbox(label="Question", placeholder="Ask a question about the image...")
250
- submit_btn = gr.Button(value="Submit", elem_classes="submit-btn")
251
-
252
- with gr.Column():
253
- output_text = gr.Textbox(label="Output Text", lines=10)
254
- plain_text_output = gr.Textbox(label="Standardized Plain Text", lines=10)
255
-
256
- submit_btn.click(
257
- qwen_inference, [model_choice, input_media, text_input], [output_text]
258
- ).then(
259
- lambda output_text: format_plain_text(output_text), [output_text], [plain_text_output]
260
- )
261
-
262
- # Add examples directly usable by clicking
263
- with gr.Row():
264
- gr.Examples(
265
- examples=[
266
- ["examples/1.png", "summarize the letter", "Text Analogy Ocrtest"],
267
- ["examples/2.jpg", "Summarize the full image in detail", "Latex OCR"],
268
- ["examples/3.png", "Describe the photo", "Qwen2VL Base"],
269
- ["examples/4.png", "summarize and solve the problem", "Math Prase"],
270
- ],
271
- inputs=[input_media, text_input, model_choice],
272
- outputs=[output_text, plain_text_output],
273
- fn=lambda img, question, model: qwen_inference(model, img, question),
274
- cache_examples=False,
275
- )
276
-
277
- with gr.Row():
278
- with gr.Column():
279
- line_spacing = gr.Dropdown(
280
- choices=[0.5, 1.0, 1.15, 1.5, 2.0, 2.5, 3.0],
281
- value=1.5,
282
- label="Line Spacing"
283
- )
284
- font_size = gr.Dropdown(
285
- choices=["8", "10", "12", "14", "16", "18", "20", "22", "24"],
286
- value="18",
287
- label="Font Size"
288
- )
289
- font_choice = gr.Dropdown(
290
- choices=[
291
- "DejaVuMathTeXGyre.ttf",
292
- "FiraCode-Medium.ttf",
293
- "InputMono-Light.ttf",
294
- "JetBrainsMono-Thin.ttf",
295
- "ProggyCrossed Regular Mac.ttf",
296
- "SourceCodePro-Black.ttf",
297
- "arial.ttf",
298
- "calibri.ttf",
299
- "mukta-malar-extralight.ttf",
300
- "noto-sans-arabic-medium.ttf",
301
- "times new roman.ttf",
302
- "ANGSA.ttf",
303
- "Book-Antiqua.ttf",
304
- "CONSOLA.TTF",
305
- "COOPBL.TTF",
306
- "Rockwell-Bold.ttf",
307
- "Candara Light.TTF",
308
- "Carlito-Regular.ttf Carlito-Regular.ttf",
309
- "Castellar.ttf",
310
- "Courier New.ttf",
311
- "LSANS.TTF",
312
- "Lucida Bright Regular.ttf",
313
- "TRTempusSansITC.ttf",
314
- "Verdana.ttf",
315
- "bell-mt.ttf",
316
- "eras-itc-light.ttf",
317
- "fonnts.com-aptos-light.ttf",
318
- "georgia.ttf",
319
- "segoeuithis.ttf",
320
- "youyuan.TTF",
321
- "TfPonetoneExpanded-7BJZA.ttf",
322
- ],
323
- value="youyuan.TTF",
324
- label="Font Choice"
325
- )
326
- alignment = gr.Dropdown(
327
- choices=["Left", "Center", "Right", "Justified"],
328
- value="Justified",
329
- label="Text Alignment"
330
- )
331
- image_size = gr.Dropdown(
332
- choices=["Small", "Medium", "Large"],
333
- value="Small",
334
- label="Image Size"
335
- )
336
- file_format = gr.Radio(["pdf", "docx"], label="File Format", value="pdf")
337
- get_document_btn = gr.Button(value="Get Document", elem_classes="download-btn")
338
 
339
- get_document_btn.click(
340
- generate_document, [input_media, output_text, file_format, font_choice, font_size, line_spacing, alignment, image_size], gr.File(label="Download Document")
341
- )
342
 
 
343
  demo.launch(debug=True)
 
1
  import gradio as gr
 
2
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
3
+ from transformers.image_utils import load_image
 
 
 
 
 
4
  from threading import Thread
5
+ import time
6
+ import torch
7
+ import spaces
 
 
 
 
 
 
8
 
9
  # Define model options
10
  MODEL_OPTIONS = {
 
14
  "Text Analogy Ocrtest": "prithivMLmods/Qwen2-VL-Ocrtest-2B-Instruct"
15
  }
16
 
17
+ # Global variables for model and processor
18
+ model = None
19
+ processor = None
20
+
21
+ # Function to load the selected model
22
+ def load_model(model_name):
23
+ global model, processor
24
+ model_id = MODEL_OPTIONS[model_name]
25
+ print(f"Loading model: {model_id}")
26
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
27
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
28
  model_id,
29
  trust_remote_code=True,
30
  torch_dtype=torch.float16
31
  ).to("cuda").eval()
32
+ print(f"Model {model_id} loaded successfully!")
33
+ return f"Model {model_name} loaded!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  @spaces.GPU
36
+ def model_inference(input_dict, history, model_choice):
37
+ global model, processor
38
+
39
+ # Load the selected model if not already loaded
40
+ if model is None or processor is None:
41
+ load_model(model_choice)
42
+
43
+ text = input_dict["text"]
44
+ files = input_dict["files"]
45
+
46
+ # Load images if provided
47
+ if len(files) > 1:
48
+ images = [load_image(image) for image in files]
49
+ elif len(files) == 1:
50
+ images = [load_image(files[0])]
51
+ else:
52
+ images = []
53
+
54
+ # Validate input
55
+ if text == "" and not images:
56
+ gr.Error("Please input a query and optionally image(s).")
57
+ return
58
+ if text == "" and images:
59
+ gr.Error("Please input a text query along with the image(s).")
60
+ return
61
+
62
+ # Prepare messages for the model
63
  messages = [
64
  {
65
  "role": "user",
66
  "content": [
67
+ *[{"type": "image", "image": image} for image in images],
68
+ {"type": "text", "text": text},
 
 
 
69
  ],
70
  }
71
  ]
72
 
73
+ # Apply chat template and process inputs
74
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
75
  inputs = processor(
76
+ text=[prompt],
77
+ images=images if images else None,
 
78
  return_tensors="pt",
79
+ padding=True,
80
  ).to("cuda")
81
 
82
+ # Set up streamer for real-time output
83
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
 
84
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
85
 
86
+ # Start generation in a separate thread
87
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
88
  thread.start()
89
 
90
+ # Stream the output
91
  buffer = ""
92
+ yield "Thinking..."
93
  for new_text in streamer:
94
  buffer += new_text
95
+ time.sleep(0.01)
 
96
  yield buffer
97
 
 
 
 
 
 
98
 
99
+ # Example inputs
100
+ examples = [
101
+ [{"text": "Describe the document?", "files": ["example_images/document.jpg"]}],
102
+ [{"text": "Describe this image.", "files": ["example_images/campeones.jpg"]}],
103
+ [{"text": "What does this say?", "files": ["example_images/math.jpg"]}],
104
+ [{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}],
105
+ [{"text": "Can you describe this image?", "files": ["example_images/newyork.jpg"]}],
106
+ [{"text": "Can you describe this image?", "files": ["example_images/dogs.jpg"]}],
107
+ [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
108
+ ]
109
+
110
+ # Gradio interface
111
+ with gr.Blocks() as demo:
112
+ gr.Markdown("# **Qwen2.5-VL-3B-Instruct**")
113
+
114
+ # Model selection dropdown
115
+ model_choice = gr.Dropdown(
116
+ label="Model Selection",
117
+ choices=list(MODEL_OPTIONS.keys()),
118
+ value="Latex OCR"
119
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ # Load model button
122
+ load_model_btn = gr.Button("Load Model")
123
+ load_model_output = gr.Textbox(label="Model Load Status")
124
+
125
+ # Chat interface
126
+ chat_interface = gr.ChatInterface(
127
+ fn=model_inference,
128
+ description="Interact with the selected Qwen2-VL model.",
129
+ examples=examples,
130
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
131
+ stop_btn="Stop Generation",
132
+ multimodal=True,
133
+ cache_examples=False,
134
+ additional_inputs=[model_choice] # Pass model_choice as an additional input
135
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ # Link the load model button to the load_model function
138
+ load_model_btn.click(load_model, inputs=model_choice, outputs=load_model_output)
 
139
 
140
+ # Launch the demo
141
  demo.launch(debug=True)