chanycha commited on
Commit
b2f2d4b
1 Parent(s): 74e6982

applying gr.State

Browse files
Files changed (1) hide show
  1. app.py +46 -36
app.py CHANGED
@@ -239,7 +239,7 @@ class Experiment(Component):
239
  idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name)
240
  return metric_info[1][idx]
241
 
242
- def generate_record(self, data_id, metric_names):
243
  record = {}
244
  _base = self.experiment.run_batch([data_id], 0, 0, 0)
245
  record['data_id'] = data_id
@@ -252,7 +252,7 @@ class Experiment(Component):
252
  metrics_ids = [self.get_metric_id_by_name(metric_nm) for metric_nm in metric_names]
253
 
254
  cnt = 0
255
- for info in self.explainer_checkbox_group.info:
256
  if info['checked']:
257
  base = self.experiment.run_batch([data_id], info['id'], info['pp_id'], 0)
258
  record['explanations'].append({
@@ -334,9 +334,9 @@ class Experiment(Component):
334
  plot = gr.Image(value=None, label="Blank", visible=False)
335
  plots.append(plot)
336
 
337
- def show_plots():
338
  _plots = [gr.Textbox(label="Prediction result", visible=False)]
339
- num_plots = sum([1 for info in self.explainer_checkbox_group.info if info['checked']])
340
  n_rows = num_plots // PLOT_PER_LINE
341
  n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows
342
  _plots += [gr.Image(value=None, label="Blank", visible=True)] * (n_rows * PLOT_PER_LINE)
@@ -344,7 +344,7 @@ class Experiment(Component):
344
  return _plots
345
 
346
  @spaces.GPU
347
- def render_plots(data_id, *metric_inputs):
348
  # Clear Cache Files
349
  # print(f"GPU Check: {torch.cuda.is_available()}")
350
  # print("Which GPU: ", torch.cuda.current_device())
@@ -360,12 +360,15 @@ class Experiment(Component):
360
  if metric:
361
  metric_input += metric
362
 
363
- record = self.generate_record(data_id, metric_input)
364
 
365
  pred = self.get_prediction(record)
366
  plots = [gr.Textbox(label="Prediction result", value=pred, visible=True)]
367
 
368
- num_plots = sum([1 for info in self.explainer_checkbox_group.info if info['checked']])
 
 
 
369
  n_rows = num_plots // PLOT_PER_LINE
370
  n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows
371
 
@@ -383,8 +386,8 @@ class Experiment(Component):
383
 
384
  return plots
385
 
386
- bttn.click(show_plots, outputs=plots)
387
- bttn.click(render_plots, inputs=[data_id] + metric_inputs, outputs=plots)
388
 
389
 
390
 
@@ -397,30 +400,33 @@ class ExplainerCheckboxGroup(Component):
397
  self.gallery = gallery
398
  explainers, exp_ids = self.experiment.manager.get_explainers()
399
 
400
- self.info = []
401
  for exp, exp_id in zip(explainers, exp_ids):
402
  exp_nm = exp.__class__.__name__
403
  if exp_nm in DEFAULT_EXPLAINER:
404
  checked = True
405
  else:
406
  checked = False
407
- self.info.append({'nm': exp_nm, 'id': exp_id, 'pp_id' : 0, 'mode': 'default', 'checked': checked})
 
 
408
 
409
- def update_check(self, exp_id, val=None):
410
- for info in self.info:
411
  if info['id'] == exp_id:
412
  if val is not None:
413
  info['checked'] = val
414
  else:
415
  info['checked'] = not info['checked']
 
416
 
417
- def insert_check(self, exp_nm, exp_id, pp_id):
418
- if exp_id in [info['id'] for info in self.info]:
419
  return
 
 
420
 
421
- self.info.append({'nm': exp_nm, 'id': exp_id, 'pp_id' : pp_id, 'mode': 'optimal', 'checked': False})
422
-
423
- def update_gallery_change(self):
424
  checkboxes = []
425
  bttns = []
426
  for exp in self.explainer_objs:
@@ -431,10 +437,10 @@ class ExplainerCheckboxGroup(Component):
431
 
432
  for exp in self.explainer_objs:
433
  val = exp.explainer_name in DEFAULT_EXPLAINER
434
- self.update_check(exp.default_exp_id, val)
435
  if hasattr(exp, "optimal_exp_id"):
436
- self.update_check(exp.optimal_exp_id, False)
437
- return checkboxes + bttns
438
 
439
  def get_checkboxes(self):
440
  checkboxes = []
@@ -447,11 +453,10 @@ class ExplainerCheckboxGroup(Component):
447
 
448
  def show(self):
449
  cnt = 0
450
- sorted_info = sorted(self.info, key=lambda x: (x['nm'] not in DEFAULT_EXPLAINER, x['nm']))
451
  with gr.Accordion("Explainers", open=True):
452
  while cnt * PLOT_PER_LINE < len(self.explainer_names):
453
  with gr.Row():
454
- for info in sorted_info[cnt*PLOT_PER_LINE:(cnt+1)*PLOT_PER_LINE]:
455
  explainer_obj = ExplainerCheckbox(info['nm'], self, self.experiment, self.gallery)
456
  self.explainer_objs.append(explainer_obj)
457
  explainer_obj.show()
@@ -461,7 +466,8 @@ class ExplainerCheckboxGroup(Component):
461
  bttns = self.get_bttns()
462
  self.gallery.gallery_obj.select(
463
  fn=self.update_gallery_change,
464
- outputs=checkboxes + bttns
 
465
  )
466
 
467
 
@@ -488,28 +494,31 @@ class ExplainerCheckbox(Component):
488
  def get_str_ppid(self, pp_obj):
489
  return pp_obj.pooling_fn.__class__.__name__ + pp_obj.normalization_fn.__class__.__name__
490
 
491
- def default_on_select(self, evt: gr.EventData):
492
- self.groups.update_check(self.default_exp_id, evt._data['value'])
 
493
 
494
- def optimal_on_select(self, evt: gr.EventData):
495
  if hasattr(self, "optimal_exp_id"):
496
- self.groups.update_check(self.optimal_exp_id, evt._data['value'])
497
  else:
498
  raise ValueError("Optimal explainer id is not found.")
 
499
 
500
  def show(self):
501
  val = self.explainer_name in DEFAULT_EXPLAINER
502
  with gr.Accordion(self.explainer_name, open=val):
503
- checked = next(filter(lambda x: x['nm'] == self.explainer_name, self.groups.info))['checked']
504
  self.default_check = gr.Checkbox(label="Default Parameter", value=checked, interactive=True)
505
  self.opt_check = gr.Checkbox(label="Optimized Parameter (Not Optimal)", interactive=False)
506
 
507
- self.default_check.select(self.default_on_select)
508
- self.opt_check.select(self.optimal_on_select)
509
 
510
  self.bttn = gr.Button(value="Optimize", size="sm", variant="primary")
511
 
512
- def optimize():
 
513
  data_id = self.gallery.selected_index
514
 
515
  opt_output = self.experiment.optimize(
@@ -527,18 +536,18 @@ class ExplainerCheckbox(Component):
527
  opt_postprocessor_id = pp_id
528
  break
529
 
530
- opt_explainer_id = max([x['id'] for x in self.groups.info]) + 1
531
  opt_output.explainer.model = self.experiment.model
532
  self.experiment.manager._explainers.append(opt_output.explainer)
533
  self.experiment.manager._explainer_ids.append(opt_explainer_id)
534
- self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id)
535
  self.optimal_exp_id = opt_explainer_id
536
  checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
537
  bttn = gr.update(value="Optimized", variant="secondary")
538
 
539
- return [checkbox, bttn]
540
 
541
- self.bttn.click(optimize, outputs=[self.opt_check, self.bttn], queue=True, concurrency_limit=1)
542
 
543
 
544
  class ExpRes(Component):
@@ -692,3 +701,4 @@ experiments['experiment2'] = {
692
  app = ImageClsApp(experiments)
693
  demo = app.launch()
694
  demo.launch(favicon_path=f"static/XAI-Top-PnP.svg")
 
 
239
  idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name)
240
  return metric_info[1][idx]
241
 
242
+ def generate_record(self, checkbox_group_info, data_id, metric_names):
243
  record = {}
244
  _base = self.experiment.run_batch([data_id], 0, 0, 0)
245
  record['data_id'] = data_id
 
252
  metrics_ids = [self.get_metric_id_by_name(metric_nm) for metric_nm in metric_names]
253
 
254
  cnt = 0
255
+ for info in checkbox_group_info:
256
  if info['checked']:
257
  base = self.experiment.run_batch([data_id], info['id'], info['pp_id'], 0)
258
  record['explanations'].append({
 
334
  plot = gr.Image(value=None, label="Blank", visible=False)
335
  plots.append(plot)
336
 
337
+ def show_plots(checkbox_group_info):
338
  _plots = [gr.Textbox(label="Prediction result", visible=False)]
339
+ num_plots = sum([1 for info in checkbox_group_info if info['checked']])
340
  n_rows = num_plots // PLOT_PER_LINE
341
  n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows
342
  _plots += [gr.Image(value=None, label="Blank", visible=True)] * (n_rows * PLOT_PER_LINE)
 
344
  return _plots
345
 
346
  @spaces.GPU
347
+ def render_plots(data_id, checkbox_group_info, *metric_inputs):
348
  # Clear Cache Files
349
  # print(f"GPU Check: {torch.cuda.is_available()}")
350
  # print("Which GPU: ", torch.cuda.current_device())
 
360
  if metric:
361
  metric_input += metric
362
 
363
+ record = self.generate_record(checkbox_group_info, data_id, metric_input)
364
 
365
  pred = self.get_prediction(record)
366
  plots = [gr.Textbox(label="Prediction result", value=pred, visible=True)]
367
 
368
+ # for info in checkbox_group_info:
369
+ # if info['checked']:
370
+ # print(info)
371
+ num_plots = sum([1 for info in checkbox_group_info if info['checked']])
372
  n_rows = num_plots // PLOT_PER_LINE
373
  n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows
374
 
 
386
 
387
  return plots
388
 
389
+ bttn.click(show_plots, inputs=[self.explainer_checkbox_group.info], outputs=plots)
390
+ bttn.click(render_plots, inputs=[data_id, self.explainer_checkbox_group.info] + metric_inputs, outputs=plots)
391
 
392
 
393
 
 
400
  self.gallery = gallery
401
  explainers, exp_ids = self.experiment.manager.get_explainers()
402
 
403
+ info = []
404
  for exp, exp_id in zip(explainers, exp_ids):
405
  exp_nm = exp.__class__.__name__
406
  if exp_nm in DEFAULT_EXPLAINER:
407
  checked = True
408
  else:
409
  checked = False
410
+ info.append({'nm': exp_nm, 'id': exp_id, 'pp_id' : 0, 'mode': 'default', 'checked': checked})
411
+ self.static_info = sorted(info, key=lambda x: (x['nm'] not in DEFAULT_EXPLAINER, x['nm']))
412
+ self.info = gr.State(info)
413
 
414
+ def update_check(self, checkbox_group_info, exp_id, val=None):
415
+ for info in checkbox_group_info:
416
  if info['id'] == exp_id:
417
  if val is not None:
418
  info['checked'] = val
419
  else:
420
  info['checked'] = not info['checked']
421
+ return checkbox_group_info
422
 
423
+ def insert_check(self, checkbox_group_info, exp_nm, exp_id, pp_id):
424
+ if exp_id in [info['id'] for info in checkbox_group_info]:
425
  return
426
+ checkbox_group_info.append({'nm': exp_nm, 'id': exp_id, 'pp_id' : pp_id, 'mode': 'optimal', 'checked': False})
427
+ return checkbox_group_info
428
 
429
+ def update_gallery_change(self, checkbox_group_info):
 
 
430
  checkboxes = []
431
  bttns = []
432
  for exp in self.explainer_objs:
 
437
 
438
  for exp in self.explainer_objs:
439
  val = exp.explainer_name in DEFAULT_EXPLAINER
440
+ checkbox_group_info = self.update_check(checkbox_group_info, exp.default_exp_id, val)
441
  if hasattr(exp, "optimal_exp_id"):
442
+ checkbox_group_info = self.update_check(checkbox_group_info, exp.optimal_exp_id, False)
443
+ return checkboxes + bttns + [checkbox_group_info]
444
 
445
  def get_checkboxes(self):
446
  checkboxes = []
 
453
 
454
  def show(self):
455
  cnt = 0
 
456
  with gr.Accordion("Explainers", open=True):
457
  while cnt * PLOT_PER_LINE < len(self.explainer_names):
458
  with gr.Row():
459
+ for info in self.static_info[cnt*PLOT_PER_LINE:(cnt+1)*PLOT_PER_LINE]:
460
  explainer_obj = ExplainerCheckbox(info['nm'], self, self.experiment, self.gallery)
461
  self.explainer_objs.append(explainer_obj)
462
  explainer_obj.show()
 
466
  bttns = self.get_bttns()
467
  self.gallery.gallery_obj.select(
468
  fn=self.update_gallery_change,
469
+ inputs=self.info,
470
+ outputs=checkboxes + bttns + [self.info],
471
  )
472
 
473
 
 
494
  def get_str_ppid(self, pp_obj):
495
  return pp_obj.pooling_fn.__class__.__name__ + pp_obj.normalization_fn.__class__.__name__
496
 
497
+ def default_on_select(self, evt: gr.EventData, checkbox_group_info):
498
+ checkbox_group_info = self.groups.update_check(checkbox_group_info, self.default_exp_id, evt._data['value'])
499
+ return checkbox_group_info
500
 
501
+ def optimal_on_select(self, evt: gr.EventData, checkbox_group_info):
502
  if hasattr(self, "optimal_exp_id"):
503
+ checkbox_group_info = self.groups.update_check(checkbox_group_info, self.optimal_exp_id, evt._data['value'])
504
  else:
505
  raise ValueError("Optimal explainer id is not found.")
506
+ return checkbox_group_info
507
 
508
  def show(self):
509
  val = self.explainer_name in DEFAULT_EXPLAINER
510
  with gr.Accordion(self.explainer_name, open=val):
511
+ checked = next(filter(lambda x: x['nm'] == self.explainer_name, self.groups.static_info))['checked']
512
  self.default_check = gr.Checkbox(label="Default Parameter", value=checked, interactive=True)
513
  self.opt_check = gr.Checkbox(label="Optimized Parameter (Not Optimal)", interactive=False)
514
 
515
+ self.default_check.select(self.default_on_select, self.groups.info, self.groups.info)
516
+ self.opt_check.select(self.optimal_on_select, self.groups.info, self.groups.info)
517
 
518
  self.bttn = gr.Button(value="Optimize", size="sm", variant="primary")
519
 
520
+ @spaces.GPU
521
+ def optimize(checkbox_group_info):
522
  data_id = self.gallery.selected_index
523
 
524
  opt_output = self.experiment.optimize(
 
536
  opt_postprocessor_id = pp_id
537
  break
538
 
539
+ opt_explainer_id = max([x['id'] for x in checkbox_group_info]) + 1
540
  opt_output.explainer.model = self.experiment.model
541
  self.experiment.manager._explainers.append(opt_output.explainer)
542
  self.experiment.manager._explainer_ids.append(opt_explainer_id)
543
+ self.groups.insert_check(checkbox_group_info, self.explainer_name, opt_explainer_id, opt_postprocessor_id)
544
  self.optimal_exp_id = opt_explainer_id
545
  checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True)
546
  bttn = gr.update(value="Optimized", variant="secondary")
547
 
548
+ return [checkbox_group_info, checkbox, bttn]
549
 
550
+ self.bttn.click(optimize, inputs=[self.groups.info], outputs=[self.groups.info, self.opt_check, self.bttn], queue=True, concurrency_limit=1)
551
 
552
 
553
  class ExpRes(Component):
 
701
  app = ImageClsApp(experiments)
702
  demo = app.launch()
703
  demo.launch(favicon_path=f"static/XAI-Top-PnP.svg")
704
+ # demo.launch(favicon_path=f"static/XAI-Top-PnP.svg", share=True)