anon5's picture
Update app.py
a8d4562 verified
import gradio as gr
import torch
from torch import nn
from PIL import Image
from torchvision import transforms
import torchvision.models as models
import torchvision.transforms as transforms
CLASSES = ['guro', 'pigs', 'proofs', 'protyk', 'safe', 'shit']
NUM_CLASSES = len(CLASSES)
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.to(device)
model.eval()
# Определение трансформаций для изображений
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Функция для предсказания
def predict(img):
img = Image.fromarray(img)
img = transform(img)
with torch.no_grad():
outputs = model(img.unsqueeze(0).to(device))
probabilities = torch.softmax(outputs, dim=1).to('cpu')
labels = [CLASSES[i] for i in range(len(CLASSES))]
result = [dict(zip(labels, probabilities.numpy()[0])), dict(zip(labels, probabilities.numpy()[0]))]
return result[0]
# Интерфейс Gradio
gr.Interface(fn=predict, inputs="image", outputs="label").launch()