HienK64BKHN's picture
Update app.py
905f281 verified
raw
history blame contribute delete
No virus
2.06 kB
import gradio as gr
import torch
from Unet import UNet
import torchvision
from torchvision.transforms import functional as f
import os
from timeit import default_timer as timer
device = 'cpu'
model = UNet(device=device, in_channels=3, num_classes=3)
model.load_state_dict(torch.load("./data/models/Unet_v1.pth", map_location=torch.device('cpu')))
image_transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(size=(128, 128)),
torchvision.transforms.ToTensor()
])
def predict(img):
start_time = timer()
img_transformed = image_transforms(img).to(device)
model.eval()
with torch.inference_mode():
y_logits = model(img_transformed.unsqueeze(dim=0)).squeeze(dim=0)
predicted_label = torch.argmax(y_logits, dim=0).to('cpu')
for i in range(3):
for j in range(128):
for z in range(128):
img_transformed[i][j][z] = predicted_label[j][z]
img_transformed = f.to_pil_image(img_transformed)
return img_transformed, round((timer() - start_time), 3)
title = "Animal Segmentation"
description = "An UNet* feature extractor computer vision model to segment animal in an image.\nModel works more precisely on an image that only contains just one animal."
article = "U-Net: Convolutional Networks for Biomedical Image Segmentation (https://arxiv.org/abs/1505.04597)"
example_list = [["examples/" + example] for example in os.listdir("examples")]
demo = gr.Interface(fn=predict, # mapping function from input to output
inputs=gr.Image(type="pil"), # what are the inputs?
outputs=[gr.Image(label="Segmentation"), # what are the outputs?
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
# Create examples list from "examples/" directory
examples=example_list,
title=title,
description=description,
article=article)
demo.launch()