Spaces:
Sleeping
Sleeping
dev merge
Browse files- app.py +11 -15
- requirements.txt +1 -1
app.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
import time
|
3 |
import os
|
4 |
import gradio as gr
|
5 |
-
from pnpxai.core.experiment import
|
6 |
from pnpxai.core.detector import extract_graph_data, symbolic_trace
|
7 |
import matplotlib.pyplot as plt
|
8 |
import plotly.graph_objects as go
|
@@ -467,22 +467,24 @@ class ExplainerCheckbox(Component):
|
|
467 |
|
468 |
|
469 |
def optimize(self):
|
470 |
-
if self.explainer_name in ["Lime", "KernelShap", "IntegratedGradients"]:
|
471 |
-
|
472 |
-
|
473 |
|
474 |
data_id = self.gallery.selected_index
|
475 |
-
|
476 |
-
|
477 |
data_id=data_id.value,
|
478 |
explainer_id=self.default_exp_id,
|
479 |
metric_id=self.obj_metric,
|
480 |
direction='maximize',
|
481 |
sampler=SAMPLE_METHOD,
|
482 |
n_trials=OPT_N_TRIALS,
|
483 |
-
return_study=False,
|
484 |
)
|
485 |
|
|
|
|
|
|
|
486 |
self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id)
|
487 |
self.optimal_exp_id = opt_explainer_id
|
488 |
checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
|
@@ -620,12 +622,9 @@ experiments = {}
|
|
620 |
model, transform = get_torchvision_model('resnet18')
|
621 |
dataset = get_imagenet_dataset(transform)
|
622 |
loader = DataLoader(dataset, batch_size=4, shuffle=False)
|
623 |
-
experiment1 =
|
624 |
model=model,
|
625 |
data=loader,
|
626 |
-
modality='image',
|
627 |
-
question='why',
|
628 |
-
evaluator_enabled=True,
|
629 |
input_extractor=lambda batch: batch[0],
|
630 |
label_extractor=lambda batch: batch[-1],
|
631 |
target_extractor=lambda outputs: outputs.argmax(-1),
|
@@ -643,12 +642,9 @@ experiments['experiment1'] = {
|
|
643 |
model, transform = get_torchvision_model('vit_b_16')
|
644 |
dataset = get_imagenet_dataset(transform)
|
645 |
loader = DataLoader(dataset, batch_size=4, shuffle=False)
|
646 |
-
experiment2 =
|
647 |
model=model,
|
648 |
data=loader,
|
649 |
-
modality='image',
|
650 |
-
question='why',
|
651 |
-
evaluator_enabled=True,
|
652 |
input_extractor=lambda batch: batch[0],
|
653 |
label_extractor=lambda batch: batch[-1],
|
654 |
target_extractor=lambda outputs: outputs.argmax(-1),
|
|
|
2 |
import time
|
3 |
import os
|
4 |
import gradio as gr
|
5 |
+
from pnpxai.core.experiment import AutoExplanationForImageClassification
|
6 |
from pnpxai.core.detector import extract_graph_data, symbolic_trace
|
7 |
import matplotlib.pyplot as plt
|
8 |
import plotly.graph_objects as go
|
|
|
467 |
|
468 |
|
469 |
def optimize(self):
|
470 |
+
# if self.explainer_name in ["Lime", "KernelShap", "IntegratedGradients"]:
|
471 |
+
# gr.Info("Lime, KernelShap and IntegratedGradients currently do not support hyperparameter optimization.")
|
472 |
+
# return [gr.update()] * 2
|
473 |
|
474 |
data_id = self.gallery.selected_index
|
475 |
+
|
476 |
+
optimized, _, _ = self.experiment.optimize(
|
477 |
data_id=data_id.value,
|
478 |
explainer_id=self.default_exp_id,
|
479 |
metric_id=self.obj_metric,
|
480 |
direction='maximize',
|
481 |
sampler=SAMPLE_METHOD,
|
482 |
n_trials=OPT_N_TRIALS,
|
|
|
483 |
)
|
484 |
|
485 |
+
opt_explainer_id = optimized['explainer_id']
|
486 |
+
opt_postprocessor_id = optimized['postprocessor_id']
|
487 |
+
|
488 |
self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id)
|
489 |
self.optimal_exp_id = opt_explainer_id
|
490 |
checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
|
|
|
622 |
model, transform = get_torchvision_model('resnet18')
|
623 |
dataset = get_imagenet_dataset(transform)
|
624 |
loader = DataLoader(dataset, batch_size=4, shuffle=False)
|
625 |
+
experiment1 = AutoExplanationForImageClassification(
|
626 |
model=model,
|
627 |
data=loader,
|
|
|
|
|
|
|
628 |
input_extractor=lambda batch: batch[0],
|
629 |
label_extractor=lambda batch: batch[-1],
|
630 |
target_extractor=lambda outputs: outputs.argmax(-1),
|
|
|
642 |
model, transform = get_torchvision_model('vit_b_16')
|
643 |
dataset = get_imagenet_dataset(transform)
|
644 |
loader = DataLoader(dataset, batch_size=4, shuffle=False)
|
645 |
+
experiment2 = AutoExplanationForImageClassification(
|
646 |
model=model,
|
647 |
data=loader,
|
|
|
|
|
|
|
648 |
input_extractor=lambda batch: batch[0],
|
649 |
label_extractor=lambda batch: batch[-1],
|
650 |
target_extractor=lambda outputs: outputs.argmax(-1),
|
requirements.txt
CHANGED
@@ -25,4 +25,4 @@ optuna
|
|
25 |
transformers>=4.0.0
|
26 |
gensim>=4.0.0
|
27 |
|
28 |
-
git+https://github.com/OpenXAIProject/pnpxai.git@
|
|
|
25 |
transformers>=4.0.0
|
26 |
gensim>=4.0.0
|
27 |
|
28 |
+
git+https://github.com/OpenXAIProject/pnpxai.git@dev#egg=pnpxai
|