CHCZHC commited on
Commit
0612bb5
1 Parent(s): 56f6da9

Upload 3 files

Browse files
mush_classifier_20230801.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:851cac064760eab030275140b81d7e04595a6bb283154f87f72c1b6bc5cf6166
3
+ size 94141978
mushroom_class_load_predict.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ from PIL import Image
3
+ from PIL import ImageFile
4
+ from torch.utils.data import Dataset,DataLoader
5
+ from transformers import AutoImageProcessor, BitModel, AdamW
6
+ import torch
7
+ from datasets import load_dataset
8
+ from torch import Tensor, nn
9
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
+ import os
11
+ import numpy as np
12
+ from sklearn import metrics, model_selection
13
+ from collections import Counter
14
+
15
+ # model class
16
+ class mush_root_model(torch.nn.Module):
17
+ def __init__(self, num_labels=1):
18
+ super(mush_root_model, self).__init__()
19
+ self.model = BitModel.from_pretrained("google/bit-50")
20
+ self.classifier = nn.Sequential(
21
+ nn.Flatten(),
22
+ nn.Linear(2048, num_labels),
23
+ )
24
+
25
+ def forward(self, input):
26
+ outputs = self.model(**input).pooler_output
27
+ #print(outputs.shape)
28
+ logits = self.classifier(outputs)
29
+ return logits
30
+
31
+ # load model
32
+ model_path="/kaggle/input/mush-room-model-class/mush_classifier_20230801.pth"
33
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
+ model = mush_root_model(num_labels=9)
35
+ model.load_state_dict(torch.load(model_path))
36
+ model.to(device)
37
+ image_processor = AutoImageProcessor.from_pretrained("google/bit-50")
38
+
39
+ # label setting
40
+ labels = ['Amanita', 'Suillus', 'Boletus', 'Lactarius', 'Agaricus', 'Hygrocybe', 'Cortinarius', 'Russula', 'Entoloma']
41
+ toxic_labels = {'Amanita': 1, 'Suillus': 0, 'Boletus': 0, 'Lactarius': 0, 'Agaricus': 0, 'Hygrocybe': 1, 'Cortinarius': 0, 'Russula': 0, 'Entoloma': 1}
42
+ mushroom_address_list = [
43
+ "Amanita毒蝇伞,伞菌目,鹅膏菌科,鹅膏菌属,主要分布于我国黑龙江、吉林、四川、西藏、云南等地,有毒",
44
+ "Suillus乳牛肝菌,牛肝菌目,乳牛肝菌科,乳牛肝菌属,分布于吉林、辽宁、山西、安徽、江西、浙江、湖南、四川、贵州等地,无毒",
45
+ "Boletus丽柄牛肝菌,伞菌目,牛肝菌科,牛肝菌属,分布于云南、陕西、甘肃、西藏等地,有毒",
46
+ "Lactarius松乳菇,红菇目,红菇科,乳菇属,广泛分布于亚热带松林地,无毒",
47
+ "Agaricus双孢蘑菇,伞菌目,蘑菇科,蘑菇属,广泛分布于北半球温带,无毒",
48
+ "Hygrocybe浅黄褐湿伞,伞菌目,蜡伞科,湿伞属,分布于香港(见于松仔园),有毒",
49
+ "Cortinarius掷丝膜菌,伞菌目,丝膜菌科,丝膜菌属,分布于湖南等地(夏秋季在山毛等阔叶林地上生长)",
50
+ "Russula褪色红菇,伞菌目,红菇科,红菇属,分布于河北、吉林、四川、江苏、西藏等地,无毒",
51
+ "Entoloma霍氏粉褶菌,伞菌目,粉褶菌科,粉褶菌属,主要分布于新西兰北岛和南岛西部,有毒",
52
+ ]
53
+ def image_process(image_path):
54
+ image = Image.open(image_path)
55
+ image_pt = image_processor(image,return_tensors="pt")
56
+ return image_pt
57
+
58
+ def predict(image_path):
59
+ image_pt = image_process(image_path)
60
+ images = image_pt.to(device)
61
+ #print(images['pixel_values'].shape)
62
+ outputs = torch.squeeze(model(images))
63
+ output = torch.sigmoid(outputs).cpu().detach().numpy().tolist()
64
+ label_id = np.argmax(output, axis=-1)
65
+ label_score = output[label_id]
66
+ return label_id, label_score
67
+
68
+
69
+ image_path="/kaggle/input/mush-dataset/dataset1/无毒类/Cortinarius/000_Pw3qUBVmwN8.jpg"
70
+ mushroom_class, confidence = predict(image_path)
71
+ toxic = toxic_labels[labels[mushroom_class]]
72
+ address = mushroom_address_list[mushroom_class]
73
+ print(f"the class of mushroom is { labels[mushroom_class]}, its confidence is {confidence} and it is {bool(toxic)} toxic")
mushroom_class_train.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ from PIL import Image
3
+ from PIL import ImageFile
4
+ from torch.utils.data import Dataset,DataLoader
5
+ from transformers import AutoImageProcessor, BitModel, AdamW
6
+ import torch
7
+ from datasets import load_dataset
8
+ from torch import Tensor, nn
9
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
+ import os
11
+ import numpy as np
12
+ from sklearn import metrics, model_selection
13
+ from collections import Counter
14
+
15
+
16
+ # cofig_size
17
+ configs = dict()
18
+ configs['batch_size'] = 32
19
+ configs['EPOCHS'] = 10
20
+ configs['LEARNING_RATE'] = 2e-5
21
+ configs['split_size'] = 0.1
22
+
23
+ # get class label
24
+ label_info = set()
25
+ dataset_path = "/kaggle/input/mush-dataset/dataset1"
26
+ def recur_label(dataset_path):
27
+ if os.path.isfile(dataset_path):
28
+ if ".docx" in dataset_path:
29
+ return
30
+ label_info.add(dataset_path.split("/")[-2])
31
+ else:
32
+ for file_name in os.listdir(dataset_path):
33
+ new_path = dataset_path + "/" + file_name
34
+ recur_label(new_path)
35
+ return
36
+ recur_label(dataset_path)
37
+ print(label_info)
38
+ dict_label = {}
39
+ for i, val in enumerate(label_info):
40
+ dict_label[val] = i
41
+ print(dict_label)
42
+
43
+ # deal with dataset, each image_path and label will be in the "all_data" and "all_label"
44
+ all_data = []
45
+ all_label = []
46
+ dataset_path = "/kaggle/input/mush-dataset/dataset1"
47
+ def recur_data(dataset_path, dict_label):
48
+ if os.path.isfile(dataset_path):
49
+ if ".docx" in dataset_path:
50
+ return
51
+ all_data.append(dataset_path)
52
+ label_name = dataset_path.split("/")[-2]
53
+ all_label.append(dict_label[label_name])
54
+ else:
55
+ for file_name in os.listdir(dataset_path):
56
+ new_path = dataset_path + "/" + file_name
57
+ recur_data(new_path, dict_label)
58
+ return
59
+
60
+ recur_data(dataset_path,dict_label)
61
+
62
+ # split data to train and test
63
+ train_data, test_data, train_label, test_label = model_selection.train_test_split(all_data, all_label, test_size=configs['split_size'], shuffle=True)
64
+ print("train:", len(train_data), len(train_label), Counter(train_label))
65
+ print("test:", len(test_data), len(test_label), Counter(test_label))
66
+
67
+ # trrain and test data loader
68
+ class mushroom_Dataset(Dataset):
69
+ def __init__(self, data,label, transform):
70
+ self.data = data[:]
71
+ self.label = label[:]
72
+ self.transform = transform
73
+
74
+ def __len__(self):
75
+ return len(self.data)
76
+
77
+ def __getitem__(self, index):
78
+ image_path = self.data[index]
79
+ label = self.label[index]
80
+ image = Image.open(image_path)
81
+ image = self.transform(image,return_tensors="pt")
82
+ return image, label, image_path
83
+
84
+ image_processor = AutoImageProcessor.from_pretrained("google/bit-50")
85
+ train_dataset = mushroom_Dataset(train_data, train_label, image_processor )
86
+ train_loader = DataLoader(train_dataset, batch_size=configs['batch_size'],
87
+ num_workers=4, shuffle=True, pin_memory=True)
88
+ test_dataset = mushroom_Dataset(test_data, test_label, image_processor )
89
+ test_loader = DataLoader(test_dataset, batch_size=configs['batch_size'],
90
+ num_workers=4, shuffle=True, pin_memory=True)
91
+
92
+
93
+ # model class
94
+ class mush_root_model(torch.nn.Module):
95
+ def __init__(self, num_labels=1):
96
+ super(mush_root_model, self).__init__()
97
+ self.model = BitModel.from_pretrained("google/bit-50")
98
+ self.classifier = nn.Sequential(
99
+ nn.Flatten(),
100
+ nn.Linear(2048, num_labels),
101
+ )
102
+
103
+ def forward(self, input):
104
+ outputs = self.model(**input).pooler_output
105
+ #print(outputs.shape)
106
+ logits = self.classifier(outputs)
107
+ return logits
108
+
109
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
110
+ model = mush_root_model(num_labels=len(dict_label))
111
+ model.to(device);
112
+
113
+
114
+ # train
115
+ def loss_fn(outputs, targets):
116
+ return torch.nn.CrossEntropyLoss()(outputs, targets)
117
+
118
+ optimizer = AdamW(params = model.parameters(), lr=configs['LEARNING_RATE'], weight_decay=1e-6)
119
+
120
+ def validation():
121
+ model.eval()
122
+ fin_targets=[]
123
+ fin_outputs=[]
124
+ with torch.no_grad():
125
+ for _, data in enumerate(test_loader, 0):
126
+ images = data[0].to(device)
127
+ images['pixel_values'] = torch.squeeze(images['pixel_values'])
128
+ targets = data[1].to(device, dtype = torch.float)
129
+ outputs = torch.squeeze(model(images))
130
+ fin_targets.extend(targets.cpu().detach().numpy().tolist())
131
+ fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
132
+ outputs1 = np.argmax(fin_outputs, axis=-1)
133
+ accuracy = metrics.accuracy_score(fin_targets, outputs1)
134
+ return accuracy
135
+
136
+ def train(epoch):
137
+ model.train()
138
+ for i in range(epoch):
139
+ for _,data in enumerate(train_loader, 0):
140
+ images = data[0].to(device)
141
+ images['pixel_values'] = torch.squeeze(images['pixel_values'])
142
+ targets = data[1].to(device, dtype = torch.int64)
143
+ outputs = torch.squeeze(model(images))
144
+ loss = loss_fn(outputs, targets)
145
+ if _%50 == 0:
146
+ acc = validation()
147
+ print(f'Epoch: {i}, Loss: {loss.item()}, val acc: {acc}')
148
+ model.train()
149
+ loss.backward()
150
+ optimizer.step()
151
+ optimizer.zero_grad()
152
+
153
+ train(configs['EPOCHS'])
154
+ save_path = "/kaggle/working/mush_classifier_20230801.pth"
155
+ torch.save(model.state_dict(), save_path)