h-siyuan commited on
Commit
1aac498
·
verified ·
1 Parent(s): fff26a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -54
app.py CHANGED
@@ -81,47 +81,76 @@ def upload_to_s3(file_name, bucket, object_name=None):
81
  except NoCredentialsError:
82
  return False
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  @spaces.GPU
85
- def run_showui(image, query, session_id):
86
- """Main function for inference."""
87
  image_path = array_to_image_path(image, session_id)
88
 
89
- messages = [
90
- {
91
- "role": "user",
92
- "content": [
93
- {"type": "text", "text": _SYSTEM},
94
- {"type": "image", "image": image_path, "min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS},
95
- {"type": "text", "text": query}
96
- ],
97
- }
98
- ]
99
-
100
- global model
101
- model = model.to("cuda")
102
-
103
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
104
- image_inputs, video_inputs = process_vision_info(messages)
105
- inputs = processor(
106
- text=[text],
107
- images=image_inputs,
108
- videos=video_inputs,
109
- padding=True,
110
- return_tensors="pt"
111
- )
112
- inputs = inputs.to("cuda")
113
-
114
- generated_ids = model.generate(**inputs, max_new_tokens=128)
115
- generated_ids_trimmed = [
116
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
117
- ]
118
- output_text = processor.batch_decode(
119
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
120
- )[0]
121
-
122
- click_xy = ast.literal_eval(output_text)
123
- result_image = draw_point(image_path, click_xy, radius=10)
124
- return result_image, str(click_xy), image_path
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  def save_and_upload_data(image_path, query, session_id, is_example_image, votes=None):
127
  """Save the data to a JSON file and upload to S3."""
@@ -221,6 +250,10 @@ def build_demo(embed_mode, concurrency_count=1):
221
 
222
  Then upload/paste from clipboard 🤗
223
  """)
 
 
 
 
224
  textbox = gr.Textbox(
225
  show_label=True,
226
  placeholder="Enter a query (e.g., 'Click Nahant')",
@@ -258,13 +291,9 @@ def build_demo(embed_mode, concurrency_count=1):
258
  )
259
 
260
  with gr.Column(scale=8):
261
- output_img = gr.Image(type="pil", label="Output Image")
262
- gr.HTML(
263
- """
264
- <p><strong>Note:</strong> The <span style="color: red;">red point</span> on the output image represents the predicted clickable coordinates.</p>
265
- """
266
- )
267
- output_coords = gr.Textbox(label="Clickable Coordinates")
268
 
269
  gr.HTML(
270
  """
@@ -276,28 +305,28 @@ def build_demo(embed_mode, concurrency_count=1):
276
  downvote_btn = gr.Button(value="👎 Too bad!", variant="secondary")
277
  clear_btn = gr.Button(value="🗑️ Clear", interactive=True)
278
 
279
- def on_submit(image, query, is_example_image):
280
  if image is None:
281
  raise ValueError("No image provided. Please upload an image before submitting.")
282
 
283
  session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
284
 
285
- result_image, click_coords, image_path = run_showui(image, query, session_id)
286
 
287
- save_and_upload_data(image_path, query, session_id, is_example_image)
288
 
289
- return result_image, click_coords, image_path, session_id
290
 
291
  submit_btn.click(
292
  on_submit,
293
- [imagebox, textbox, is_example_dropdown],
294
- [output_img, output_coords, state_image_path, state_session_id],
295
  )
296
 
297
  clear_btn.click(
298
- lambda: (None, None, None, None, None, None),
299
  inputs=None,
300
- outputs=[imagebox, textbox, output_img, output_coords, state_image_path, state_session_id],
301
  queue=False
302
  )
303
 
@@ -324,4 +353,4 @@ if __name__ == "__main__":
324
  server_port=7860,
325
  ssr_mode=False,
326
  debug=True,
327
- )
 
81
  except NoCredentialsError:
82
  return False
83
 
84
+ def crop_image(image_path, click_xy, crop_factor=0.5):
85
+ """Crop the image around the click point."""
86
+ image = Image.open(image_path)
87
+ width, height = image.size
88
+ crop_width, crop_height = int(width * crop_factor), int(height * crop_factor)
89
+
90
+ center_x, center_y = int(click_xy[0] * width), int(click_xy[1] * height)
91
+ left = max(center_x - crop_width // 2, 0)
92
+ upper = max(center_y - crop_height // 2, 0)
93
+ right = min(center_x + crop_width // 2, width)
94
+ lower = min(center_y + crop_height // 2, height)
95
+
96
+ cropped_image = image.crop((left, upper, right, lower))
97
+ cropped_image_path = f"cropped_{os.path.basename(image_path)}"
98
+ cropped_image.save(cropped_image_path)
99
+
100
+ return cropped_image_path
101
+
102
  @spaces.GPU
103
+ def run_showui(image, query, session_id, iterations=2):
104
+ """Main function for iterative inference."""
105
  image_path = array_to_image_path(image, session_id)
106
 
107
+ click_xy = None
108
+ images_during_iterations = [] # List to store images at each step
109
+
110
+ for _ in range(iterations):
111
+ messages = [
112
+ {
113
+ "role": "user",
114
+ "content": [
115
+ {"type": "text", "text": _SYSTEM},
116
+ {"type": "image", "image": image_path, "min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS},
117
+ {"type": "text", "text": query}
118
+ ],
119
+ }
120
+ ]
121
+
122
+ global model
123
+ model = model.to("cuda")
124
+
125
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
126
+ image_inputs, video_inputs = process_vision_info(messages)
127
+ inputs = processor(
128
+ text=[text],
129
+ images=image_inputs,
130
+ videos=video_inputs,
131
+ padding=True,
132
+ return_tensors="pt"
133
+ )
134
+ inputs = inputs.to("cuda")
135
+
136
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
137
+ generated_ids_trimmed = [
138
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
139
+ ]
140
+ output_text = processor.batch_decode(
141
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
142
+ )[0]
143
+
144
+ click_xy = ast.literal_eval(output_text)
145
+
146
+ # Draw point on the current image
147
+ result_image = draw_point(image_path, click_xy, radius=10)
148
+ images_during_iterations.append(result_image) # Store the current image
149
+
150
+ # Crop the image for the next iteration
151
+ image_path = crop_image(image_path, click_xy)
152
+
153
+ return images_during_iterations, str(click_xy)
154
 
155
  def save_and_upload_data(image_path, query, session_id, is_example_image, votes=None):
156
  """Save the data to a JSON file and upload to S3."""
 
250
 
251
  Then upload/paste from clipboard 🤗
252
  """)
253
+
254
+ # Add a slider for iteration count
255
+ iteration_slider = gr.Slider(minimum=1, maximum=3, step=1, value=1, label="Refinement Steps")
256
+
257
  textbox = gr.Textbox(
258
  show_label=True,
259
  placeholder="Enter a query (e.g., 'Click Nahant')",
 
291
  )
292
 
293
  with gr.Column(scale=8):
294
+ # output_gallery = gr.Gallery(label="Iterative Refinement", object_fit="contain")
295
+ output_gallery = gr.Gallery(label="Iterative Refinement")
296
+ output_coords = gr.Textbox(label="Final Clickable Coordinates")
 
 
 
 
297
 
298
  gr.HTML(
299
  """
 
305
  downvote_btn = gr.Button(value="👎 Too bad!", variant="secondary")
306
  clear_btn = gr.Button(value="🗑️ Clear", interactive=True)
307
 
308
+ def on_submit(image, query, iterations, is_example_image):
309
  if image is None:
310
  raise ValueError("No image provided. Please upload an image before submitting.")
311
 
312
  session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
313
 
314
+ images_during_iterations, click_coords = run_showui(image, query, session_id, iterations)
315
 
316
+ save_and_upload_data(images_during_iterations[-1], query, session_id, is_example_image)
317
 
318
+ return images_during_iterations, click_coords, session_id
319
 
320
  submit_btn.click(
321
  on_submit,
322
+ [imagebox, textbox, iteration_slider, is_example_dropdown],
323
+ [output_gallery, output_coords, state_session_id],
324
  )
325
 
326
  clear_btn.click(
327
+ lambda: (None, None, None, None),
328
  inputs=None,
329
+ outputs=[imagebox, textbox, output_gallery, output_coords, state_session_id],
330
  queue=False
331
  )
332
 
 
353
  server_port=7860,
354
  ssr_mode=False,
355
  debug=True,
356
+ )