Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
import torchvision | |
from PIL import Image | |
from utils import * | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from huggingface_hub import Repository, upload_file | |
from torch.utils.data import Dataset | |
import numpy as np | |
from collections import Counter | |
with open('app.css','r') as f: | |
BLOCK_CSS = f.read() | |
n_epochs = 10 | |
batch_size_train = 128 | |
batch_size_test = 1000 | |
learning_rate = 0.01 | |
adv_learning_rate= 0.001 | |
momentum = 0.5 | |
log_interval = 10 | |
random_seed = 1 | |
TRAIN_CUTOFF = 10 | |
TEST_PER_SAMPLE = 5000 | |
DASHBOARD_EXPLANATION = DASHBOARD_EXPLANATION.format(TEST_PER_SAMPLE=TEST_PER_SAMPLE) | |
WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF) | |
MODEL_PATH = 'model' | |
METRIC_PATH = os.path.join(MODEL_PATH,'metrics.json') | |
MODEL_WEIGHTS_PATH = os.path.join(MODEL_PATH,'mnist_model.pth') | |
OPTIMIZER_PATH = os.path.join(MODEL_PATH,'optimizer.pth') | |
REPOSITORY_DIR = "data" | |
LOCAL_DIR = 'data_local' | |
os.makedirs(LOCAL_DIR,exist_ok=True) | |
GET_STATISTICS_MESSAGE = "Get Statistics" | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
MODEL_REPO = 'mnist-adversarial-model' | |
HF_DATASET ="mnist-adversarial-dataset" | |
DATASET_REPO_URL = f"https://huggingface.co./datasets/chrisjay/{HF_DATASET}" | |
MODEL_REPO_URL = f"https://huggingface.co./model/chrisjay/{MODEL_REPO}" | |
repo = Repository( | |
local_dir="data_mnist", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN | |
) | |
repo.git_pull() | |
model_repo = Repository( | |
local_dir=MODEL_PATH, clone_from=MODEL_REPO_URL, use_auth_token=HF_TOKEN, repo_type="model" | |
) | |
model_repo.git_pull() | |
torch.backends.cudnn.enabled = False | |
torch.manual_seed(random_seed) | |
class MNISTAdversarial_Dataset(Dataset): | |
def __init__(self,data_dir,transform): | |
repo.git_pull() | |
self.data_dir = os.path.join(data_dir,'data') | |
self.transform = transform | |
files = [f.name for f in os.scandir(self.data_dir)] | |
self.images = [] | |
self.numbers = [] | |
for f in files: | |
self.FOLDER = os.path.join(os.path.join(self.data_dir,f)) | |
metadata_path = os.path.join(self.FOLDER,'metadata.jsonl') | |
image_path =os.path.join(self.FOLDER,'image.png') | |
if os.path.exists(image_path) and os.path.exists(metadata_path): | |
metadata = read_json_lines(metadata_path) | |
if metadata is not None: | |
img = Image.open(image_path) | |
self.images.append(img) | |
self.numbers.append(metadata[0]['correct_number']) | |
assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers." | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self,idx): | |
img, label = self.images[idx], self.numbers[idx] | |
img = self.transform(img) | |
return img, label | |
class MNISTCorrupted_By_Digit(Dataset): | |
def __init__(self,transform,digit,limit=TEST_PER_SAMPLE): | |
self.transform = transform | |
self.digit = digit | |
corrupted_dir="./mnist_c" | |
files = [f.name for f in os.scandir(corrupted_dir)] | |
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy')) for f in files] | |
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy')) for f in files] | |
self.data = np.vstack(images) | |
self.labels = np.hstack(labels) | |
assert (self.data.shape[0] == self.labels.shape[0]) | |
mask = self.labels == self.digit | |
data_masked = self.data[mask] | |
# Just to be on the safe side, ensure limit is more than the minimum | |
limit = min(limit,data_masked.shape[0]) | |
self.data_for_use = data_masked[:limit] | |
self.labels_for_use = self.labels[mask][:limit] | |
assert (self.data_for_use.shape[0] == self.labels_for_use.shape[0]) | |
def __len__(self): | |
return len(self.data_for_use) | |
def __getitem__(self,idx): | |
if torch.is_tensor(idx): | |
idx = idx.tolist() | |
image = self.data_for_use[idx] | |
label = self.labels_for_use[idx] | |
if self.transform: | |
image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms | |
image = self.transform(image_pil) | |
return image, label | |
class MNISTCorrupted(Dataset): | |
def __init__(self,transform): | |
self.transform = transform | |
corrupted_dir="./mnist_c" | |
files = [f.name for f in os.scandir(corrupted_dir)] | |
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy'))[:TEST_PER_SAMPLE] for f in files] | |
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy'))[:TEST_PER_SAMPLE] for f in files] | |
self.data = np.vstack(images) | |
self.labels = np.hstack(labels) | |
assert (self.data.shape[0] == self.labels.shape[0]) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
if torch.is_tensor(idx): | |
idx = idx.tolist() | |
image = self.data[idx] | |
label = self.labels[idx] | |
if self.transform: | |
image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms | |
image = self.transform(image_pil) | |
return image, label | |
TRAIN_TRANSFORM = torchvision.transforms.Compose([ | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize( | |
(0.1307,), (0.3081,)) | |
]) | |
test_loader = torch.utils.data.DataLoader(MNISTCorrupted(TRAIN_TRANSFORM), | |
batch_size=batch_size_test, shuffle=False) | |
# Source: https://nextjournal.com/gkoehler/pytorch-mnist | |
class MNIST_Model(nn.Module): | |
def __init__(self): | |
super(MNIST_Model, self).__init__() | |
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | |
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | |
self.conv2_drop = nn.Dropout2d() | |
self.fc1 = nn.Linear(320, 50) | |
self.fc2 = nn.Linear(50, 10) | |
def forward(self, x): | |
x = F.relu(F.max_pool2d(self.conv1(x), 2)) | |
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) | |
x = x.view(-1, 320) | |
x = F.relu(self.fc1(x)) | |
x = F.dropout(x, training=self.training) | |
x = self.fc2(x) | |
return F.log_softmax(x) | |
def train(epochs,network,optimizer,train_loader): | |
train_losses=[] | |
network.train() | |
for epoch in range(epochs): | |
for batch_idx, (data, target) in enumerate(train_loader): | |
optimizer.zero_grad() | |
output = network(data) | |
loss = F.nll_loss(output, target) | |
loss.backward() | |
optimizer.step() | |
if batch_idx % log_interval == 0: | |
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | |
epoch, batch_idx * len(data), len(train_loader.dataset), | |
100. * batch_idx / len(train_loader), loss.item())) | |
train_losses.append(loss.item()) | |
torch.save(network.state_dict(), MODEL_WEIGHTS_PATH) | |
torch.save(optimizer.state_dict(), OPTIMIZER_PATH) | |
def test(): | |
test_losses=[] | |
network.eval() | |
test_loss = 0 | |
correct = 0 | |
with torch.no_grad(): | |
for data, target in test_loader: | |
output = network(data) | |
test_loss += F.nll_loss(output, target, size_average=False).item() | |
pred = output.data.max(1, keepdim=True)[1] | |
correct += pred.eq(target.data.view_as(pred)).sum() | |
test_loss /= len(test_loader.dataset) | |
test_losses.append(test_loss) | |
acc = 100. * correct / len(test_loader.dataset) | |
acc = acc.item() | |
test_metric = 'γ½Current test metric -> Avg. loss: `{:.4f}`, Accuracy: `{:.0f}%`\n'.format( | |
test_loss,acc) | |
print(test_metric) | |
return test_metric,acc | |
random_seed = 1 | |
torch.backends.cudnn.enabled = False | |
torch.manual_seed(random_seed) | |
network = MNIST_Model() | |
optimizer = optim.SGD(network.parameters(), lr=learning_rate, | |
momentum=momentum) | |
train_loader = torch.utils.data.DataLoader( | |
torchvision.datasets.MNIST('./files/', train=True, download=True, | |
transform=TRAIN_TRANSFORM), | |
batch_size=batch_size_train, shuffle=True) | |
test_iid_loader = torch.utils.data.DataLoader( | |
torchvision.datasets.MNIST('./files/', train=False, download=True, | |
transform=TRAIN_TRANSFORM), | |
batch_size=batch_size_test, shuffle=True) | |
model_state_dict = MODEL_WEIGHTS_PATH | |
optimizer_state_dict = OPTIMIZER_PATH | |
if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict): | |
network_state_dict = torch.load(model_state_dict) | |
network.load_state_dict(network_state_dict) | |
optimizer_state_dict = torch.load(optimizer_state_dict) | |
optimizer.load_state_dict(optimizer_state_dict) | |
# Train model | |
#n_epochs=20 | |
#train(n_epochs,network,optimizer,train_loader) | |
#test() | |
def train_and_test(train_model=True): | |
if train_model: | |
# Train for one epoch and test | |
train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM) | |
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True) | |
train(n_epochs,network,optimizer,train_loader) | |
test_metric,test_acc = test() | |
network.eval() | |
if os.path.exists(METRIC_PATH): | |
metric_dict = read_json(METRIC_PATH) | |
metric_dict['all'] = metric_dict['all']+ [test_acc] if 'all' in metric_dict else [] + [test_acc] | |
else: | |
metric_dict={} | |
metric_dict['all'] = [test_acc] | |
for i in range(10): | |
data_per_digit = MNISTCorrupted_By_Digit(TRAIN_TRANSFORM,i) | |
dataloader_per_digit = torch.utils.data.DataLoader(data_per_digit,batch_size=len(data_per_digit), shuffle=False) | |
data_per_digit, label_per_digit = iter(dataloader_per_digit).next() | |
output = network(data_per_digit) | |
pred = output.data.max(1, keepdim=True)[1] | |
correct = pred.eq(label_per_digit.data.view_as(pred)).sum() | |
acc = 100. * correct / len(data_per_digit) | |
acc=acc.item() | |
if os.path.exists(METRIC_PATH): | |
metric_dict[str(i)].append(acc) | |
else: | |
metric_dict[str(i)] = [acc] | |
dump_json(thing=metric_dict,file=METRIC_PATH) | |
# Push models and metrics to hub | |
model_repo.push_to_hub() | |
return test_metric | |
# Update model weights again | |
model_state_dict = MODEL_WEIGHTS_PATH | |
optimizer_state_dict = OPTIMIZER_PATH | |
model_repo.git_pull() | |
if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict): | |
network_state_dict = torch.load(model_state_dict) | |
network.load_state_dict(network_state_dict) | |
optimizer_state_dict = torch.load(optimizer_state_dict) | |
optimizer.load_state_dict(optimizer_state_dict) | |
else: | |
# Use best weights | |
BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth" | |
BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth" | |
network_state_dict = torch.load(BEST_WEIGHTS_MODEL) | |
network.load_state_dict(network_state_dict) | |
optimizer_state_dict = torch.load(BEST_WEIGHTS_OPTIMIZER) | |
optimizer.load_state_dict(optimizer_state_dict) | |
if not os.path.exists(METRIC_PATH): | |
_ = train_and_test(False) | |
def image_classifier(inp): | |
""" | |
It loads the latest model weights from the model repository, and then uses those weights to make a | |
prediction on the input image. | |
:param inp: the image to be classified | |
:return: A dictionary of the form {class_number: confidence} | |
""" | |
# Get latest model weights ---------------- | |
model_repo.git_pull() | |
model_state_dict = MODEL_WEIGHTS_PATH | |
optimizer_state_dict = OPTIMIZER_PATH | |
which_weights='' | |
if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict): | |
which_weights = "Using weights from model repo" | |
network_state_dict = torch.load(model_state_dict) | |
network.load_state_dict(network_state_dict) | |
optimizer_state_dict = torch.load(optimizer_state_dict) | |
optimizer.load_state_dict(optimizer_state_dict) | |
else: | |
# Use best weights | |
which_weights = "Using default best weights" | |
BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth" | |
BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth" | |
network.load_state_dict(torch.load(BEST_WEIGHTS_MODEL)) | |
optimizer.load_state_dict(torch.load(BEST_WEIGHTS_OPTIMIZER)) | |
network.eval() | |
input_image = TRAIN_TRANSFORM(inp).unsqueeze(0) | |
with torch.no_grad(): | |
prediction = torch.nn.functional.softmax(network(input_image)[0], dim=0) | |
#pred_number = prediction.data.max(1, keepdim=True)[1] | |
sorted_prediction = torch.sort(prediction,descending=True) | |
confidences={} | |
for s,v in zip(sorted_prediction.indices.numpy().tolist(),sorted_prediction.values.numpy().tolist()): | |
confidences.update({s:v}) | |
return confidences | |
def flag(input_image,correct_result,adversarial_number): | |
""" | |
It takes in an image, the correct result, and the number of adversarial images that have been | |
uploaded so far. It saves the image and metadata to a local directory, uploads the image and | |
metadata to the hub, and then pulls the data from the hub to the local directory. If the number of | |
images in the local directory is divisible by the TRAIN_CUTOFF, then it trains the model on the | |
adversarial data | |
:param input_image: The adversarial image that you want to save | |
:param correct_result: The correct number that the image represents | |
:param adversarial_number: This is the number of adversarial examples that have been uploaded to the | |
dataset | |
:return: The output is the output of the flag function. | |
""" | |
adversarial_number = 0 if None else adversarial_number | |
metadata_name = get_unique_name() | |
SAVE_FILE_DIR = os.path.join(LOCAL_DIR,metadata_name) | |
os.makedirs(SAVE_FILE_DIR,exist_ok=True) | |
image_output_filename = os.path.join(SAVE_FILE_DIR,'image.png') | |
try: | |
input_image.save(image_output_filename) | |
except Exception: | |
raise Exception(f"Had issues saving PIL image to file") | |
# Write metadata.json to file | |
json_file_path = os.path.join(SAVE_FILE_DIR,'metadata.jsonl') | |
metadata= {'id':metadata_name,'file_name':'image.png', | |
'correct_number':correct_result | |
} | |
dump_json(metadata,json_file_path) | |
# Simply upload the image file and metadata using the hub's upload_file | |
# Upload the image | |
repo_image_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'image.png')) | |
_ = upload_file(path_or_fileobj = image_output_filename, | |
path_in_repo =repo_image_path, | |
repo_id=f'chrisjay/{HF_DATASET}', | |
repo_type='dataset', | |
token=HF_TOKEN | |
) | |
# Upload the metadata | |
repo_json_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'metadata.jsonl')) | |
_ = upload_file(path_or_fileobj = json_file_path, | |
path_in_repo =repo_json_path, | |
repo_id=f'chrisjay/{HF_DATASET}', | |
repo_type='dataset', | |
token=HF_TOKEN | |
) | |
adversarial_number+=1 | |
output = f'<div> β ({adversarial_number}) Successfully saved your adversarial data. </div>' | |
repo.git_pull() | |
length_of_dataset = len([f for f in os.scandir("./data_mnist/data")]) | |
test_metric = f"<html> {DEFAULT_TEST_METRIC} </html>" | |
if length_of_dataset % TRAIN_CUTOFF ==0: | |
test_metric_ = train_and_test() | |
test_metric = f"<html> {test_metric_} </html>" | |
output = f'<div> β ({adversarial_number}) Successfully saved your adversarial data and trained the model on adversarial data! </div>' | |
return output,adversarial_number | |
def get_number_dict(DATA_DIR): | |
""" | |
It takes a directory as input, and returns a list of the number of times each number appears in the | |
metadata.jsonl files in that directory | |
:param DATA_DIR: The directory where the data is stored | |
""" | |
files = [f.name for f in os.scandir(DATA_DIR)] | |
metadata_jsons = [read_json_lines(os.path.join(os.path.join(DATA_DIR,f),'metadata.jsonl')) for f in files] | |
numbers = [m[0]['correct_number'] for m in metadata_jsons if m is not None] | |
numbers_count = Counter(numbers) | |
numbers_count_keys = list(numbers_count.keys()) | |
numbers_count_values = [numbers_count[k] for k in numbers_count_keys] | |
return numbers_count_keys,numbers_count_values | |
def get_statistics(): | |
""" | |
It loads the model and optimizer state dicts, pulls the latest data from the repo, gets the number | |
of adversarial samples per digit, plots the distribution of adversarial samples per digit, plots the | |
test accuracy per digit per train step, and plots the test accuracy for all digits per train step | |
:return: the following: | |
""" | |
model_repo.git_pull() | |
model_state_dict = MODEL_WEIGHTS_PATH | |
optimizer_state_dict = OPTIMIZER_PATH | |
if os.path.exists(model_state_dict): | |
network_state_dict = torch.load(model_state_dict) | |
network.load_state_dict(network_state_dict) | |
if os.path.exists(optimizer_state_dict): | |
optimizer_state_dict = torch.load(optimizer_state_dict) | |
optimizer.load_state_dict(optimizer_state_dict) | |
repo.git_pull() | |
DATA_DIR = './data_mnist/data' | |
numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR) | |
STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values)) | |
plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples per digit",True) | |
fig_d, ax_d = plt.subplots(tight_layout=True) | |
if os.path.exists(METRIC_PATH): | |
metric_dict = read_json(METRIC_PATH) | |
for i in range(10): | |
try: | |
x_i = [i+1 for i in range(len(metric_dict[str(i)]))] | |
ax_d.plot(x_i, metric_dict[str(i)],label=str(i)) | |
except Exception: | |
continue | |
ax_d.set_xticks(range(0, len(metric_dict['0'])+1, 1)) | |
else: | |
metric_dict={} | |
fig_d.legend() | |
ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step") | |
done_html = f"""<div style="color: green"> | |
<p> β Statistics loaded successfully! Click `{GET_STATISTICS_MESSAGE}`to reload.</p> | |
</div> | |
""" | |
# Plot for total test accuracy for all digits | |
fig_all, ax_all = plt.subplots(tight_layout=True) | |
x_i = [i+1 for i in range(len(metric_dict['all']))] | |
ax_all.plot(x_i, metric_dict['all']) | |
ax_all.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy for all digits") | |
ax_all.set_xticks(range(0, x_i[-1]+1, 1)) | |
return plt_digits,ax_d.figure,ax_all.figure,done_html,STATS_EXPLANATION_ | |
def main(): | |
block = gr.Blocks(css=BLOCK_CSS) | |
with block: | |
gr.Markdown(TITLE) | |
gr.Markdown(description) | |
with gr.Tabs(): | |
with gr.TabItem('MNIST'): | |
gr.Markdown(WHAT_TO_DO) | |
#test_metric = gr.outputs.HTML("") | |
with gr.Row(): | |
image_input =gr.inputs.Image(source="canvas",shape=(28,28),invert_colors=True,image_mode="L",type="pil") | |
label_output = gr.outputs.Label(num_top_classes=2) | |
gr.Markdown(MODEL_IS_WRONG) | |
number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?") | |
gr.Markdown('Please wait a while after you press `Flag`. It takes time.') | |
flag_btn = gr.Button("Flag") | |
output_result = gr.outputs.HTML() | |
adversarial_number = gr.Variable(value=0) | |
image_input.change(image_classifier,inputs = [image_input],outputs=[label_output]) | |
flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,adversarial_number]) | |
with gr.TabItem('Dashboard') as dashboard: | |
get_stat = gr.Button(f'{GET_STATISTICS_MESSAGE}') | |
notification = gr.HTML(f"""<div style="color: green"> | |
<p> β Click `{GET_STATISTICS_MESSAGE}` to generate statistics... </p> | |
</div> | |
""") | |
stats = gr.Markdown() | |
stat_adv_image =gr.Plot(type="matplotlib") | |
gr.Markdown(DASHBOARD_EXPLANATION) | |
test_results=gr.Plot(type="matplotlib") | |
gr.Markdown(DASHBOARD_EXPLANATION_TEST) | |
test_results_all=gr.Plot(type="matplotlib") | |
#dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification,stats]) | |
get_stat.click(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,test_results_all,notification,stats]) | |
block.launch() | |
if __name__ == "__main__": | |
main() |