chaojie777 commited on
Commit
64b7d21
·
1 Parent(s): 5f31b46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -1
app.py CHANGED
@@ -8,7 +8,83 @@ iface = gr.Interface.from_pipeline(
8
  pipe,
9
  examples= [ [example] for example in images],
10
  description="Final project that labels flowers images into: Daisy, Dandelion, Rose, Sunflower, Tulip",
11
- title="Vit"
12
  )
13
 
14
  iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  pipe,
9
  examples= [ [example] for example in images],
10
  description="Final project that labels flowers images into: Daisy, Dandelion, Rose, Sunflower, Tulip",
11
+ title="Flower Classifier - Vit"
12
  )
13
 
14
  iface.launch()
15
+ # import torch
16
+ # import torchvision.transforms as transforms
17
+ # from torchvision import models
18
+ # from torch import nn
19
+ # import torch.nn.functional as F
20
+ # import gradio as gr
21
+ # import json
22
+
23
+
24
+ # # Get cpu or gpu device for training.
25
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ # print(f"Using {device} device")
27
+
28
+ # model_path = './resnet18_flower.pth'
29
+ # num_classes = 5
30
+ # label_name_list = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
31
+
32
+ # # step1: 创建并加载模型
33
+ # resnet18 = models.resnet18(pretrained=True)
34
+ # num_ftrs = resnet18.fc.in_features
35
+ # resnet18.fc = nn.Linear(num_ftrs, num_classes)
36
+ # resnet18 = resnet18.to(device)
37
+
38
+ # resnet18.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))
39
+
40
+ # model = resnet18
41
+ # # model = base_model_vgg13
42
+ # model.eval()
43
+
44
+
45
+ # # step2: 图片转换
46
+ # train_transform = transforms.Compose([
47
+ # # transforms.RandomRotation(5),
48
+ # # transforms.RandomHorizontalFlip(),
49
+ # transforms.Resize((224,224)),
50
+ # transforms.ToTensor(),
51
+ # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
52
+ # ])
53
+
54
+
55
+ # def predict(inp):
56
+ # # 定义预处理变换
57
+ # transform = train_transform
58
+
59
+ # # 加载图片并进行预处理
60
+ # # image = Image.open(image_path)
61
+ # image = transform(inp).unsqueeze(0).to(device)
62
+
63
+ # # step3:使用模型进行预测
64
+ # with torch.no_grad():
65
+ # output = model(image)
66
+
67
+ # # step4:数据后处理
68
+ # # 计算预测概率
69
+ # pred_score = nn.functional.softmax(output[0], dim=0)
70
+ # pred_score = pred_score.cpu().numpy()
71
+
72
+ # # 获取预测结果
73
+ # pred_index = torch.argmax(output, dim=1).item()
74
+ # pred_label = label_name_list[pred_index]
75
+
76
+ # # 转为json字符串格式
77
+ # result_dict = {'pred_score':str(max(pred_score)),'pred_index':str(pred_index),'pred_label':pred_label }
78
+ # result_json = json.dumps(result_dict)
79
+
80
+ # return result_json
81
+
82
+
83
+ # demo = gr.Interface(fn=predict,
84
+ # inputs=gr.Image(type="pil"),
85
+ # outputs="text",
86
+ # examples=["./592px-Red_sunflower.jpg"],
87
+ # )
88
+
89
+ # # demo.launch(debug=True)
90
+ # demo.launch()