test / app.py
chaojie777's picture
Update app.py
# import gradio as gr
# from transformers import pipeline
# pipe = pipeline("image-classification", model="chaojie777/google-vit-base-patch16-224-in21k")
# images = ["images/daisy.jpg", "images/dandelion.jpg", "images/rosa.jpg", "images/sunflower.jpg", "images/tulip.jpg"]
# iface = gr.Interface.from_pipeline(
# pipe,
# examples= [ [example] for example in images],
# description="Final project that labels flowers images into: Daisy, Dandelion, Rose, Sunflower, Tulip",
# title="Flower Classifier - Vit"
# )
# iface.launch()
import torch
import torchvision.transforms as transforms
from torchvision import models
from torch import nn
import torch.nn.functional as F
import gradio as gr
from PIL import Image
import json
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# 设置模型路径
model_path = './best.pth' # 替换为您训练的模型的路径
num_classes = 5
label_name_list = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
# 创建并加载模型
resnet18 = models.resnet18(pretrained=True)
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, num_classes)
resnet18 = resnet18.to(device)
# 加载训练好的模型参数
resnet18.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model = resnet18
# 图片转换
train_transform = transforms.Compose([
# transforms.RandomRotation(5),
# transforms.RandomHorizontalFlip(),
transforms.Resize((224, 224)),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
def predict(inp):
# 定义预处理变换
transform = train_transform
# 加载图片并进行预处理
image = transform(inp).unsqueeze(0).to(device)
# 使用模型进行预测
with torch.no_grad():
output = model(image)
# 数据后处理
# 计算预测概率
pred_score = nn.functional.softmax(output[0], dim=0)
pred_score = pred_score.cpu().numpy()
# 获取预测结果
pred_index = torch.argmax(output, dim=1).item()
pred_label = label_name_list[pred_index]
# 转为json字符串格式
result_dict = {'pred_score': str(max(pred_score)), 'pred_index': str(pred_index), 'pred_label': pred_label}
result_json = json.dumps(result_dict)
return result_json
images = ["images/daisy.jpg", "images/dandelion.jpg", "images/rosa.jpg", "images/sunflower.jpg", "images/tulip.jpg"]
# 设置Gradio接口
demo = gr.Interface(fn=predict,
examples= [ [example] for example in images],
# 启动Gradio接口
# import torch
# import torchvision.transforms as transforms
# from torchvision import models
# from torch import nn
# import torch.nn.functional as F
# import gradio as gr
# import json
# # Get cpu or gpu device for training.
# device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Using {device} device")
# model_path = './resnet18_flower.pth'
# num_classes = 5
# label_name_list = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
# # step1: 创建并加载模型
# resnet18 = models.resnet18(pretrained=True)
# num_ftrs = resnet18.fc.in_features
# resnet18.fc = nn.Linear(num_ftrs, num_classes)
# resnet18 = resnet18.to(device)
# resnet18.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))
# model = resnet18
# # model = base_model_vgg13
# model.eval()
# # step2: 图片转换
# train_transform = transforms.Compose([
# # transforms.RandomRotation(5),
# # transforms.RandomHorizontalFlip(),
# transforms.Resize((224,224)),
# transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])
# def predict(inp):
# # 定义预处理变换
# transform = train_transform
# # 加载图片并进行预处理
# # image = Image.open(image_path)
# image = transform(inp).unsqueeze(0).to(device)
# # step3:使用模型进行预测
# with torch.no_grad():
# output = model(image)
# # step4:数据后处理
# # 计算预测概率
# pred_score = nn.functional.softmax(output[0], dim=0)
# pred_score = pred_score.cpu().numpy()
# # 获取预测结果
# pred_index = torch.argmax(output, dim=1).item()
# pred_label = label_name_list[pred_index]
# # 转为json字符串格式
# result_dict = {'pred_score':str(max(pred_score)),'pred_index':str(pred_index),'pred_label':pred_label }
# result_json = json.dumps(result_dict)
# return result_json
# demo = gr.Interface(fn=predict,
# inputs=gr.Image(type="pil"),
# outputs="text",
# examples=["./592px-Red_sunflower.jpg"],
# )
# # demo.launch(debug=True)
# demo.launch()