Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -3,11 +3,9 @@ from PIL import Image
|
|
3 |
import torch
|
4 |
from threading import Thread
|
5 |
import gradio as gr
|
6 |
-
import spaces
|
7 |
import fitz # PyMuPDF
|
8 |
import io
|
9 |
import logging
|
10 |
-
from typing import List, Dict, Any
|
11 |
|
12 |
# Set up logging
|
13 |
logging.basicConfig(level=logging.INFO)
|
@@ -15,138 +13,209 @@ logger = logging.getLogger(__name__)
|
|
15 |
|
16 |
# Load model and processor
|
17 |
ckpt = "Qwen/Qwen2.5-VL-7B-Instruct"
|
18 |
-
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
19 |
-
ckpt,
|
20 |
-
torch_dtype=torch.bfloat16,
|
21 |
-
device_map="auto",
|
22 |
-
trust_remote_code=True
|
23 |
-
)
|
24 |
processor = AutoProcessor.from_pretrained(ckpt, trust_remote_code=True)
|
25 |
|
26 |
-
class
|
27 |
def __init__(self):
|
28 |
-
self.
|
29 |
-
self.
|
30 |
-
|
31 |
-
def add_message(self, role: str, content: Any, images: List[Image.Image] = None):
|
32 |
-
self.conversation_history.append({
|
33 |
-
"role": role,
|
34 |
-
"content": content,
|
35 |
-
"images": images or []
|
36 |
-
})
|
37 |
|
38 |
def clear(self):
|
39 |
-
self.
|
40 |
-
self.
|
41 |
-
|
42 |
-
|
|
|
43 |
|
44 |
-
def
|
45 |
-
"""
|
46 |
try:
|
47 |
doc = fitz.open(file_path)
|
48 |
images = []
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
doc.close()
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
except Exception as e:
|
57 |
-
logger.error(f"
|
58 |
-
|
59 |
|
60 |
-
def
|
61 |
-
"""
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
66 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
try:
|
68 |
-
|
|
|
69 |
except Exception as e:
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
@spaces.GPU
|
74 |
-
def
|
75 |
try:
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
# Update chat state
|
81 |
-
chat_state.add_message("user", user_text, user_images)
|
82 |
-
|
83 |
-
# Build conversation history for model
|
84 |
messages = []
|
85 |
-
for msg in chat_state.conversation_history:
|
86 |
-
content = []
|
87 |
-
if msg["role"] == "user":
|
88 |
-
content.append({"type": "text", "text": msg["content"]})
|
89 |
-
for img in msg["images"]:
|
90 |
-
content.append({"type": "image"})
|
91 |
-
messages.append({"role": "user", "content": content})
|
92 |
-
else:
|
93 |
-
messages.append({"role": "assistant", "content": msg["content"]})
|
94 |
|
95 |
-
#
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
# Get all images from history
|
103 |
-
all_images = [img for msg in chat_state.conversation_history for img in msg["images"]]
|
104 |
-
|
105 |
-
inputs = processor(
|
106 |
-
text=model_inputs,
|
107 |
-
images=all_images if all_images else None,
|
108 |
-
return_tensors="pt"
|
109 |
-
).to(model.device)
|
110 |
|
111 |
-
#
|
112 |
-
|
113 |
-
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
|
114 |
-
|
115 |
-
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
116 |
-
thread.start()
|
117 |
-
|
118 |
-
buffer = ""
|
119 |
-
for new_text in streamer:
|
120 |
-
buffer += new_text
|
121 |
-
yield buffer
|
122 |
-
|
123 |
-
# Save final response
|
124 |
-
chat_state.add_message("assistant", buffer)
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
except Exception as e:
|
127 |
-
logger.error(f"
|
128 |
yield "An error occurred. Please try again."
|
129 |
|
130 |
-
def
|
131 |
-
"""Clear
|
132 |
-
|
133 |
-
return "
|
134 |
|
135 |
-
# Create Gradio interface
|
136 |
-
with gr.Blocks(
|
137 |
-
gr.Markdown("#
|
138 |
-
gr.Markdown("
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
)
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
|
|
|
|
|
|
|
|
150 |
|
151 |
-
|
152 |
-
|
|
|
3 |
import torch
|
4 |
from threading import Thread
|
5 |
import gradio as gr
|
|
|
6 |
import fitz # PyMuPDF
|
7 |
import io
|
8 |
import logging
|
|
|
9 |
|
10 |
# Set up logging
|
11 |
logging.basicConfig(level=logging.INFO)
|
|
|
13 |
|
14 |
# Load model and processor
|
15 |
ckpt = "Qwen/Qwen2.5-VL-7B-Instruct"
|
16 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(ckpt, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")
|
|
|
|
|
|
|
|
|
|
|
17 |
processor = AutoProcessor.from_pretrained(ckpt, trust_remote_code=True)
|
18 |
|
19 |
+
class DocumentState:
|
20 |
def __init__(self):
|
21 |
+
self.current_doc_images = []
|
22 |
+
self.current_doc_text = ""
|
23 |
+
self.doc_type = None
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
def clear(self):
|
26 |
+
self.current_doc_images = []
|
27 |
+
self.current_doc_text = ""
|
28 |
+
self.doc_type = None
|
29 |
+
|
30 |
+
doc_state = DocumentState()
|
31 |
|
32 |
+
def process_pdf_file(file_path):
|
33 |
+
"""Convert PDF to images and extract text using PyMuPDF."""
|
34 |
try:
|
35 |
doc = fitz.open(file_path)
|
36 |
images = []
|
37 |
+
text = ""
|
38 |
+
|
39 |
+
for page_num in range(doc.page_count):
|
40 |
+
try:
|
41 |
+
page = doc[page_num]
|
42 |
+
page_text = page.get_text("text")
|
43 |
+
if page_text.strip():
|
44 |
+
text += f"Page {page_num + 1}:\n{page_text}\n\n"
|
45 |
+
|
46 |
+
zoom = 3
|
47 |
+
mat = fitz.Matrix(zoom, zoom)
|
48 |
+
pix = page.get_pixmap(matrix=mat, alpha=False)
|
49 |
+
img_data = pix.tobytes("png")
|
50 |
+
img = Image.open(io.BytesIO(img_data))
|
51 |
+
img = img.convert("RGB")
|
52 |
+
|
53 |
+
max_size = 1600
|
54 |
+
if max(img.size) > max_size:
|
55 |
+
ratio = max_size / max(img.size)
|
56 |
+
new_size = tuple(int(dim * ratio) for dim in img.size)
|
57 |
+
img = img.resize(new_size, Image.Resampling.LANCZOS)
|
58 |
+
|
59 |
+
images.append(img)
|
60 |
+
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"Error processing page {page_num}: {str(e)}")
|
63 |
+
continue
|
64 |
+
|
65 |
doc.close()
|
66 |
+
|
67 |
+
if not images:
|
68 |
+
raise ValueError("No valid images could be extracted from the PDF")
|
69 |
+
|
70 |
+
return images, text
|
71 |
+
|
72 |
except Exception as e:
|
73 |
+
logger.error(f"Error processing PDF file: {str(e)}")
|
74 |
+
raise
|
75 |
|
76 |
+
def process_uploaded_file(file):
|
77 |
+
"""Process uploaded file and update document state."""
|
78 |
+
try:
|
79 |
+
doc_state.clear()
|
80 |
+
|
81 |
+
if file is None:
|
82 |
+
return "No file uploaded. Please upload a file."
|
83 |
+
|
84 |
+
if isinstance(file, dict):
|
85 |
+
file_path = file["name"]
|
86 |
else:
|
87 |
+
file_path = file.name
|
88 |
+
|
89 |
+
file_ext = file_path.lower().split('.')[-1]
|
90 |
+
image_extensions = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
|
91 |
+
|
92 |
+
if file_ext == 'pdf':
|
93 |
+
doc_state.doc_type = 'pdf'
|
94 |
try:
|
95 |
+
doc_state.current_doc_images, doc_state.current_doc_text = process_pdf_file(file_path)
|
96 |
+
return f"PDF processed successfully. Total pages: {len(doc_state.current_doc_images)}. You can now ask questions about the content."
|
97 |
except Exception as e:
|
98 |
+
return f"Error processing PDF: {str(e)}. Please try a different PDF file."
|
99 |
+
elif file_ext in image_extensions:
|
100 |
+
doc_state.doc_type = 'image'
|
101 |
+
try:
|
102 |
+
img = Image.open(file_path).convert("RGB")
|
103 |
+
max_size = 1600
|
104 |
+
if max(img.size) > max_size:
|
105 |
+
ratio = max_size / max(img.size)
|
106 |
+
new_size = tuple(int(dim * ratio) for dim in img.size)
|
107 |
+
img = img.resize(new_size, Image.Resampling.LANCZOS)
|
108 |
+
doc_state.current_doc_images = [img]
|
109 |
+
return "Image loaded successfully. You can now ask questions about the content."
|
110 |
+
except Exception as e:
|
111 |
+
return f"Error processing image: {str(e)}. Please try a different image file."
|
112 |
+
else:
|
113 |
+
return f"Unsupported file type: {file_ext}. Please upload a PDF or image file."
|
114 |
+
except Exception as e:
|
115 |
+
logger.error(f"Error in process_file: {str(e)}")
|
116 |
+
return "An error occurred while processing the file. Please try again."
|
117 |
|
118 |
+
@spaces.GPU()
|
119 |
+
def bot_streaming(user_prompt, max_new_tokens=4096):
|
120 |
try:
|
121 |
+
if not user_prompt.strip():
|
122 |
+
yield "Please enter a valid prompt/question."
|
123 |
+
return
|
124 |
+
|
|
|
|
|
|
|
|
|
125 |
messages = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
+
# Include document context
|
128 |
+
if doc_state.current_doc_images:
|
129 |
+
context = f"\nDocument context:\n{doc_state.current_doc_text}" if doc_state.current_doc_text else ""
|
130 |
+
current_msg = f"{user_prompt}{context}"
|
131 |
+
messages.append({"role": "user", "content": [{"type": "text", "text": current_msg}, {"type": "image"}]})
|
132 |
+
else:
|
133 |
+
messages.append({"role": "user", "content": [{"type": "text", "text": user_prompt}]})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
+
# Process inputs
|
136 |
+
texts = processor.apply_chat_template(messages, add_generation_prompt=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
+
try:
|
139 |
+
if doc_state.current_doc_images:
|
140 |
+
inputs = processor(
|
141 |
+
text=texts,
|
142 |
+
images=doc_state.current_doc_images[0:1],
|
143 |
+
return_tensors="pt"
|
144 |
+
).to("cuda")
|
145 |
+
else:
|
146 |
+
inputs = processor(text=texts, return_tensors="pt").to("cuda")
|
147 |
+
|
148 |
+
streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
|
149 |
+
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
|
150 |
+
|
151 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
152 |
+
thread.start()
|
153 |
+
|
154 |
+
buffer = ""
|
155 |
+
for new_text in streamer:
|
156 |
+
buffer += new_text
|
157 |
+
time.sleep(0.01)
|
158 |
+
yield buffer
|
159 |
+
|
160 |
+
except Exception as e:
|
161 |
+
logger.error(f"Error in model processing: {str(e)}")
|
162 |
+
yield "An error occurred while processing your request. Please try again."
|
163 |
+
|
164 |
except Exception as e:
|
165 |
+
logger.error(f"Error in bot_streaming: {str(e)}")
|
166 |
yield "An error occurred. Please try again."
|
167 |
|
168 |
+
def clear_context():
|
169 |
+
"""Clear the current document context."""
|
170 |
+
doc_state.clear()
|
171 |
+
return "Document context cleared. You can upload a new document."
|
172 |
|
173 |
+
# Create the Gradio interface
|
174 |
+
with gr.Blocks() as demo:
|
175 |
+
gr.Markdown("# Document Analyzer with Custom Prompts")
|
176 |
+
gr.Markdown("Upload a document and enter your custom prompt/question about its contents.")
|
177 |
|
178 |
+
with gr.Row():
|
179 |
+
file_upload = gr.File(
|
180 |
+
label="Upload Document (PDF or Image)",
|
181 |
+
file_types=[".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"]
|
182 |
+
)
|
183 |
+
upload_status = gr.Textbox(
|
184 |
+
label="Upload Status",
|
185 |
+
interactive=False
|
186 |
+
)
|
187 |
+
|
188 |
+
with gr.Row():
|
189 |
+
user_prompt = gr.Textbox(
|
190 |
+
label="Enter your prompt/question",
|
191 |
+
placeholder="e.g., Explain this document...\nExtract key points...\nWhat is the main idea?",
|
192 |
+
lines=3
|
193 |
+
)
|
194 |
+
generate_btn = gr.Button("Generate")
|
195 |
+
|
196 |
+
clear_btn = gr.Button("Clear Document Context")
|
197 |
+
|
198 |
+
output_text = gr.Textbox(
|
199 |
+
label="Output",
|
200 |
+
interactive=False
|
201 |
)
|
202 |
|
203 |
+
file_upload.change(
|
204 |
+
fn=process_uploaded_file,
|
205 |
+
inputs=[file_upload],
|
206 |
+
outputs=[upload_status]
|
207 |
+
)
|
208 |
+
|
209 |
+
generate_btn.click(
|
210 |
+
fn=bot_streaming,
|
211 |
+
inputs=[user_prompt],
|
212 |
+
outputs=[output_text]
|
213 |
+
)
|
214 |
|
215 |
+
clear_btn.click(
|
216 |
+
fn=clear_context,
|
217 |
+
outputs=[upload_status]
|
218 |
+
)
|
219 |
|
220 |
+
# Launch the interface
|
221 |
+
demo.launch(debug=True)
|