rahulvenkk commited on
Commit
4d601e2
1 Parent(s): ee3d5e7
app.py CHANGED
@@ -17,6 +17,7 @@ dot_radius = 7 # Radius for the dots
17
  dot_thickness = -1 # Thickness for solid circle (-1 fills the circle)
18
  from PIL import Image
19
  import torch
 
20
  #load model
21
  from cwm.model.model_factory import model_factory
22
 
@@ -141,9 +142,42 @@ with gr.Blocks() as demo:
141
  def load_img(evt: gr.SelectData):
142
  img_path = evt.value['image']['path']
143
  img = np.array(Image.open(img_path))
 
 
 
 
 
 
144
  # print(f"Image uploaded with shape: {input.shape}")
145
  resized_img = resize_to_square(img)
146
- return resized_img, resized_img, img, []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
 
149
  def store_img(img):
@@ -154,7 +188,7 @@ with gr.Blocks() as demo:
154
 
155
  with gr.Row():
156
  with gr.Column():
157
- gallery = gr.Gallery( ["./assets/desk_1.jpg", "./assets/glasses.jpg", "./assets/watering_pot.jpg", "./assets/jordon.jpeg", "./assets/bird.jpg", "./assets/bread.jpg", "./assets/ducks.jpg", "./assets/robot_arm.jpg",], columns=5, allow_preview=False, label="Select an example image to test")
158
  # examples = gr.Examples(
159
  # examples=[
160
  # ["./assets/desk_1.jpg", "./assets/desk_1.jpg"],
@@ -228,13 +262,13 @@ with gr.Blocks() as demo:
228
  # Draw arrow
229
 
230
  # Draw dots at start and end points
231
- cv2.circle(temp, start_point, dot_radius, color, dot_thickness)
232
- cv2.circle(temp, end_point, dot_radius, color, dot_thickness)
233
 
234
  # If there is an odd number of points (e.g., only a start point), draw a dot for it
235
  if len(sel_pix) == 1:
236
  start_point = sel_pix[0]
237
- cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness)
238
 
239
  return temp if isinstance(temp, np.ndarray) else np.array(temp)
240
 
 
17
  dot_thickness = -1 # Thickness for solid circle (-1 fills the circle)
18
  from PIL import Image
19
  import torch
20
+ import json
21
  #load model
22
  from cwm.model.model_factory import model_factory
23
 
 
142
  def load_img(evt: gr.SelectData):
143
  img_path = evt.value['image']['path']
144
  img = np.array(Image.open(img_path))
145
+ # print(f"Image uploaded with shape: {input.shape}")
146
+ with open('./assets/intervention_test_images/annot.json', 'r') as f:
147
+ points_json = json.load(f)
148
+
149
+ points_json = points_json[os.path.basename(img_path)]
150
+
151
  # print(f"Image uploaded with shape: {input.shape}")
152
  resized_img = resize_to_square(img)
153
+
154
+ temp = resized_img.copy()
155
+
156
+ # Redraw all remaining arrows and dots
157
+ for i in range(0, len(points_json), 2):
158
+ start_point = points_json[i]
159
+ end_point = points_json[i + 1]
160
+ if start_point == end_point:
161
+ # Zero-length vector: Draw a dot
162
+ color = dot_color_fixed
163
+ else:
164
+ cv2.arrowedLine(temp, start_point, end_point, arrow_color, thickness, tipLength=tip_length,
165
+ line_type=cv2.LINE_AA)
166
+ color = arrow_color
167
+ # Draw arrow
168
+
169
+ # Draw dots at start and end points
170
+ cv2.circle(temp, start_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)
171
+ cv2.circle(temp, end_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)
172
+
173
+ # If there is an odd number of points (e.g., only a start point), draw a dot for it
174
+ if len(points_json) == 1:
175
+ start_point = points_json[0]
176
+ cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA)
177
+
178
+
179
+
180
+ return temp, resized_img, img, points_json
181
 
182
 
183
  def store_img(img):
 
188
 
189
  with gr.Row():
190
  with gr.Column():
191
+ gallery = gr.Gallery( ["./assets/ducks.jpg", "./assets/robot_arm.jpg", "./assets/bread.jpg", "./assets/bird.jpg", "./assets/desk_1.jpg", "./assets/glasses.jpg", "./assets/watering_pot.jpg"], columns=5, allow_preview=False, label="Select an example image to test")
192
  # examples = gr.Examples(
193
  # examples=[
194
  # ["./assets/desk_1.jpg", "./assets/desk_1.jpg"],
 
262
  # Draw arrow
263
 
264
  # Draw dots at start and end points
265
+ cv2.circle(temp, start_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)
266
+ cv2.circle(temp, end_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)
267
 
268
  # If there is an odd number of points (e.g., only a start point), draw a dot for it
269
  if len(sel_pix) == 1:
270
  start_point = sel_pix[0]
271
+ cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA)
272
 
273
  return temp if isinstance(temp, np.ndarray) else np.array(temp)
274
 
assets/intervention_test_images/annot.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bread.jpg": [[120, 257], [175, 269], [328, 375], [266, 353], [410, 217], [341, 248], [228, 149], [248, 211], [152, 129], [152, 129], [108, 51], [108, 51], [342, 39], [342, 39], [479, 93], [479, 93], [477, 390], [477, 390], [229, 486], [229, 486], [58, 442], [58, 442]]}