chanycha commited on
Commit
db551ba
1 Parent(s): fd3b413
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -483,18 +483,29 @@ class ExplainerCheckbox(Component):
483
 
484
  data_id = self.gallery.selected_index
485
 
486
- optimized, _, _ = self.experiment.optimize(
487
- data_id=data_id.value,
488
  explainer_id=self.default_exp_id,
489
  metric_id=self.obj_metric,
490
  direction='maximize',
491
  sampler=SAMPLE_METHOD,
492
  n_trials=OPT_N_TRIALS,
493
  )
 
494
 
495
- opt_explainer_id = optimized['explainer_id']
496
- opt_postprocessor_id = optimized['postprocessor_id']
497
-
 
 
 
 
 
 
 
 
 
 
498
  self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id)
499
  self.optimal_exp_id = opt_explainer_id
500
  checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
@@ -628,6 +639,8 @@ from torch.utils.data import DataLoader
628
  from helpers import get_imagenet_dataset, get_torchvision_model, denormalize_image
629
 
630
  os.environ['GRADIO_TEMP_DIR'] = '.tmp'
 
 
631
 
632
  def target_visualizer(x): return dataset.dataset.idx_to_label(x.item())
633
 
@@ -637,7 +650,7 @@ model, transform = get_torchvision_model('resnet18')
637
  dataset = get_imagenet_dataset(transform)
638
  loader = DataLoader(dataset, batch_size=4, shuffle=False)
639
  experiment1 = AutoExplanationForImageClassification(
640
- model=model,
641
  data=loader,
642
  input_extractor=lambda batch: batch[0],
643
  label_extractor=lambda batch: batch[-1],
@@ -657,7 +670,7 @@ model, transform = get_torchvision_model('vit_b_16')
657
  dataset = get_imagenet_dataset(transform)
658
  loader = DataLoader(dataset, batch_size=4, shuffle=False)
659
  experiment2 = AutoExplanationForImageClassification(
660
- model=model,
661
  data=loader,
662
  input_extractor=lambda batch: batch[0],
663
  label_extractor=lambda batch: batch[-1],
 
483
 
484
  data_id = self.gallery.selected_index
485
 
486
+ opt_output = self.experiment.optimize(
487
+ data_ids=data_id.value,
488
  explainer_id=self.default_exp_id,
489
  metric_id=self.obj_metric,
490
  direction='maximize',
491
  sampler=SAMPLE_METHOD,
492
  n_trials=OPT_N_TRIALS,
493
  )
494
+
495
 
496
+ def get_str_ppid(pp_obj):
497
+ return pp_obj.pooling_fn.__class__.__name__ + pp_obj.normalization_fn.__class__.__name__
498
+
499
+ str_id = get_str_ppid(opt_output.postprocessor)
500
+ for pp_obj, pp_id in zip(*self.experiment.manager.get_postprocessors()):
501
+ if get_str_ppid(pp_obj) == str_id:
502
+ opt_postprocessor_id = pp_id
503
+ break
504
+
505
+ opt_explainer_id = max([x['id'] for x in self.groups.info]) + 1
506
+ opt_output.explainer.model = self.experiment.model
507
+ self.experiment.manager._explainers.append(opt_output.explainer)
508
+ self.experiment.manager._explainer_ids.append(opt_explainer_id)
509
  self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id)
510
  self.optimal_exp_id = opt_explainer_id
511
  checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
 
639
  from helpers import get_imagenet_dataset, get_torchvision_model, denormalize_image
640
 
641
  os.environ['GRADIO_TEMP_DIR'] = '.tmp'
642
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
643
+ device = torch.device("cpu")
644
 
645
  def target_visualizer(x): return dataset.dataset.idx_to_label(x.item())
646
 
 
650
  dataset = get_imagenet_dataset(transform)
651
  loader = DataLoader(dataset, batch_size=4, shuffle=False)
652
  experiment1 = AutoExplanationForImageClassification(
653
+ model=model.to(device),
654
  data=loader,
655
  input_extractor=lambda batch: batch[0],
656
  label_extractor=lambda batch: batch[-1],
 
670
  dataset = get_imagenet_dataset(transform)
671
  loader = DataLoader(dataset, batch_size=4, shuffle=False)
672
  experiment2 = AutoExplanationForImageClassification(
673
+ model=model.to(device),
674
  data=loader,
675
  input_extractor=lambda batch: batch[0],
676
  label_extractor=lambda batch: batch[-1],