chaojie777 commited on
Commit
193a9d6
·
1 Parent(s): 64b7d21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -10
app.py CHANGED
@@ -1,17 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from transformers import pipeline
3
-
4
- pipe = pipeline("image-classification", model="chaojie777/google-vit-base-patch16-224-in21k")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  images = ["images/daisy.jpg", "images/dandelion.jpg", "images/rosa.jpg", "images/sunflower.jpg", "images/tulip.jpg"]
 
 
 
 
 
 
6
 
7
- iface = gr.Interface.from_pipeline(
8
- pipe,
9
- examples= [ [example] for example in images],
10
- description="Final project that labels flowers images into: Daisy, Dandelion, Rose, Sunflower, Tulip",
11
- title="Flower Classifier - Vit"
12
- )
13
 
14
- iface.launch()
15
  # import torch
16
  # import torchvision.transforms as transforms
17
  # from torchvision import models
 
1
+ # import gradio as gr
2
+ # from transformers import pipeline
3
+
4
+ # pipe = pipeline("image-classification", model="chaojie777/google-vit-base-patch16-224-in21k")
5
+ # images = ["images/daisy.jpg", "images/dandelion.jpg", "images/rosa.jpg", "images/sunflower.jpg", "images/tulip.jpg"]
6
+
7
+ # iface = gr.Interface.from_pipeline(
8
+ # pipe,
9
+ # examples= [ [example] for example in images],
10
+ # description="Final project that labels flowers images into: Daisy, Dandelion, Rose, Sunflower, Tulip",
11
+ # title="Flower Classifier - Vit"
12
+ # )
13
+
14
+ # iface.launch()
15
+ import torch
16
+ import torchvision.transforms as transforms
17
+ from torchvision import models
18
+ from torch import nn
19
+ import torch.nn.functional as F
20
  import gradio as gr
21
+ from PIL import Image
22
+ import json
23
+
24
+ # Get cpu or gpu device for training.
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ print(f"Using {device} device")
27
+
28
+ # 设置模型路径
29
+ model_path = './best.pth' # 替换为您训练的模型的路径
30
+ num_classes = 5
31
+ label_name_list = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
32
+
33
+ # 创建并加载模型
34
+ resnet18 = models.resnet18(pretrained=True)
35
+ num_ftrs = resnet18.fc.in_features
36
+ resnet18.fc = nn.Linear(num_ftrs, num_classes)
37
+ resnet18 = resnet18.to(device)
38
+
39
+ # 加载训练好的模型参数
40
+ resnet18.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
41
+ model = resnet18
42
+ model.eval()
43
+
44
+ # 图片转换
45
+ train_transform = transforms.Compose([
46
+ # transforms.RandomRotation(5),
47
+ # transforms.RandomHorizontalFlip(),
48
+ transforms.Resize((224, 224)),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
51
+ ])
52
+
53
+ def predict(inp):
54
+ # 定义预处理变换
55
+ transform = train_transform
56
+
57
+ # 加载图片并进行预处理
58
+ image = transform(inp).unsqueeze(0).to(device)
59
+
60
+ # 使用模型进行预测
61
+ with torch.no_grad():
62
+ output = model(image)
63
+
64
+ # 数据后处理
65
+ # 计算预测概率
66
+ pred_score = nn.functional.softmax(output[0], dim=0)
67
+ pred_score = pred_score.cpu().numpy()
68
+
69
+ # 获取预测结果
70
+ pred_index = torch.argmax(output, dim=1).item()
71
+ pred_label = label_name_list[pred_index]
72
+
73
+ # 转为json字符串格式
74
+ result_dict = {'pred_score': str(max(pred_score)), 'pred_index': str(pred_index), 'pred_label': pred_label}
75
+ result_json = json.dumps(result_dict)
76
+
77
+ return result_json
78
  images = ["images/daisy.jpg", "images/dandelion.jpg", "images/rosa.jpg", "images/sunflower.jpg", "images/tulip.jpg"]
79
+ # 设置Gradio接口
80
+ demo = gr.Interface(fn=predict,
81
+ inputs=gr.Image(type="pil"),
82
+ outputs="text",
83
+ examples= [ [example] for example in images],
84
+ )
85
 
86
+ # 启动Gradio接口
87
+ demo.launch()
 
 
 
 
88
 
 
89
  # import torch
90
  # import torchvision.transforms as transforms
91
  # from torchvision import models