reconstruc baseline, mask
Browse files
app.py
CHANGED
@@ -7,6 +7,8 @@ import gradio as gr
|
|
7 |
import spaces
|
8 |
from pnpxai.core.experiment.auto_explanation import AutoExplanationForImageClassification
|
9 |
from pnpxai.core.detector.detector import extract_graph_data, symbolic_trace
|
|
|
|
|
10 |
import matplotlib.pyplot as plt
|
11 |
import plotly.graph_objects as go
|
12 |
import plotly.express as px
|
@@ -541,10 +543,9 @@ class ExplainerCheckbox(Component):
|
|
541 |
break
|
542 |
|
543 |
opt_exp_id = max([x['id'] for x in checkbox_group_info]) + 1
|
544 |
-
# opt_output.explainer.model = self.experiment.model
|
545 |
-
# self.experiment.manager._explainers.append(opt_output.explainer)
|
546 |
-
# self.experiment.manager._explainer_ids.append(opt_exp_id)
|
547 |
|
|
|
|
|
548 |
opt_res = {
|
549 |
'id': opt_exp_id,
|
550 |
'class': opt_output.explainer.__class__,
|
@@ -558,15 +559,56 @@ class ExplainerCheckbox(Component):
|
|
558 |
return [opt_res, checkbox_group_info, checkbox, bttn]
|
559 |
|
560 |
def update_exp(exp_res):
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
570 |
self.experiment.manager._explainers.append(explainer)
|
571 |
self.experiment.manager._explainer_ids.append(_id)
|
572 |
|
|
|
7 |
import spaces
|
8 |
from pnpxai.core.experiment.auto_explanation import AutoExplanationForImageClassification
|
9 |
from pnpxai.core.detector.detector import extract_graph_data, symbolic_trace
|
10 |
+
from pnpxai.explainers.utils.baselines import BASELINE_FUNCTIONS_FOR_IMAGE
|
11 |
+
from pnpxai.explainers.utils.feature_masks import FEATURE_MASK_FUNCTIONS_FOR_IMAGE
|
12 |
import matplotlib.pyplot as plt
|
13 |
import plotly.graph_objects as go
|
14 |
import plotly.express as px
|
|
|
543 |
break
|
544 |
|
545 |
opt_exp_id = max([x['id'] for x in checkbox_group_info]) + 1
|
|
|
|
|
|
|
546 |
|
547 |
+
# Deliver the parameter and class and reconstruct
|
548 |
+
# It should be done because spaces.GPU cannot pickle the class object
|
549 |
opt_res = {
|
550 |
'id': opt_exp_id,
|
551 |
'class': opt_output.explainer.__class__,
|
|
|
559 |
return [opt_res, checkbox_group_info, checkbox, bttn]
|
560 |
|
561 |
def update_exp(exp_res):
|
562 |
+
try:
|
563 |
+
kwargs = {}
|
564 |
+
has_baseline = False
|
565 |
+
has_feature_mask = False
|
566 |
+
for k,v in exp_res['params'].items():
|
567 |
+
if "explainer" in k:
|
568 |
+
_key = k.split("explainer.")[1]
|
569 |
+
kwargs[_key] = v
|
570 |
+
if "baseline_fn" in _key:
|
571 |
+
has_baseline = True
|
572 |
+
if "feature_mask_fn" in _key:
|
573 |
+
has_feature_mask = True
|
574 |
+
|
575 |
+
# Reconstruct baseline object
|
576 |
+
if has_baseline:
|
577 |
+
method = kwargs['baseline_fn.method']
|
578 |
+
del kwargs['baseline_fn.method']
|
579 |
+
baseline_kwargs = {}
|
580 |
+
keys = list(kwargs.keys())
|
581 |
+
for k in keys:
|
582 |
+
v = kwargs[k]
|
583 |
+
if "baseline_fn" in k:
|
584 |
+
baseline_kwargs[k.split("baseline_fn.")[1]] = v
|
585 |
+
del kwargs[k]
|
586 |
+
if method == "mean":
|
587 |
+
baseline_kwargs['dim'] = 1 # Set arbitrary value
|
588 |
+
baseline_fn = BASELINE_FUNCTIONS_FOR_IMAGE[method](**baseline_kwargs)
|
589 |
+
kwargs['baseline_fn'] = baseline_fn
|
590 |
+
|
591 |
+
# Reconstruct feature_mask object
|
592 |
+
if has_feature_mask:
|
593 |
+
method = kwargs['feature_mask_fn.method']
|
594 |
+
del kwargs['feature_mask_fn.method']
|
595 |
+
mask_kwargs = {}
|
596 |
+
keys = list(kwargs.keys())
|
597 |
+
for k in keys:
|
598 |
+
v = kwargs[k]
|
599 |
+
if "feature_mask_fn" in k:
|
600 |
+
mask_kwargs[k.split("feature_mask_fn.")[1]] = v
|
601 |
+
del kwargs[k]
|
602 |
+
mask_fn = FEATURE_MASK_FUNCTIONS_FOR_IMAGE[method](**mask_kwargs)
|
603 |
+
kwargs['feature_mask_fn'] = mask_fn
|
604 |
+
|
605 |
+
kwargs['model'] = self.experiment.model
|
606 |
+
explainer = exp_res['class'](**kwargs)
|
607 |
+
_id = exp_res['id']
|
608 |
+
except Exception as e:
|
609 |
+
# If the optimization is failed, use the default parameter explainer as optimal
|
610 |
+
explainer = self.experiment.manager._explainers[self.default_exp_id]
|
611 |
+
|
612 |
self.experiment.manager._explainers.append(explainer)
|
613 |
self.experiment.manager._explainer_ids.append(_id)
|
614 |
|