Spaces:
Runtime error
Runtime error
File size: 21,714 Bytes
80442c0 866cafe f240072 80442c0 f240072 866cafe f240072 80442c0 ee4b6af 866cafe 80442c0 c3fd70a 80442c0 35ee063 c3fd70a c4c6bd6 866cafe 46eef9f f240072 866cafe e4a62fe f240072 46eef9f f240072 866cafe 7583157 46eef9f 866cafe f240072 46eef9f 7583157 46eef9f 80442c0 866cafe e4a62fe 866cafe c4c6bd6 866cafe c4c6bd6 866cafe 80442c0 866cafe 80442c0 866cafe 80442c0 866cafe 80442c0 46eef9f 80442c0 866cafe 35ee063 c3fd70a 866cafe 80442c0 c3fd70a 90e82c9 80442c0 90e82c9 5bd1489 80442c0 603879a 5bd1489 f240072 866cafe c3fd70a 866cafe 603879a 866cafe 46eef9f 866cafe 5bd1489 c3fd70a 5bd1489 c4c6bd6 c3fd70a 5bd1489 c4c6bd6 5bd1489 c4c6bd6 5bd1489 c4c6bd6 c3fd70a c4c6bd6 c3fd70a c4c6bd6 c3fd70a c4c6bd6 c3fd70a 5bd1489 866cafe c4c6bd6 866cafe f240072 866cafe f240072 866cafe 35ee063 866cafe e4a62fe 866cafe e4a62fe 866cafe c4c6bd6 46eef9f 866cafe 35ee063 e4a62fe f240072 35ee063 820ea84 866cafe e4a62fe 5bd1489 866cafe f240072 866cafe e4a62fe 866cafe 5bd1489 e4a62fe 5bd1489 dc69193 866cafe ee4b6af f240072 866cafe f240072 866cafe 35ee063 f240072 4db7f81 2dc35ff f240072 e4a62fe 2dc35ff 866cafe f240072 866cafe d4894f5 35ee063 f240072 866cafe e4a62fe 866cafe 2dc35ff 5bd1489 866cafe e4a62fe 866cafe f240072 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 |
import os
import torch
import gradio as gr
import torchvision
from PIL import Image
from utils import *
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from huggingface_hub import Repository, upload_file
from torch.utils.data import Dataset
import numpy as np
from collections import Counter
with open('app.css','r') as f:
BLOCK_CSS = f.read()
n_epochs = 10
batch_size_train = 128
batch_size_test = 1000
learning_rate = 0.01
adv_learning_rate= 0.001
momentum = 0.5
log_interval = 10
random_seed = 1
TRAIN_CUTOFF = 10
TEST_PER_SAMPLE = 5000
DASHBOARD_EXPLANATION = DASHBOARD_EXPLANATION.format(TEST_PER_SAMPLE=TEST_PER_SAMPLE)
WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
MODEL_PATH = 'model'
METRIC_PATH = os.path.join(MODEL_PATH,'metrics.json')
MODEL_WEIGHTS_PATH = os.path.join(MODEL_PATH,'mnist_model.pth')
OPTIMIZER_PATH = os.path.join(MODEL_PATH,'optimizer.pth')
REPOSITORY_DIR = "data"
LOCAL_DIR = 'data_local'
os.makedirs(LOCAL_DIR,exist_ok=True)
GET_STATISTICS_MESSAGE = "Get Statistics"
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_REPO = 'mnist-adversarial-model'
HF_DATASET ="mnist-adversarial-dataset"
DATASET_REPO_URL = f"https://huggingface.co./datasets/chrisjay/{HF_DATASET}"
MODEL_REPO_URL = f"https://huggingface.co./model/chrisjay/{MODEL_REPO}"
repo = Repository(
local_dir="data_mnist", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
)
repo.git_pull()
model_repo = Repository(
local_dir=MODEL_PATH, clone_from=MODEL_REPO_URL, use_auth_token=HF_TOKEN, repo_type="model"
)
model_repo.git_pull()
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
class MNISTAdversarial_Dataset(Dataset):
def __init__(self,data_dir,transform):
repo.git_pull()
self.data_dir = os.path.join(data_dir,'data')
self.transform = transform
files = [f.name for f in os.scandir(self.data_dir)]
self.images = []
self.numbers = []
for f in files:
self.FOLDER = os.path.join(os.path.join(self.data_dir,f))
metadata_path = os.path.join(self.FOLDER,'metadata.jsonl')
image_path =os.path.join(self.FOLDER,'image.png')
if os.path.exists(image_path) and os.path.exists(metadata_path):
metadata = read_json_lines(metadata_path)
if metadata is not None:
img = Image.open(image_path)
self.images.append(img)
self.numbers.append(metadata[0]['correct_number'])
assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers."
def __len__(self):
return len(self.images)
def __getitem__(self,idx):
img, label = self.images[idx], self.numbers[idx]
img = self.transform(img)
return img, label
class MNISTCorrupted_By_Digit(Dataset):
def __init__(self,transform,digit,limit=TEST_PER_SAMPLE):
self.transform = transform
self.digit = digit
corrupted_dir="./mnist_c"
files = [f.name for f in os.scandir(corrupted_dir)]
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy')) for f in files]
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy')) for f in files]
self.data = np.vstack(images)
self.labels = np.hstack(labels)
assert (self.data.shape[0] == self.labels.shape[0])
mask = self.labels == self.digit
data_masked = self.data[mask]
# Just to be on the safe side, ensure limit is more than the minimum
limit = min(limit,data_masked.shape[0])
self.data_for_use = data_masked[:limit]
self.labels_for_use = self.labels[mask][:limit]
assert (self.data_for_use.shape[0] == self.labels_for_use.shape[0])
def __len__(self):
return len(self.data_for_use)
def __getitem__(self,idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image = self.data_for_use[idx]
label = self.labels_for_use[idx]
if self.transform:
image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms
image = self.transform(image_pil)
return image, label
class MNISTCorrupted(Dataset):
def __init__(self,transform):
self.transform = transform
corrupted_dir="./mnist_c"
files = [f.name for f in os.scandir(corrupted_dir)]
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy'))[:TEST_PER_SAMPLE] for f in files]
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy'))[:TEST_PER_SAMPLE] for f in files]
self.data = np.vstack(images)
self.labels = np.hstack(labels)
assert (self.data.shape[0] == self.labels.shape[0])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image = self.data[idx]
label = self.labels[idx]
if self.transform:
image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms
image = self.transform(image_pil)
return image, label
TRAIN_TRANSFORM = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])
test_loader = torch.utils.data.DataLoader(MNISTCorrupted(TRAIN_TRANSFORM),
batch_size=batch_size_test, shuffle=False)
# Source: https://nextjournal.com/gkoehler/pytorch-mnist
class MNIST_Model(nn.Module):
def __init__(self):
super(MNIST_Model, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
def train(epochs,network,optimizer,train_loader):
train_losses=[]
network.train()
for epoch in range(epochs):
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = network(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
train_losses.append(loss.item())
torch.save(network.state_dict(), MODEL_WEIGHTS_PATH)
torch.save(optimizer.state_dict(), OPTIMIZER_PATH)
def test():
test_losses=[]
network.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = network(data)
test_loss += F.nll_loss(output, target, size_average=False).item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
test_loss /= len(test_loader.dataset)
test_losses.append(test_loss)
acc = 100. * correct / len(test_loader.dataset)
acc = acc.item()
test_metric = '〽Current test metric -> Avg. loss: `{:.4f}`, Accuracy: `{:.0f}%`\n'.format(
test_loss,acc)
print(test_metric)
return test_metric,acc
random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
network = MNIST_Model()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
momentum=momentum)
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./files/', train=True, download=True,
transform=TRAIN_TRANSFORM),
batch_size=batch_size_train, shuffle=True)
test_iid_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./files/', train=False, download=True,
transform=TRAIN_TRANSFORM),
batch_size=batch_size_test, shuffle=True)
model_state_dict = MODEL_WEIGHTS_PATH
optimizer_state_dict = OPTIMIZER_PATH
if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
network_state_dict = torch.load(model_state_dict)
network.load_state_dict(network_state_dict)
optimizer_state_dict = torch.load(optimizer_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
# Train model
#n_epochs=20
#train(n_epochs,network,optimizer,train_loader)
#test()
def train_and_test(train_model=True):
if train_model:
# Train for one epoch and test
train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True)
train(n_epochs,network,optimizer,train_loader)
test_metric,test_acc = test()
network.eval()
if os.path.exists(METRIC_PATH):
metric_dict = read_json(METRIC_PATH)
metric_dict['all'] = metric_dict['all']+ [test_acc] if 'all' in metric_dict else [] + [test_acc]
else:
metric_dict={}
metric_dict['all'] = [test_acc]
for i in range(10):
data_per_digit = MNISTCorrupted_By_Digit(TRAIN_TRANSFORM,i)
dataloader_per_digit = torch.utils.data.DataLoader(data_per_digit,batch_size=len(data_per_digit), shuffle=False)
data_per_digit, label_per_digit = iter(dataloader_per_digit).next()
output = network(data_per_digit)
pred = output.data.max(1, keepdim=True)[1]
correct = pred.eq(label_per_digit.data.view_as(pred)).sum()
acc = 100. * correct / len(data_per_digit)
acc=acc.item()
if os.path.exists(METRIC_PATH):
metric_dict[str(i)].append(acc)
else:
metric_dict[str(i)] = [acc]
dump_json(thing=metric_dict,file=METRIC_PATH)
# Push models and metrics to hub
model_repo.push_to_hub()
return test_metric
# Update model weights again
model_state_dict = MODEL_WEIGHTS_PATH
optimizer_state_dict = OPTIMIZER_PATH
model_repo.git_pull()
if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
network_state_dict = torch.load(model_state_dict)
network.load_state_dict(network_state_dict)
optimizer_state_dict = torch.load(optimizer_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
else:
# Use best weights
BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth"
BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth"
network_state_dict = torch.load(BEST_WEIGHTS_MODEL)
network.load_state_dict(network_state_dict)
optimizer_state_dict = torch.load(BEST_WEIGHTS_OPTIMIZER)
optimizer.load_state_dict(optimizer_state_dict)
if not os.path.exists(METRIC_PATH):
_ = train_and_test(False)
def image_classifier(inp):
"""
It loads the latest model weights from the model repository, and then uses those weights to make a
prediction on the input image.
:param inp: the image to be classified
:return: A dictionary of the form {class_number: confidence}
"""
# Get latest model weights ----------------
model_repo.git_pull()
model_state_dict = MODEL_WEIGHTS_PATH
optimizer_state_dict = OPTIMIZER_PATH
which_weights=''
if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
which_weights = "Using weights from model repo"
network_state_dict = torch.load(model_state_dict)
network.load_state_dict(network_state_dict)
optimizer_state_dict = torch.load(optimizer_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
else:
# Use best weights
which_weights = "Using default best weights"
BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth"
BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth"
network.load_state_dict(torch.load(BEST_WEIGHTS_MODEL))
optimizer.load_state_dict(torch.load(BEST_WEIGHTS_OPTIMIZER))
network.eval()
input_image = TRAIN_TRANSFORM(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(network(input_image)[0], dim=0)
#pred_number = prediction.data.max(1, keepdim=True)[1]
sorted_prediction = torch.sort(prediction,descending=True)
confidences={}
for s,v in zip(sorted_prediction.indices.numpy().tolist(),sorted_prediction.values.numpy().tolist()):
confidences.update({s:v})
return confidences
def flag(input_image,correct_result,adversarial_number):
"""
It takes in an image, the correct result, and the number of adversarial images that have been
uploaded so far. It saves the image and metadata to a local directory, uploads the image and
metadata to the hub, and then pulls the data from the hub to the local directory. If the number of
images in the local directory is divisible by the TRAIN_CUTOFF, then it trains the model on the
adversarial data
:param input_image: The adversarial image that you want to save
:param correct_result: The correct number that the image represents
:param adversarial_number: This is the number of adversarial examples that have been uploaded to the
dataset
:return: The output is the output of the flag function.
"""
adversarial_number = 0 if None else adversarial_number
metadata_name = get_unique_name()
SAVE_FILE_DIR = os.path.join(LOCAL_DIR,metadata_name)
os.makedirs(SAVE_FILE_DIR,exist_ok=True)
image_output_filename = os.path.join(SAVE_FILE_DIR,'image.png')
try:
input_image.save(image_output_filename)
except Exception:
raise Exception(f"Had issues saving PIL image to file")
# Write metadata.json to file
json_file_path = os.path.join(SAVE_FILE_DIR,'metadata.jsonl')
metadata= {'id':metadata_name,'file_name':'image.png',
'correct_number':correct_result
}
dump_json(metadata,json_file_path)
# Simply upload the image file and metadata using the hub's upload_file
# Upload the image
repo_image_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'image.png'))
_ = upload_file(path_or_fileobj = image_output_filename,
path_in_repo =repo_image_path,
repo_id=f'chrisjay/{HF_DATASET}',
repo_type='dataset',
token=HF_TOKEN
)
# Upload the metadata
repo_json_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'metadata.jsonl'))
_ = upload_file(path_or_fileobj = json_file_path,
path_in_repo =repo_json_path,
repo_id=f'chrisjay/{HF_DATASET}',
repo_type='dataset',
token=HF_TOKEN
)
adversarial_number+=1
output = f'<div> ✔ ({adversarial_number}) Successfully saved your adversarial data. </div>'
repo.git_pull()
length_of_dataset = len([f for f in os.scandir("./data_mnist/data")])
test_metric = f"<html> {DEFAULT_TEST_METRIC} </html>"
if length_of_dataset % TRAIN_CUTOFF ==0:
test_metric_ = train_and_test()
test_metric = f"<html> {test_metric_} </html>"
output = f'<div> ✔ ({adversarial_number}) Successfully saved your adversarial data and trained the model on adversarial data! </div>'
return output,adversarial_number
def get_number_dict(DATA_DIR):
"""
It takes a directory as input, and returns a list of the number of times each number appears in the
metadata.jsonl files in that directory
:param DATA_DIR: The directory where the data is stored
"""
files = [f.name for f in os.scandir(DATA_DIR)]
metadata_jsons = [read_json_lines(os.path.join(os.path.join(DATA_DIR,f),'metadata.jsonl')) for f in files]
numbers = [m[0]['correct_number'] for m in metadata_jsons if m is not None]
numbers_count = Counter(numbers)
numbers_count_keys = list(numbers_count.keys())
numbers_count_values = [numbers_count[k] for k in numbers_count_keys]
return numbers_count_keys,numbers_count_values
def get_statistics():
"""
It loads the model and optimizer state dicts, pulls the latest data from the repo, gets the number
of adversarial samples per digit, plots the distribution of adversarial samples per digit, plots the
test accuracy per digit per train step, and plots the test accuracy for all digits per train step
:return: the following:
"""
model_repo.git_pull()
model_state_dict = MODEL_WEIGHTS_PATH
optimizer_state_dict = OPTIMIZER_PATH
if os.path.exists(model_state_dict):
network_state_dict = torch.load(model_state_dict)
network.load_state_dict(network_state_dict)
if os.path.exists(optimizer_state_dict):
optimizer_state_dict = torch.load(optimizer_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
repo.git_pull()
DATA_DIR = './data_mnist/data'
numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR)
STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values))
plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples per digit",True)
fig_d, ax_d = plt.subplots(tight_layout=True)
if os.path.exists(METRIC_PATH):
metric_dict = read_json(METRIC_PATH)
for i in range(10):
try:
x_i = [i+1 for i in range(len(metric_dict[str(i)]))]
ax_d.plot(x_i, metric_dict[str(i)],label=str(i))
except Exception:
continue
ax_d.set_xticks(range(0, len(metric_dict['0'])+1, 1))
else:
metric_dict={}
fig_d.legend()
ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step")
done_html = f"""<div style="color: green">
<p> ✅ Statistics loaded successfully! Click `{GET_STATISTICS_MESSAGE}`to reload.</p>
</div>
"""
# Plot for total test accuracy for all digits
fig_all, ax_all = plt.subplots(tight_layout=True)
x_i = [i+1 for i in range(len(metric_dict['all']))]
ax_all.plot(x_i, metric_dict['all'])
ax_all.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy for all digits")
ax_all.set_xticks(range(0, x_i[-1]+1, 1))
return plt_digits,ax_d.figure,ax_all.figure,done_html,STATS_EXPLANATION_
def main():
block = gr.Blocks(css=BLOCK_CSS)
with block:
gr.Markdown(TITLE)
gr.Markdown(description)
with gr.Tabs():
with gr.TabItem('MNIST'):
gr.Markdown(WHAT_TO_DO)
#test_metric = gr.outputs.HTML("")
with gr.Row():
image_input =gr.inputs.Image(source="canvas",shape=(28,28),invert_colors=True,image_mode="L",type="pil")
label_output = gr.outputs.Label(num_top_classes=2)
gr.Markdown(MODEL_IS_WRONG)
number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?")
gr.Markdown('Please wait a while after you press `Flag`. It takes time.')
flag_btn = gr.Button("Flag")
output_result = gr.outputs.HTML()
adversarial_number = gr.Variable(value=0)
image_input.change(image_classifier,inputs = [image_input],outputs=[label_output])
flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,adversarial_number])
with gr.TabItem('Dashboard') as dashboard:
get_stat = gr.Button(f'{GET_STATISTICS_MESSAGE}')
notification = gr.HTML(f"""<div style="color: green">
<p> ⌛ Click `{GET_STATISTICS_MESSAGE}` to generate statistics... </p>
</div>
""")
stats = gr.Markdown()
stat_adv_image =gr.Plot(type="matplotlib")
gr.Markdown(DASHBOARD_EXPLANATION)
test_results=gr.Plot(type="matplotlib")
gr.Markdown(DASHBOARD_EXPLANATION_TEST)
test_results_all=gr.Plot(type="matplotlib")
#dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification,stats])
get_stat.click(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,test_results_all,notification,stats])
block.launch()
if __name__ == "__main__":
main() |