# import gradio as gr # from transformers import pipeline # pipe = pipeline("image-classification", model="chaojie777/google-vit-base-patch16-224-in21k") # images = ["images/daisy.jpg", "images/dandelion.jpg", "images/rosa.jpg", "images/sunflower.jpg", "images/tulip.jpg"] # iface = gr.Interface.from_pipeline( # pipe, # examples= [ [example] for example in images], # description="Final project that labels flowers images into: Daisy, Dandelion, Rose, Sunflower, Tulip", # title="Flower Classifier - Vit" # ) # iface.launch() import torch import torchvision.transforms as transforms from torchvision import models from torch import nn import torch.nn.functional as F import gradio as gr from PIL import Image import json # Get cpu or gpu device for training. device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using {device} device") # 设置模型路径 model_path = './best.pth' # 替换为您训练的模型的路径 num_classes = 5 label_name_list = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] # 创建并加载模型 resnet18 = models.resnet18(pretrained=True) num_ftrs = resnet18.fc.in_features resnet18.fc = nn.Linear(num_ftrs, num_classes) resnet18 = resnet18.to(device) # 加载训练好的模型参数 resnet18.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model = resnet18 model.eval() # 图片转换 train_transform = transforms.Compose([ # transforms.RandomRotation(5), # transforms.RandomHorizontalFlip(), transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def predict(inp): # 定义预处理变换 transform = train_transform # 加载图片并进行预处理 image = transform(inp).unsqueeze(0).to(device) # 使用模型进行预测 with torch.no_grad(): output = model(image) # 数据后处理 # 计算预测概率 pred_score = nn.functional.softmax(output[0], dim=0) pred_score = pred_score.cpu().numpy() # 获取预测结果 pred_index = torch.argmax(output, dim=1).item() pred_label = label_name_list[pred_index] # 转为json字符串格式 result_dict = {'pred_score': str(max(pred_score)), 'pred_index': str(pred_index), 'pred_label': pred_label} result_json = json.dumps(result_dict) return result_json images = ["images/daisy.jpg", "images/dandelion.jpg", "images/rosa.jpg", "images/sunflower.jpg", "images/tulip.jpg"] # 设置Gradio接口 demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text", examples= [ [example] for example in images], ) # 启动Gradio接口 demo.launch() # import torch # import torchvision.transforms as transforms # from torchvision import models # from torch import nn # import torch.nn.functional as F # import gradio as gr # import json # # Get cpu or gpu device for training. # device = "cuda" if torch.cuda.is_available() else "cpu" # print(f"Using {device} device") # model_path = './resnet18_flower.pth' # num_classes = 5 # label_name_list = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] # # step1: 创建并加载模型 # resnet18 = models.resnet18(pretrained=True) # num_ftrs = resnet18.fc.in_features # resnet18.fc = nn.Linear(num_ftrs, num_classes) # resnet18 = resnet18.to(device) # resnet18.load_state_dict(torch.load(model_path,map_location=torch.device('cpu'))) # model = resnet18 # # model = base_model_vgg13 # model.eval() # # step2: 图片转换 # train_transform = transforms.Compose([ # # transforms.RandomRotation(5), # # transforms.RandomHorizontalFlip(), # transforms.Resize((224,224)), # transforms.ToTensor(), # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ]) # def predict(inp): # # 定义预处理变换 # transform = train_transform # # 加载图片并进行预处理 # # image = Image.open(image_path) # image = transform(inp).unsqueeze(0).to(device) # # step3:使用模型进行预测 # with torch.no_grad(): # output = model(image) # # step4:数据后处理 # # 计算预测概率 # pred_score = nn.functional.softmax(output[0], dim=0) # pred_score = pred_score.cpu().numpy() # # 获取预测结果 # pred_index = torch.argmax(output, dim=1).item() # pred_label = label_name_list[pred_index] # # 转为json字符串格式 # result_dict = {'pred_score':str(max(pred_score)),'pred_index':str(pred_index),'pred_label':pred_label } # result_json = json.dumps(result_dict) # return result_json # demo = gr.Interface(fn=predict, # inputs=gr.Image(type="pil"), # outputs="text", # examples=["./592px-Red_sunflower.jpg"], # ) # # demo.launch(debug=True) # demo.launch()