Shak33l-UiRev commited on
Commit
d1abdf9
·
verified ·
1 Parent(s): e5a11be

removed BROS model & Adding OmniParser

Browse files

main changes include:

OmniParser Integration:

Added YOLO model loading for icon detection
Added Florence-2 model for captioning
Proper handling of both models in the pipeline


Analysis Pipeline:

Object detection with YOLO
Caption generation for detected elements
Structured output with bounding boxes and descriptions


User Interface:

Updated model information
Added UI-specific strengths and capabilities
Proper debug information for UI parsing

Files changed (1) hide show
  1. app.py +96 -46
app.py CHANGED
@@ -1,20 +1,20 @@
1
  import streamlit as st
2
  from PIL import Image
3
  import torch
4
- import json
5
  from transformers import (
6
  DonutProcessor,
7
  VisionEncoderDecoderModel,
8
  LayoutLMv3Processor,
9
  LayoutLMv3ForSequenceClassification,
10
- BrosProcessor,
11
- BrosForTokenClassification,
12
- LlavaProcessor,
13
- LlavaForConditionalGeneration
14
  )
 
 
 
 
15
  from datetime import datetime
16
 
17
- # Cache the model loading to improve performance
18
  @st.cache_resource
19
  def load_model(model_name):
20
  """Load the selected model and processor"""
@@ -31,13 +31,21 @@ def load_model(model_name):
31
  processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
32
  model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
33
 
34
- elif model_name == "BROS":
35
- processor = BrosProcessor.from_pretrained("microsoft/bros-base")
36
- model = BrosForTokenClassification.from_pretrained("microsoft/bros-base")
37
-
38
- elif model_name == "LLaVA-1.5":
39
- processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
40
- model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
 
 
 
 
 
 
 
 
41
 
42
  return model, processor
43
  except Exception as e:
@@ -47,14 +55,54 @@ def load_model(model_name):
47
  def analyze_document(image, model_name, model, processor):
48
  """Analyze document using selected model"""
49
  try:
50
- # Process image according to model requirements
51
- if model_name == "Donut":
52
- # Prepare input with task prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  pixel_values = processor(image, return_tensors="pt").pixel_values
54
  task_prompt = "<s_cord>analyze the document and extract information</s_cord>"
55
  decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
56
 
57
- # Generate output with improved parameters
58
  outputs = model.generate(
59
  pixel_values,
60
  decoder_input_ids=decoder_input_ids,
@@ -68,31 +116,41 @@ def analyze_document(image, model_name, model, processor):
68
  return_dict_in_generate=True
69
  )
70
 
71
- # Process and clean the output
72
  sequence = processor.batch_decode(outputs.sequences)[0]
73
  sequence = sequence.replace(task_prompt, "").replace("</s_cord>", "").strip()
74
 
75
- # Try to parse as JSON, fallback to raw text
76
  try:
77
  result = json.loads(sequence)
78
  except json.JSONDecodeError:
79
  result = {"raw_text": sequence}
80
 
81
  elif model_name == "LayoutLMv3":
82
- inputs = processor(image, return_tensors="pt")
83
- outputs = model(**inputs)
84
- result = {"logits": outputs.logits.tolist()} # Convert tensor to list for JSON serialization
 
 
 
85
 
86
- elif model_name == "BROS":
87
- inputs = processor(image, return_tensors="pt")
88
- outputs = model(**inputs)
89
- result = {"predictions": outputs.logits.tolist()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- elif model_name == "LLaVA-1.5":
92
- inputs = processor(image, return_tensors="pt")
93
- outputs = model.generate(**inputs, max_length=256)
94
- result = {"generated_text": processor.decode(outputs[0], skip_special_tokens=True)}
95
-
96
  return result
97
 
98
  except Exception as e:
@@ -157,26 +215,18 @@ with col2:
157
  "Donut": {
158
  "description": "Best for structured OCR and document format understanding",
159
  "memory": "6-8GB",
160
- "strengths": ["Structured OCR", "Memory efficient", "Good with fixed formats"],
161
- "best_for": ["Invoices", "Forms", "Structured documents"]
162
  },
163
  "LayoutLMv3": {
164
  "description": "Strong layout understanding with reasoning capabilities",
165
  "memory": "12-15GB",
166
- "strengths": ["Layout understanding", "Reasoning", "Pre-trained knowledge"],
167
- "best_for": ["Complex layouts", "Mixed content", "Tables"]
168
- },
169
- "BROS": {
170
- "description": "Memory efficient with fast inference",
171
- "memory": "4-6GB",
172
- "strengths": ["Fast inference", "Memory efficient", "Easy fine-tuning"],
173
- "best_for": ["Simple documents", "Quick analysis", "Basic OCR"]
174
  },
175
- "LLaVA-1.5": {
176
- "description": "Comprehensive OCR with strong reasoning",
177
- "memory": "25-40GB",
178
- "strengths": ["Strong reasoning", "Zero-shot capable", "Visual understanding"],
179
- "best_for": ["Complex documents", "Natural language understanding", "Visual QA"]
180
  }
181
  }
182
 
 
1
  import streamlit as st
2
  from PIL import Image
3
  import torch
 
4
  from transformers import (
5
  DonutProcessor,
6
  VisionEncoderDecoderModel,
7
  LayoutLMv3Processor,
8
  LayoutLMv3ForSequenceClassification,
9
+ AutoProcessor,
10
+ AutoModelForCausalLM
 
 
11
  )
12
+ from ultralytics import YOLO
13
+ import io
14
+ import base64
15
+ import json
16
  from datetime import datetime
17
 
 
18
  @st.cache_resource
19
  def load_model(model_name):
20
  """Load the selected model and processor"""
 
31
  processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
32
  model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
33
 
34
+ elif model_name == "OmniParser":
35
+ # Load YOLO model for icon detection
36
+ yolo_model = YOLO('microsoft/OmniParser', task='detect')
37
+ # Load Florence-2 model for captioning
38
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ "microsoft/OmniParser",
41
+ torch_dtype=torch.float16,
42
+ trust_remote_code=True
43
+ )
44
+ return {
45
+ 'yolo': yolo_model,
46
+ 'processor': processor,
47
+ 'model': model
48
+ }
49
 
50
  return model, processor
51
  except Exception as e:
 
55
  def analyze_document(image, model_name, model, processor):
56
  """Analyze document using selected model"""
57
  try:
58
+ if model_name == "OmniParser":
59
+ # Save image temporarily
60
+ temp_path = "temp_image.png"
61
+ image.save(temp_path)
62
+
63
+ # Configure box detection parameters
64
+ box_threshold = 0.05 # Can be made configurable
65
+ iou_threshold = 0.1 # Can be made configurable
66
+
67
+ # Run YOLO detection
68
+ yolo_results = model['yolo'](
69
+ temp_path,
70
+ conf=box_threshold,
71
+ iou=iou_threshold,
72
+ device='cpu' if not torch.cuda.is_available() else 'cuda'
73
+ )
74
+
75
+ # Process detections
76
+ results = []
77
+ for det in yolo_results[0].boxes.data:
78
+ x1, y1, x2, y2, conf, cls = det
79
+
80
+ # Get region of interest
81
+ roi = image.crop((x1, y1, x2, y2))
82
+
83
+ # Generate caption using Florence-2
84
+ inputs = processor(images=roi, return_tensors="pt")
85
+ outputs = model['model'].generate(**inputs, max_length=50)
86
+ caption = processor.decode(outputs[0], skip_special_tokens=True)
87
+
88
+ results.append({
89
+ "bbox": [float(x) for x in [x1, y1, x2, y2]],
90
+ "confidence": float(conf),
91
+ "class": int(cls),
92
+ "caption": caption
93
+ })
94
+
95
+ return {
96
+ "detected_elements": len(results),
97
+ "elements": results
98
+ }
99
+
100
+ # [Previous model handling remains the same...]
101
+ elif model_name == "Donut":
102
  pixel_values = processor(image, return_tensors="pt").pixel_values
103
  task_prompt = "<s_cord>analyze the document and extract information</s_cord>"
104
  decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
105
 
 
106
  outputs = model.generate(
107
  pixel_values,
108
  decoder_input_ids=decoder_input_ids,
 
116
  return_dict_in_generate=True
117
  )
118
 
 
119
  sequence = processor.batch_decode(outputs.sequences)[0]
120
  sequence = sequence.replace(task_prompt, "").replace("</s_cord>", "").strip()
121
 
 
122
  try:
123
  result = json.loads(sequence)
124
  except json.JSONDecodeError:
125
  result = {"raw_text": sequence}
126
 
127
  elif model_name == "LayoutLMv3":
128
+ encoded_inputs = processor(
129
+ image,
130
+ return_tensors="pt",
131
+ add_special_tokens=True,
132
+ return_offsets_mapping=True
133
+ )
134
 
135
+ outputs = model(**encoded_inputs)
136
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
137
+
138
+ words = processor.tokenizer.convert_ids_to_tokens(
139
+ encoded_inputs.input_ids.squeeze().tolist()
140
+ )
141
+
142
+ result = {
143
+ "predictions": [
144
+ {
145
+ "text": word,
146
+ "label": pred
147
+ }
148
+ for word, pred in zip(words, predictions)
149
+ if word not in ["<s>", "</s>", "<pad>"]
150
+ ],
151
+ "confidence_scores": outputs.logits.softmax(-1).max(-1).values.squeeze().tolist()
152
+ }
153
 
 
 
 
 
 
154
  return result
155
 
156
  except Exception as e:
 
215
  "Donut": {
216
  "description": "Best for structured OCR and document format understanding",
217
  "memory": "6-8GB",
218
+ "strengths": ["Structured OCR", "Memory efficient", "Good with fixed formats"]
 
219
  },
220
  "LayoutLMv3": {
221
  "description": "Strong layout understanding with reasoning capabilities",
222
  "memory": "12-15GB",
223
+ "strengths": ["Layout understanding", "Reasoning", "Pre-trained knowledge"]
 
 
 
 
 
 
 
224
  },
225
+ "OmniParser": {
226
+ "description": "General screen parsing tool for UI understanding",
227
+ "memory": "8-10GB",
228
+ "strengths": ["UI element detection", "Interactive element recognition", "Function description"],
229
+ "best_for": ["Screenshots", "UI analysis", "Interactive elements"]
230
  }
231
  }
232