Spaces:
Runtime error
Runtime error
# import gradio as gr | |
# from transformers import pipeline | |
# pipe = pipeline("image-classification", model="chaojie777/google-vit-base-patch16-224-in21k") | |
# images = ["images/daisy.jpg", "images/dandelion.jpg", "images/rosa.jpg", "images/sunflower.jpg", "images/tulip.jpg"] | |
# iface = gr.Interface.from_pipeline( | |
# pipe, | |
# examples= [ [example] for example in images], | |
# description="Final project that labels flowers images into: Daisy, Dandelion, Rose, Sunflower, Tulip", | |
# title="Flower Classifier - Vit" | |
# ) | |
# iface.launch() | |
import torch | |
import torchvision.transforms as transforms | |
from torchvision import models | |
from torch import nn | |
import torch.nn.functional as F | |
import gradio as gr | |
from PIL import Image | |
import json | |
# Get cpu or gpu device for training. | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using {device} device") | |
# 设置模型路径 | |
model_path = './best.pth' # 替换为您训练的模型的路径 | |
num_classes = 5 | |
label_name_list = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] | |
# 创建并加载模型 | |
resnet18 = models.resnet18(pretrained=True) | |
num_ftrs = resnet18.fc.in_features | |
resnet18.fc = nn.Linear(num_ftrs, num_classes) | |
resnet18 = resnet18.to(device) | |
# 加载训练好的模型参数 | |
resnet18.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
model = resnet18 | |
model.eval() | |
# 图片转换 | |
train_transform = transforms.Compose([ | |
# transforms.RandomRotation(5), | |
# transforms.RandomHorizontalFlip(), | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
def predict(inp): | |
# 定义预处理变换 | |
transform = train_transform | |
# 加载图片并进行预处理 | |
image = transform(inp).unsqueeze(0).to(device) | |
# 使用模型进行预测 | |
with torch.no_grad(): | |
output = model(image) | |
# 数据后处理 | |
# 计算预测概率 | |
pred_score = nn.functional.softmax(output[0], dim=0) | |
pred_score = pred_score.cpu().numpy() | |
# 获取预测结果 | |
pred_index = torch.argmax(output, dim=1).item() | |
pred_label = label_name_list[pred_index] | |
# 转为json字符串格式 | |
result_dict = {'pred_score': str(max(pred_score)), 'pred_index': str(pred_index), 'pred_label': pred_label} | |
result_json = json.dumps(result_dict) | |
return result_json | |
images = ["images/daisy.jpg", "images/dandelion.jpg", "images/rosa.jpg", "images/sunflower.jpg", "images/tulip.jpg"] | |
# 设置Gradio接口 | |
demo = gr.Interface(fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs="text", | |
examples= [ [example] for example in images], | |
) | |
# 启动Gradio接口 | |
demo.launch() | |
# import torch | |
# import torchvision.transforms as transforms | |
# from torchvision import models | |
# from torch import nn | |
# import torch.nn.functional as F | |
# import gradio as gr | |
# import json | |
# # Get cpu or gpu device for training. | |
# device = "cuda" if torch.cuda.is_available() else "cpu" | |
# print(f"Using {device} device") | |
# model_path = './resnet18_flower.pth' | |
# num_classes = 5 | |
# label_name_list = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] | |
# # step1: 创建并加载模型 | |
# resnet18 = models.resnet18(pretrained=True) | |
# num_ftrs = resnet18.fc.in_features | |
# resnet18.fc = nn.Linear(num_ftrs, num_classes) | |
# resnet18 = resnet18.to(device) | |
# resnet18.load_state_dict(torch.load(model_path,map_location=torch.device('cpu'))) | |
# model = resnet18 | |
# # model = base_model_vgg13 | |
# model.eval() | |
# # step2: 图片转换 | |
# train_transform = transforms.Compose([ | |
# # transforms.RandomRotation(5), | |
# # transforms.RandomHorizontalFlip(), | |
# transforms.Resize((224,224)), | |
# transforms.ToTensor(), | |
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
# ]) | |
# def predict(inp): | |
# # 定义预处理变换 | |
# transform = train_transform | |
# # 加载图片并进行预处理 | |
# # image = Image.open(image_path) | |
# image = transform(inp).unsqueeze(0).to(device) | |
# # step3:使用模型进行预测 | |
# with torch.no_grad(): | |
# output = model(image) | |
# # step4:数据后处理 | |
# # 计算预测概率 | |
# pred_score = nn.functional.softmax(output[0], dim=0) | |
# pred_score = pred_score.cpu().numpy() | |
# # 获取预测结果 | |
# pred_index = torch.argmax(output, dim=1).item() | |
# pred_label = label_name_list[pred_index] | |
# # 转为json字符串格式 | |
# result_dict = {'pred_score':str(max(pred_score)),'pred_index':str(pred_index),'pred_label':pred_label } | |
# result_json = json.dumps(result_dict) | |
# return result_json | |
# demo = gr.Interface(fn=predict, | |
# inputs=gr.Image(type="pil"), | |
# outputs="text", | |
# examples=["./592px-Red_sunflower.jpg"], | |
# ) | |
# # demo.launch(debug=True) | |
# demo.launch() | |