Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torch import nn | |
from PIL import Image | |
from torchvision import transforms | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
CLASSES = ['guro', 'pigs', 'proofs', 'protyk', 'safe', 'shit'] | |
NUM_CLASSES = len(CLASSES) | |
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
device = torch.device('cpu') | |
model = models.resnet18(pretrained=True) | |
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES) | |
model.load_state_dict(torch.load('best_model.pth', map_location=device)) | |
model.to(device) | |
model.eval() | |
# Определение трансформаций для изображений | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Функция для предсказания | |
def predict(img): | |
img = Image.fromarray(img) | |
img = transform(img) | |
with torch.no_grad(): | |
outputs = model(img.unsqueeze(0).to(device)) | |
probabilities = torch.softmax(outputs, dim=1).to('cpu') | |
labels = [CLASSES[i] for i in range(len(CLASSES))] | |
result = [dict(zip(labels, probabilities.numpy()[0])), dict(zip(labels, probabilities.numpy()[0]))] | |
return result[0] | |
# Интерфейс Gradio | |
gr.Interface(fn=predict, inputs="image", outputs="label").launch() | |