chanycha commited on
Commit
d9eb6c9
1 Parent(s): bb63bcd
Files changed (2) hide show
  1. app.py +11 -15
  2. requirements.txt +1 -1
app.py CHANGED
@@ -2,7 +2,7 @@
2
  import time
3
  import os
4
  import gradio as gr
5
- from pnpxai.core.experiment import AutoExplanation
6
  from pnpxai.core.detector import extract_graph_data, symbolic_trace
7
  import matplotlib.pyplot as plt
8
  import plotly.graph_objects as go
@@ -467,22 +467,24 @@ class ExplainerCheckbox(Component):
467
 
468
 
469
  def optimize(self):
470
- if self.explainer_name in ["Lime", "KernelShap", "IntegratedGradients"]:
471
- gr.Info("Lime, KernelShap and IntegratedGradients currently do not support hyperparameter optimization.")
472
- return [gr.update()] * 2
473
 
474
  data_id = self.gallery.selected_index
475
-
476
- opt_explainer_id, opt_postprocessor_id = self.experiment.optimize(
477
  data_id=data_id.value,
478
  explainer_id=self.default_exp_id,
479
  metric_id=self.obj_metric,
480
  direction='maximize',
481
  sampler=SAMPLE_METHOD,
482
  n_trials=OPT_N_TRIALS,
483
- return_study=False,
484
  )
485
 
 
 
 
486
  self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id)
487
  self.optimal_exp_id = opt_explainer_id
488
  checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
@@ -620,12 +622,9 @@ experiments = {}
620
  model, transform = get_torchvision_model('resnet18')
621
  dataset = get_imagenet_dataset(transform)
622
  loader = DataLoader(dataset, batch_size=4, shuffle=False)
623
- experiment1 = AutoExplanation(
624
  model=model,
625
  data=loader,
626
- modality='image',
627
- question='why',
628
- evaluator_enabled=True,
629
  input_extractor=lambda batch: batch[0],
630
  label_extractor=lambda batch: batch[-1],
631
  target_extractor=lambda outputs: outputs.argmax(-1),
@@ -643,12 +642,9 @@ experiments['experiment1'] = {
643
  model, transform = get_torchvision_model('vit_b_16')
644
  dataset = get_imagenet_dataset(transform)
645
  loader = DataLoader(dataset, batch_size=4, shuffle=False)
646
- experiment2 = AutoExplanation(
647
  model=model,
648
  data=loader,
649
- modality='image',
650
- question='why',
651
- evaluator_enabled=True,
652
  input_extractor=lambda batch: batch[0],
653
  label_extractor=lambda batch: batch[-1],
654
  target_extractor=lambda outputs: outputs.argmax(-1),
 
2
  import time
3
  import os
4
  import gradio as gr
5
+ from pnpxai.core.experiment import AutoExplanationForImageClassification
6
  from pnpxai.core.detector import extract_graph_data, symbolic_trace
7
  import matplotlib.pyplot as plt
8
  import plotly.graph_objects as go
 
467
 
468
 
469
  def optimize(self):
470
+ # if self.explainer_name in ["Lime", "KernelShap", "IntegratedGradients"]:
471
+ # gr.Info("Lime, KernelShap and IntegratedGradients currently do not support hyperparameter optimization.")
472
+ # return [gr.update()] * 2
473
 
474
  data_id = self.gallery.selected_index
475
+
476
+ optimized, _, _ = self.experiment.optimize(
477
  data_id=data_id.value,
478
  explainer_id=self.default_exp_id,
479
  metric_id=self.obj_metric,
480
  direction='maximize',
481
  sampler=SAMPLE_METHOD,
482
  n_trials=OPT_N_TRIALS,
 
483
  )
484
 
485
+ opt_explainer_id = optimized['explainer_id']
486
+ opt_postprocessor_id = optimized['postprocessor_id']
487
+
488
  self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id)
489
  self.optimal_exp_id = opt_explainer_id
490
  checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
 
622
  model, transform = get_torchvision_model('resnet18')
623
  dataset = get_imagenet_dataset(transform)
624
  loader = DataLoader(dataset, batch_size=4, shuffle=False)
625
+ experiment1 = AutoExplanationForImageClassification(
626
  model=model,
627
  data=loader,
 
 
 
628
  input_extractor=lambda batch: batch[0],
629
  label_extractor=lambda batch: batch[-1],
630
  target_extractor=lambda outputs: outputs.argmax(-1),
 
642
  model, transform = get_torchvision_model('vit_b_16')
643
  dataset = get_imagenet_dataset(transform)
644
  loader = DataLoader(dataset, batch_size=4, shuffle=False)
645
+ experiment2 = AutoExplanationForImageClassification(
646
  model=model,
647
  data=loader,
 
 
 
648
  input_extractor=lambda batch: batch[0],
649
  label_extractor=lambda batch: batch[-1],
650
  target_extractor=lambda outputs: outputs.argmax(-1),
requirements.txt CHANGED
@@ -25,4 +25,4 @@ optuna
25
  transformers>=4.0.0
26
  gensim>=4.0.0
27
 
28
- git+https://github.com/OpenXAIProject/pnpxai.git@feat/image/tutorial#egg=pnpxai
 
25
  transformers>=4.0.0
26
  gensim>=4.0.0
27
 
28
+ git+https://github.com/OpenXAIProject/pnpxai.git@dev#egg=pnpxai