rerun
Browse files
app.py
CHANGED
@@ -488,37 +488,6 @@ class ExplainerCheckbox(Component):
|
|
488 |
def get_str_ppid(self, pp_obj):
|
489 |
return pp_obj.pooling_fn.__class__.__name__ + pp_obj.normalization_fn.__class__.__name__
|
490 |
|
491 |
-
@spaces.GPU
|
492 |
-
def optimize(self):
|
493 |
-
data_id = self.gallery.selected_index
|
494 |
-
|
495 |
-
opt_output = self.experiment.optimize(
|
496 |
-
data_ids=data_id.value,
|
497 |
-
explainer_id=self.default_exp_id,
|
498 |
-
metric_id=self.obj_metric,
|
499 |
-
direction='maximize',
|
500 |
-
sampler=SAMPLE_METHOD,
|
501 |
-
n_trials=OPT_N_TRIALS,
|
502 |
-
)
|
503 |
-
|
504 |
-
str_id = self.get_str_ppid(opt_output.postprocessor)
|
505 |
-
for pp_obj, pp_id in zip(*self.experiment.manager.get_postprocessors()):
|
506 |
-
if self.get_str_ppid(pp_obj) == str_id:
|
507 |
-
opt_postprocessor_id = pp_id
|
508 |
-
break
|
509 |
-
|
510 |
-
opt_explainer_id = max([x['id'] for x in self.groups.info]) + 1
|
511 |
-
opt_output.explainer.model = self.experiment.model
|
512 |
-
self.experiment.manager._explainers.append(opt_output.explainer)
|
513 |
-
self.experiment.manager._explainer_ids.append(opt_explainer_id)
|
514 |
-
self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id)
|
515 |
-
self.optimal_exp_id = opt_explainer_id
|
516 |
-
checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
|
517 |
-
bttn = gr.update(value="Optimized", variant="secondary")
|
518 |
-
|
519 |
-
return [checkbox, bttn]
|
520 |
-
|
521 |
-
|
522 |
def default_on_select(self, evt: gr.EventData):
|
523 |
self.groups.update_check(self.default_exp_id, evt._data['value'])
|
524 |
|
|
|
488 |
def get_str_ppid(self, pp_obj):
|
489 |
return pp_obj.pooling_fn.__class__.__name__ + pp_obj.normalization_fn.__class__.__name__
|
490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
def default_on_select(self, evt: gr.EventData):
|
492 |
self.groups.update_check(self.default_exp_id, evt._data['value'])
|
493 |
|