chanycha commited on
Commit
29fa28e
1 Parent(s): 5269caa

reconstruc baseline, mask

Browse files
Files changed (1) hide show
  1. app.py +54 -12
app.py CHANGED
@@ -7,6 +7,8 @@ import gradio as gr
7
  import spaces
8
  from pnpxai.core.experiment.auto_explanation import AutoExplanationForImageClassification
9
  from pnpxai.core.detector.detector import extract_graph_data, symbolic_trace
 
 
10
  import matplotlib.pyplot as plt
11
  import plotly.graph_objects as go
12
  import plotly.express as px
@@ -541,10 +543,9 @@ class ExplainerCheckbox(Component):
541
  break
542
 
543
  opt_exp_id = max([x['id'] for x in checkbox_group_info]) + 1
544
- # opt_output.explainer.model = self.experiment.model
545
- # self.experiment.manager._explainers.append(opt_output.explainer)
546
- # self.experiment.manager._explainer_ids.append(opt_exp_id)
547
 
 
 
548
  opt_res = {
549
  'id': opt_exp_id,
550
  'class': opt_output.explainer.__class__,
@@ -558,15 +559,56 @@ class ExplainerCheckbox(Component):
558
  return [opt_res, checkbox_group_info, checkbox, bttn]
559
 
560
  def update_exp(exp_res):
561
- kwargs = {}
562
- for k,v in exp_res['params'].items():
563
- if "explainer" in k:
564
- _key = k.split("explainer.")[1]
565
- kwargs[_key] = v
566
-
567
- kwargs['model'] = self.experiment.model
568
- explainer = exp_res['class'](**kwargs)
569
- _id = exp_res['id']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
  self.experiment.manager._explainers.append(explainer)
571
  self.experiment.manager._explainer_ids.append(_id)
572
 
 
7
  import spaces
8
  from pnpxai.core.experiment.auto_explanation import AutoExplanationForImageClassification
9
  from pnpxai.core.detector.detector import extract_graph_data, symbolic_trace
10
+ from pnpxai.explainers.utils.baselines import BASELINE_FUNCTIONS_FOR_IMAGE
11
+ from pnpxai.explainers.utils.feature_masks import FEATURE_MASK_FUNCTIONS_FOR_IMAGE
12
  import matplotlib.pyplot as plt
13
  import plotly.graph_objects as go
14
  import plotly.express as px
 
543
  break
544
 
545
  opt_exp_id = max([x['id'] for x in checkbox_group_info]) + 1
 
 
 
546
 
547
+ # Deliver the parameter and class and reconstruct
548
+ # It should be done because spaces.GPU cannot pickle the class object
549
  opt_res = {
550
  'id': opt_exp_id,
551
  'class': opt_output.explainer.__class__,
 
559
  return [opt_res, checkbox_group_info, checkbox, bttn]
560
 
561
  def update_exp(exp_res):
562
+ try:
563
+ kwargs = {}
564
+ has_baseline = False
565
+ has_feature_mask = False
566
+ for k,v in exp_res['params'].items():
567
+ if "explainer" in k:
568
+ _key = k.split("explainer.")[1]
569
+ kwargs[_key] = v
570
+ if "baseline_fn" in _key:
571
+ has_baseline = True
572
+ if "feature_mask_fn" in _key:
573
+ has_feature_mask = True
574
+
575
+ # Reconstruct baseline object
576
+ if has_baseline:
577
+ method = kwargs['baseline_fn.method']
578
+ del kwargs['baseline_fn.method']
579
+ baseline_kwargs = {}
580
+ keys = list(kwargs.keys())
581
+ for k in keys:
582
+ v = kwargs[k]
583
+ if "baseline_fn" in k:
584
+ baseline_kwargs[k.split("baseline_fn.")[1]] = v
585
+ del kwargs[k]
586
+ if method == "mean":
587
+ baseline_kwargs['dim'] = 1 # Set arbitrary value
588
+ baseline_fn = BASELINE_FUNCTIONS_FOR_IMAGE[method](**baseline_kwargs)
589
+ kwargs['baseline_fn'] = baseline_fn
590
+
591
+ # Reconstruct feature_mask object
592
+ if has_feature_mask:
593
+ method = kwargs['feature_mask_fn.method']
594
+ del kwargs['feature_mask_fn.method']
595
+ mask_kwargs = {}
596
+ keys = list(kwargs.keys())
597
+ for k in keys:
598
+ v = kwargs[k]
599
+ if "feature_mask_fn" in k:
600
+ mask_kwargs[k.split("feature_mask_fn.")[1]] = v
601
+ del kwargs[k]
602
+ mask_fn = FEATURE_MASK_FUNCTIONS_FOR_IMAGE[method](**mask_kwargs)
603
+ kwargs['feature_mask_fn'] = mask_fn
604
+
605
+ kwargs['model'] = self.experiment.model
606
+ explainer = exp_res['class'](**kwargs)
607
+ _id = exp_res['id']
608
+ except Exception as e:
609
+ # If the optimization is failed, use the default parameter explainer as optimal
610
+ explainer = self.experiment.manager._explainers[self.default_exp_id]
611
+
612
  self.experiment.manager._explainers.append(explainer)
613
  self.experiment.manager._explainer_ids.append(_id)
614