Daemontatox commited on
Commit
0cac8da
·
verified ·
1 Parent(s): 00e7090

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -107
app.py CHANGED
@@ -3,11 +3,9 @@ from PIL import Image
3
  import torch
4
  from threading import Thread
5
  import gradio as gr
6
- import spaces
7
  import fitz # PyMuPDF
8
  import io
9
  import logging
10
- from typing import List, Dict, Any
11
 
12
  # Set up logging
13
  logging.basicConfig(level=logging.INFO)
@@ -15,138 +13,209 @@ logger = logging.getLogger(__name__)
15
 
16
  # Load model and processor
17
  ckpt = "Qwen/Qwen2.5-VL-7B-Instruct"
18
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
19
- ckpt,
20
- torch_dtype=torch.bfloat16,
21
- device_map="auto",
22
- trust_remote_code=True
23
- )
24
  processor = AutoProcessor.from_pretrained(ckpt, trust_remote_code=True)
25
 
26
- class ChatState:
27
  def __init__(self):
28
- self.conversation_history: List[Dict[str, Any]] = []
29
- self.current_images: List[Image.Image] = []
30
-
31
- def add_message(self, role: str, content: Any, images: List[Image.Image] = None):
32
- self.conversation_history.append({
33
- "role": role,
34
- "content": content,
35
- "images": images or []
36
- })
37
 
38
  def clear(self):
39
- self.conversation_history = []
40
- self.current_images = []
41
-
42
- chat_state = ChatState()
 
43
 
44
- def process_pdf(file_path: str, max_pages: int = 200) -> List[Image.Image]:
45
- """Process PDF file into images (first 5 pages max for demo)"""
46
  try:
47
  doc = fitz.open(file_path)
48
  images = []
49
- for page_num in range(min(doc.page_count, max_pages)):
50
- page = doc.load_page(page_num)
51
- pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
52
- img_data = pix.tobytes("ppm")
53
- images.append(Image.open(io.BytesIO(img_data)).convert("RGB"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  doc.close()
55
- return images
 
 
 
 
 
56
  except Exception as e:
57
- logger.error(f"PDF processing error: {str(e)}")
58
- return []
59
 
60
- def handle_file_upload(files: List[str]) -> List[Image.Image]:
61
- """Handle uploaded files (PDF or images)"""
62
- images = []
63
- for file_path in files:
64
- if file_path.lower().endswith('pdf'):
65
- images.extend(process_pdf(file_path))
 
 
 
 
66
  else:
 
 
 
 
 
 
 
67
  try:
68
- images.append(Image.open(file_path).convert("RGB"))
 
69
  except Exception as e:
70
- logger.error(f"Image processing error: {str(e)}")
71
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- @spaces.GPU
74
- def chat_streaming(message: Dict, history: List, max_new_tokens: int = 1024):
75
  try:
76
- # Process user input
77
- user_text = message["text"]
78
- user_images = handle_file_upload([f["path"] for f in message["files"]]) if message["files"] else []
79
-
80
- # Update chat state
81
- chat_state.add_message("user", user_text, user_images)
82
-
83
- # Build conversation history for model
84
  messages = []
85
- for msg in chat_state.conversation_history:
86
- content = []
87
- if msg["role"] == "user":
88
- content.append({"type": "text", "text": msg["content"]})
89
- for img in msg["images"]:
90
- content.append({"type": "image"})
91
- messages.append({"role": "user", "content": content})
92
- else:
93
- messages.append({"role": "assistant", "content": msg["content"]})
94
 
95
- # Prepare model inputs
96
- model_inputs = processor.apply_chat_template(
97
- messages,
98
- add_generation_prompt=True,
99
- tokenize=False
100
- )
101
-
102
- # Get all images from history
103
- all_images = [img for msg in chat_state.conversation_history for img in msg["images"]]
104
-
105
- inputs = processor(
106
- text=model_inputs,
107
- images=all_images if all_images else None,
108
- return_tensors="pt"
109
- ).to(model.device)
110
 
111
- # Stream response
112
- streamer = TextIteratorStreamer(processor, skip_special_tokens=True)
113
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
114
-
115
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
116
- thread.start()
117
-
118
- buffer = ""
119
- for new_text in streamer:
120
- buffer += new_text
121
- yield buffer
122
-
123
- # Save final response
124
- chat_state.add_message("assistant", buffer)
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  except Exception as e:
127
- logger.error(f"Chat error: {str(e)}")
128
  yield "An error occurred. Please try again."
129
 
130
- def clear_chat():
131
- """Clear chat history"""
132
- chat_state.clear()
133
- return "Chat history cleared. Start a new conversation."
134
 
135
- # Create Gradio interface
136
- with gr.Blocks(title="Multimodal Chat Assistant") as demo:
137
- gr.Markdown("# Multimodal Chat Assistant")
138
- gr.Markdown("Chat with documents and images! Upload PDFs or images and ask questions.")
139
 
140
- chat_interface = gr.ChatInterface(
141
- fn=chat_streaming,
142
- additional_inputs=[
143
- gr.Slider(100, 4096, value=1024, label="Max Response Length"),
144
- gr.File(file_count="multiple", file_types=["image", "pdf", "text"], label="Upload Files")
145
- ],
146
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  )
148
 
 
 
 
 
 
 
 
 
 
 
 
149
 
 
 
 
 
150
 
151
- if __name__ == "__main__":
152
- demo.launch(debug=True)
 
3
  import torch
4
  from threading import Thread
5
  import gradio as gr
 
6
  import fitz # PyMuPDF
7
  import io
8
  import logging
 
9
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO)
 
13
 
14
  # Load model and processor
15
  ckpt = "Qwen/Qwen2.5-VL-7B-Instruct"
16
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(ckpt, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")
 
 
 
 
 
17
  processor = AutoProcessor.from_pretrained(ckpt, trust_remote_code=True)
18
 
19
+ class DocumentState:
20
  def __init__(self):
21
+ self.current_doc_images = []
22
+ self.current_doc_text = ""
23
+ self.doc_type = None
 
 
 
 
 
 
24
 
25
  def clear(self):
26
+ self.current_doc_images = []
27
+ self.current_doc_text = ""
28
+ self.doc_type = None
29
+
30
+ doc_state = DocumentState()
31
 
32
+ def process_pdf_file(file_path):
33
+ """Convert PDF to images and extract text using PyMuPDF."""
34
  try:
35
  doc = fitz.open(file_path)
36
  images = []
37
+ text = ""
38
+
39
+ for page_num in range(doc.page_count):
40
+ try:
41
+ page = doc[page_num]
42
+ page_text = page.get_text("text")
43
+ if page_text.strip():
44
+ text += f"Page {page_num + 1}:\n{page_text}\n\n"
45
+
46
+ zoom = 3
47
+ mat = fitz.Matrix(zoom, zoom)
48
+ pix = page.get_pixmap(matrix=mat, alpha=False)
49
+ img_data = pix.tobytes("png")
50
+ img = Image.open(io.BytesIO(img_data))
51
+ img = img.convert("RGB")
52
+
53
+ max_size = 1600
54
+ if max(img.size) > max_size:
55
+ ratio = max_size / max(img.size)
56
+ new_size = tuple(int(dim * ratio) for dim in img.size)
57
+ img = img.resize(new_size, Image.Resampling.LANCZOS)
58
+
59
+ images.append(img)
60
+
61
+ except Exception as e:
62
+ logger.error(f"Error processing page {page_num}: {str(e)}")
63
+ continue
64
+
65
  doc.close()
66
+
67
+ if not images:
68
+ raise ValueError("No valid images could be extracted from the PDF")
69
+
70
+ return images, text
71
+
72
  except Exception as e:
73
+ logger.error(f"Error processing PDF file: {str(e)}")
74
+ raise
75
 
76
+ def process_uploaded_file(file):
77
+ """Process uploaded file and update document state."""
78
+ try:
79
+ doc_state.clear()
80
+
81
+ if file is None:
82
+ return "No file uploaded. Please upload a file."
83
+
84
+ if isinstance(file, dict):
85
+ file_path = file["name"]
86
  else:
87
+ file_path = file.name
88
+
89
+ file_ext = file_path.lower().split('.')[-1]
90
+ image_extensions = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
91
+
92
+ if file_ext == 'pdf':
93
+ doc_state.doc_type = 'pdf'
94
  try:
95
+ doc_state.current_doc_images, doc_state.current_doc_text = process_pdf_file(file_path)
96
+ return f"PDF processed successfully. Total pages: {len(doc_state.current_doc_images)}. You can now ask questions about the content."
97
  except Exception as e:
98
+ return f"Error processing PDF: {str(e)}. Please try a different PDF file."
99
+ elif file_ext in image_extensions:
100
+ doc_state.doc_type = 'image'
101
+ try:
102
+ img = Image.open(file_path).convert("RGB")
103
+ max_size = 1600
104
+ if max(img.size) > max_size:
105
+ ratio = max_size / max(img.size)
106
+ new_size = tuple(int(dim * ratio) for dim in img.size)
107
+ img = img.resize(new_size, Image.Resampling.LANCZOS)
108
+ doc_state.current_doc_images = [img]
109
+ return "Image loaded successfully. You can now ask questions about the content."
110
+ except Exception as e:
111
+ return f"Error processing image: {str(e)}. Please try a different image file."
112
+ else:
113
+ return f"Unsupported file type: {file_ext}. Please upload a PDF or image file."
114
+ except Exception as e:
115
+ logger.error(f"Error in process_file: {str(e)}")
116
+ return "An error occurred while processing the file. Please try again."
117
 
118
+ @spaces.GPU()
119
+ def bot_streaming(user_prompt, max_new_tokens=4096):
120
  try:
121
+ if not user_prompt.strip():
122
+ yield "Please enter a valid prompt/question."
123
+ return
124
+
 
 
 
 
125
  messages = []
 
 
 
 
 
 
 
 
 
126
 
127
+ # Include document context
128
+ if doc_state.current_doc_images:
129
+ context = f"\nDocument context:\n{doc_state.current_doc_text}" if doc_state.current_doc_text else ""
130
+ current_msg = f"{user_prompt}{context}"
131
+ messages.append({"role": "user", "content": [{"type": "text", "text": current_msg}, {"type": "image"}]})
132
+ else:
133
+ messages.append({"role": "user", "content": [{"type": "text", "text": user_prompt}]})
 
 
 
 
 
 
 
 
134
 
135
+ # Process inputs
136
+ texts = processor.apply_chat_template(messages, add_generation_prompt=True)
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ try:
139
+ if doc_state.current_doc_images:
140
+ inputs = processor(
141
+ text=texts,
142
+ images=doc_state.current_doc_images[0:1],
143
+ return_tensors="pt"
144
+ ).to("cuda")
145
+ else:
146
+ inputs = processor(text=texts, return_tensors="pt").to("cuda")
147
+
148
+ streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
149
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
150
+
151
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
152
+ thread.start()
153
+
154
+ buffer = ""
155
+ for new_text in streamer:
156
+ buffer += new_text
157
+ time.sleep(0.01)
158
+ yield buffer
159
+
160
+ except Exception as e:
161
+ logger.error(f"Error in model processing: {str(e)}")
162
+ yield "An error occurred while processing your request. Please try again."
163
+
164
  except Exception as e:
165
+ logger.error(f"Error in bot_streaming: {str(e)}")
166
  yield "An error occurred. Please try again."
167
 
168
+ def clear_context():
169
+ """Clear the current document context."""
170
+ doc_state.clear()
171
+ return "Document context cleared. You can upload a new document."
172
 
173
+ # Create the Gradio interface
174
+ with gr.Blocks() as demo:
175
+ gr.Markdown("# Document Analyzer with Custom Prompts")
176
+ gr.Markdown("Upload a document and enter your custom prompt/question about its contents.")
177
 
178
+ with gr.Row():
179
+ file_upload = gr.File(
180
+ label="Upload Document (PDF or Image)",
181
+ file_types=[".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"]
182
+ )
183
+ upload_status = gr.Textbox(
184
+ label="Upload Status",
185
+ interactive=False
186
+ )
187
+
188
+ with gr.Row():
189
+ user_prompt = gr.Textbox(
190
+ label="Enter your prompt/question",
191
+ placeholder="e.g., Explain this document...\nExtract key points...\nWhat is the main idea?",
192
+ lines=3
193
+ )
194
+ generate_btn = gr.Button("Generate")
195
+
196
+ clear_btn = gr.Button("Clear Document Context")
197
+
198
+ output_text = gr.Textbox(
199
+ label="Output",
200
+ interactive=False
201
  )
202
 
203
+ file_upload.change(
204
+ fn=process_uploaded_file,
205
+ inputs=[file_upload],
206
+ outputs=[upload_status]
207
+ )
208
+
209
+ generate_btn.click(
210
+ fn=bot_streaming,
211
+ inputs=[user_prompt],
212
+ outputs=[output_text]
213
+ )
214
 
215
+ clear_btn.click(
216
+ fn=clear_context,
217
+ outputs=[upload_status]
218
+ )
219
 
220
+ # Launch the interface
221
+ demo.launch(debug=True)