Realcat commited on
Commit
4a7fc02
·
1 Parent(s): 260ecba

update: rerun ransac

Browse files
Files changed (3) hide show
  1. common/app_class.py +25 -2
  2. common/utils.py +59 -10
  3. common/viz.py +35 -1
common/app_class.py CHANGED
@@ -9,6 +9,7 @@ from common.utils import (
9
  load_config,
10
  get_matcher_zoo,
11
  run_matching,
 
12
  gen_examples,
13
  GRADIO_VERSION,
14
  )
@@ -159,7 +160,9 @@ class ImageMatchingApp:
159
  label="Ransac Iterations",
160
  value=self.cfg["defaults"]["ransac_max_iter"],
161
  )
162
-
 
 
163
  with gr.Accordion("Geometry Setting", open=False):
164
  with gr.Row(equal_height=False):
165
  choice_geometry_type = gr.Radio(
@@ -171,6 +174,7 @@ class ImageMatchingApp:
171
  )
172
 
173
  # collect inputs
 
174
  inputs = [
175
  input_image0,
176
  input_image1,
@@ -184,6 +188,7 @@ class ImageMatchingApp:
184
  ransac_max_iter,
185
  choice_geometry_type,
186
  gr.State(self.matcher_zoo),
 
187
  ]
188
 
189
  # Add some examples
@@ -207,7 +212,8 @@ class ImageMatchingApp:
207
  with gr.Column():
208
  output_keypoints = gr.Image(label="Keypoints", type="numpy")
209
  output_matches_raw = gr.Image(
210
- label="Raw Matches", type="numpy"
 
211
  )
212
  output_matches_ransac = gr.Image(
213
  label="Ransac Matches", type="numpy"
@@ -254,6 +260,7 @@ class ImageMatchingApp:
254
  matcher_info,
255
  geometry_result,
256
  output_wrapped,
 
257
  ]
258
  # button callbacks
259
  button_run.click(
@@ -288,6 +295,22 @@ class ImageMatchingApp:
288
  fn=self.ui_reset_state, inputs=None, outputs=reset_outputs
289
  )
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  # estimate geo
292
  choice_geometry_type.change(
293
  fn=generate_warp_images,
 
9
  load_config,
10
  get_matcher_zoo,
11
  run_matching,
12
+ run_ransac,
13
  gen_examples,
14
  GRADIO_VERSION,
15
  )
 
160
  label="Ransac Iterations",
161
  value=self.cfg["defaults"]["ransac_max_iter"],
162
  )
163
+ button_ransac = gr.Button(
164
+ value="Rerun RANSAC", variant="primary"
165
+ )
166
  with gr.Accordion("Geometry Setting", open=False):
167
  with gr.Row(equal_height=False):
168
  choice_geometry_type = gr.Radio(
 
174
  )
175
 
176
  # collect inputs
177
+ state_cache = gr.State({})
178
  inputs = [
179
  input_image0,
180
  input_image1,
 
188
  ransac_max_iter,
189
  choice_geometry_type,
190
  gr.State(self.matcher_zoo),
191
+ # state_cache,
192
  ]
193
 
194
  # Add some examples
 
212
  with gr.Column():
213
  output_keypoints = gr.Image(label="Keypoints", type="numpy")
214
  output_matches_raw = gr.Image(
215
+ label="Raw Matches",
216
+ type="numpy",
217
  )
218
  output_matches_ransac = gr.Image(
219
  label="Ransac Matches", type="numpy"
 
260
  matcher_info,
261
  geometry_result,
262
  output_wrapped,
263
+ state_cache,
264
  ]
265
  # button callbacks
266
  button_run.click(
 
295
  fn=self.ui_reset_state, inputs=None, outputs=reset_outputs
296
  )
297
 
298
+ # run ransac button action
299
+ button_ransac.click(
300
+ fn=run_ransac,
301
+ inputs=[
302
+ ransac_method,
303
+ ransac_reproj_threshold,
304
+ ransac_confidence,
305
+ ransac_max_iter,
306
+ state_cache,
307
+ ],
308
+ outputs=[
309
+ output_matches_ransac,
310
+ matches_result_info,
311
+ ],
312
+ )
313
+
314
  # estimate geo
315
  choice_geometry_type.change(
316
  fn=generate_warp_images,
common/utils.py CHANGED
@@ -265,12 +265,13 @@ def filter_matches(
265
  mask = np.array(mask.ravel().astype("bool"), dtype="bool")
266
  if H is not None:
267
  if feature_type == "KEYPOINT":
268
- pred["keypoints0_orig"] = mkpts0[mask]
269
- pred["keypoints1_orig"] = mkpts1[mask]
270
- pred["mconf"] = pred["mconf"][mask]
271
  elif feature_type == "LINE":
272
- pred["line_keypoints0_orig"] = mkpts0[mask]
273
- pred["line_keypoints1_orig"] = mkpts1[mask]
 
274
  return pred
275
 
276
 
@@ -440,6 +441,50 @@ def generate_warp_images(
440
  return None, None
441
 
442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  def run_matching(
444
  image0: np.ndarray,
445
  image1: np.ndarray,
@@ -496,7 +541,7 @@ def run_matching(
496
  output_matches_ransac = None
497
 
498
  # super slow!
499
- if "roma" in key.lower():
500
  gr.Info(
501
  f"Success! Please be patient and allow for about 2-3 minutes."
502
  f" Due to CPU inference, {key} is quiet slow."
@@ -592,7 +637,7 @@ def run_matching(
592
  "Image 1 - Ransac matched keypoints",
593
  ]
594
  output_matches_ransac, num_matches_ransac = display_matches(
595
- pred, titles=titles
596
  )
597
  gr.Info(f"Display matches done using: {time.time()-t1:.3f}s")
598
  logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
@@ -607,17 +652,20 @@ def run_matching(
607
  choice_geometry_type,
608
  )
609
  plt.close("all")
610
- del pred
611
  logger.info(f"TOTAL time: {time.time()-t0:.3f}s")
612
  gr.Info(f"In summary, total time: {time.time()-t0:.3f}s")
613
 
 
 
 
614
  return (
615
  output_keypoints,
616
  output_matches_raw,
617
  output_matches_ransac,
618
  {
619
- "number raw matches": num_matches_raw,
620
- "number ransac matches": num_matches_ransac,
621
  },
622
  {
623
  "match_conf": match_conf,
@@ -627,6 +675,7 @@ def run_matching(
627
  "geom_info": geom_info,
628
  },
629
  output_wrapped,
 
630
  )
631
 
632
 
 
265
  mask = np.array(mask.ravel().astype("bool"), dtype="bool")
266
  if H is not None:
267
  if feature_type == "KEYPOINT":
268
+ pred["mkeypoints0_orig"] = mkpts0[mask]
269
+ pred["mkeypoints1_orig"] = mkpts1[mask]
270
+ pred["mmconf"] = pred["mconf"][mask]
271
  elif feature_type == "LINE":
272
+ pred["mline_keypoints0_orig"] = mkpts0[mask]
273
+ pred["mline_keypoints1_orig"] = mkpts1[mask]
274
+ pred["H"] = H
275
  return pred
276
 
277
 
 
441
  return None, None
442
 
443
 
444
+ def run_ransac(
445
+ ransac_method: str = DEFAULT_RANSAC_METHOD,
446
+ ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD,
447
+ ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
448
+ ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
449
+ state_cache: Dict[str, Any] = None,
450
+ ):
451
+ t1 = time.time()
452
+ logger.info(
453
+ f"Run RANSAC matches using: {ransac_method} with threshold: {ransac_reproj_threshold}"
454
+ )
455
+ logger.info(
456
+ f"Run RANSAC matches using: {ransac_confidence} with iter: {ransac_max_iter}"
457
+ )
458
+ # if enable_ransac:
459
+ filter_matches(
460
+ state_cache,
461
+ ransac_method=ransac_method,
462
+ ransac_reproj_threshold=ransac_reproj_threshold,
463
+ ransac_confidence=ransac_confidence,
464
+ ransac_max_iter=ransac_max_iter,
465
+ )
466
+ gr.Info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
467
+ logger.info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
468
+ t1 = time.time()
469
+
470
+ # plot images with ransac matches
471
+ titles = [
472
+ "Image 0 - Ransac matched keypoints",
473
+ "Image 1 - Ransac matched keypoints",
474
+ ]
475
+ output_matches_ransac, num_matches_ransac = display_matches(
476
+ state_cache, titles=titles, tag="KPTS_RANSAC"
477
+ )
478
+ gr.Info(f"Display matches done using: {time.time()-t1:.3f}s")
479
+ logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
480
+ t1 = time.time()
481
+ num_matches_raw = state_cache["num_matches_raw"]
482
+ return output_matches_ransac, {
483
+ "num_matches_raw": num_matches_raw,
484
+ "num_matches_ransac": num_matches_ransac,
485
+ }
486
+
487
+
488
  def run_matching(
489
  image0: np.ndarray,
490
  image1: np.ndarray,
 
541
  output_matches_ransac = None
542
 
543
  # super slow!
544
+ if "roma" in key.lower() and device == "cpu":
545
  gr.Info(
546
  f"Success! Please be patient and allow for about 2-3 minutes."
547
  f" Due to CPU inference, {key} is quiet slow."
 
637
  "Image 1 - Ransac matched keypoints",
638
  ]
639
  output_matches_ransac, num_matches_ransac = display_matches(
640
+ pred, titles=titles, tag="KPTS_RANSAC"
641
  )
642
  gr.Info(f"Display matches done using: {time.time()-t1:.3f}s")
643
  logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
 
652
  choice_geometry_type,
653
  )
654
  plt.close("all")
655
+ # del pred
656
  logger.info(f"TOTAL time: {time.time()-t0:.3f}s")
657
  gr.Info(f"In summary, total time: {time.time()-t0:.3f}s")
658
 
659
+ state_cache = pred
660
+ state_cache["num_matches_raw"] = num_matches_raw
661
+ state_cache["num_matches_ransac"] = num_matches_ransac
662
  return (
663
  output_keypoints,
664
  output_matches_raw,
665
  output_matches_ransac,
666
  {
667
+ "num_raw_matches": num_matches_raw,
668
+ "num_ransac_matches": num_matches_ransac,
669
  },
670
  {
671
  "match_conf": match_conf,
 
675
  "geom_info": geom_info,
676
  },
677
  output_wrapped,
678
+ state_cache,
679
  )
680
 
681
 
common/viz.py CHANGED
@@ -156,7 +156,11 @@ def make_matching_figure(
156
  axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5)
157
 
158
  # draw matches
159
- if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
 
 
 
 
160
  fig.canvas.draw()
161
  transFigure = fig.transFigure.inverted()
162
  fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
@@ -377,6 +381,7 @@ def display_matches(
377
  titles: List[str] = [],
378
  texts: List[str] = [],
379
  dpi: int = 300,
 
380
  ) -> Tuple[np.ndarray, int]:
381
  """
382
  Displays the matches between two images.
@@ -393,11 +398,13 @@ def display_matches(
393
  img1 = pred["image1_orig"]
394
 
395
  num_inliers = 0
 
396
  if (
397
  "keypoints0_orig" in pred
398
  and "keypoints1_orig" in pred
399
  and pred["keypoints0_orig"] is not None
400
  and pred["keypoints1_orig"] is not None
 
401
  ):
402
  mkpts0 = pred["keypoints0_orig"]
403
  mkpts1 = pred["keypoints1_orig"]
@@ -417,11 +424,38 @@ def display_matches(
417
  texts=texts,
418
  )
419
  fig = fig_mkpts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  if (
421
  "line0_orig" in pred
422
  and "line1_orig" in pred
423
  and pred["line0_orig"] is not None
424
  and pred["line1_orig"] is not None
 
425
  ):
426
  # lines
427
  mtlines0 = pred["line0_orig"]
 
156
  axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5)
157
 
158
  # draw matches
159
+ if (
160
+ mkpts0.shape[0] != 0
161
+ and mkpts1.shape[0] != 0
162
+ and mkpts0.shape == mkpts1.shape
163
+ ):
164
  fig.canvas.draw()
165
  transFigure = fig.transFigure.inverted()
166
  fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
 
381
  titles: List[str] = [],
382
  texts: List[str] = [],
383
  dpi: int = 300,
384
+ tag: str = "KPTS_RAW", # KPTS_RAW, KPTS_RANSAC, LINES_RAW, LINES_RANSAC,
385
  ) -> Tuple[np.ndarray, int]:
386
  """
387
  Displays the matches between two images.
 
398
  img1 = pred["image1_orig"]
399
 
400
  num_inliers = 0
401
+ # draw raw matches
402
  if (
403
  "keypoints0_orig" in pred
404
  and "keypoints1_orig" in pred
405
  and pred["keypoints0_orig"] is not None
406
  and pred["keypoints1_orig"] is not None
407
+ and tag == "KPTS_RAW"
408
  ):
409
  mkpts0 = pred["keypoints0_orig"]
410
  mkpts1 = pred["keypoints1_orig"]
 
424
  texts=texts,
425
  )
426
  fig = fig_mkpts
427
+ elif (
428
+ "mkeypoints0_orig" in pred
429
+ and "mkeypoints1_orig" in pred
430
+ and pred["mkeypoints0_orig"] is not None
431
+ and pred["mkeypoints1_orig"] is not None
432
+ and tag == "KPTS_RANSAC"
433
+ ): # draw ransac matches
434
+ mkpts0 = pred["mkeypoints0_orig"]
435
+ mkpts1 = pred["mkeypoints1_orig"]
436
+ num_inliers = len(mkpts0)
437
+ if "mmconf" in pred:
438
+ mmconf = pred["mmconf"]
439
+ else:
440
+ mmconf = np.ones(len(mkpts0))
441
+ fig_mkpts = draw_matches_core(
442
+ mkpts0,
443
+ mkpts1,
444
+ img0,
445
+ img1,
446
+ mmconf,
447
+ dpi=dpi,
448
+ titles=titles,
449
+ texts=texts,
450
+ )
451
+ fig = fig_mkpts
452
+ # TODO: draw lines
453
  if (
454
  "line0_orig" in pred
455
  and "line1_orig" in pred
456
  and pred["line0_orig"] is not None
457
  and pred["line1_orig"] is not None
458
+ # and (tag == "LINES_RAW" or tag == "LINES_RANSAC")
459
  ):
460
  # lines
461
  mtlines0 = pred["line0_orig"]