Spaces:
No application file
No application file
Upload 3 files
Browse files- mush_classifier_20230801.pth +3 -0
- mushroom_class_load_predict.py +73 -0
- mushroom_class_train.py +155 -0
mush_classifier_20230801.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:851cac064760eab030275140b81d7e04595a6bb283154f87f72c1b6bc5cf6166
|
3 |
+
size 94141978
|
mushroom_class_load_predict.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
from PIL import Image
|
3 |
+
from PIL import ImageFile
|
4 |
+
from torch.utils.data import Dataset,DataLoader
|
5 |
+
from transformers import AutoImageProcessor, BitModel, AdamW
|
6 |
+
import torch
|
7 |
+
from datasets import load_dataset
|
8 |
+
from torch import Tensor, nn
|
9 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
from sklearn import metrics, model_selection
|
13 |
+
from collections import Counter
|
14 |
+
|
15 |
+
# model class
|
16 |
+
class mush_root_model(torch.nn.Module):
|
17 |
+
def __init__(self, num_labels=1):
|
18 |
+
super(mush_root_model, self).__init__()
|
19 |
+
self.model = BitModel.from_pretrained("google/bit-50")
|
20 |
+
self.classifier = nn.Sequential(
|
21 |
+
nn.Flatten(),
|
22 |
+
nn.Linear(2048, num_labels),
|
23 |
+
)
|
24 |
+
|
25 |
+
def forward(self, input):
|
26 |
+
outputs = self.model(**input).pooler_output
|
27 |
+
#print(outputs.shape)
|
28 |
+
logits = self.classifier(outputs)
|
29 |
+
return logits
|
30 |
+
|
31 |
+
# load model
|
32 |
+
model_path="/kaggle/input/mush-room-model-class/mush_classifier_20230801.pth"
|
33 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
34 |
+
model = mush_root_model(num_labels=9)
|
35 |
+
model.load_state_dict(torch.load(model_path))
|
36 |
+
model.to(device)
|
37 |
+
image_processor = AutoImageProcessor.from_pretrained("google/bit-50")
|
38 |
+
|
39 |
+
# label setting
|
40 |
+
labels = ['Amanita', 'Suillus', 'Boletus', 'Lactarius', 'Agaricus', 'Hygrocybe', 'Cortinarius', 'Russula', 'Entoloma']
|
41 |
+
toxic_labels = {'Amanita': 1, 'Suillus': 0, 'Boletus': 0, 'Lactarius': 0, 'Agaricus': 0, 'Hygrocybe': 1, 'Cortinarius': 0, 'Russula': 0, 'Entoloma': 1}
|
42 |
+
mushroom_address_list = [
|
43 |
+
"Amanita毒蝇伞,伞菌目,鹅膏菌科,鹅膏菌属,主要分布于我国黑龙江、吉林、四川、西藏、云南等地,有毒",
|
44 |
+
"Suillus乳牛肝菌,牛肝菌目,乳牛肝菌科,乳牛肝菌属,分布于吉林、辽宁、山西、安徽、江西、浙江、湖南、四川、贵州等地,无毒",
|
45 |
+
"Boletus丽柄牛肝菌,伞菌目,牛肝菌科,牛肝菌属,分布于云南、陕西、甘肃、西藏等地,有毒",
|
46 |
+
"Lactarius松乳菇,红菇目,红菇科,乳菇属,广泛分布于亚热带松林地,无毒",
|
47 |
+
"Agaricus双孢蘑菇,伞菌目,蘑菇科,蘑菇属,广泛分布于北半球温带,无毒",
|
48 |
+
"Hygrocybe浅黄褐湿伞,伞菌目,蜡伞科,湿伞属,分布于香港(见于松仔园),有毒",
|
49 |
+
"Cortinarius掷丝膜菌,伞菌目,丝膜菌科,丝膜菌属,分布于湖南等地(夏秋季在山毛等阔叶林地上生长)",
|
50 |
+
"Russula褪色红菇,伞菌目,红菇科,红菇属,分布于河北、吉林、四川、江苏、西藏等地,无毒",
|
51 |
+
"Entoloma霍氏粉褶菌,伞菌目,粉褶菌科,粉褶菌属,主要分布于新西兰北岛和南岛西部,有毒",
|
52 |
+
]
|
53 |
+
def image_process(image_path):
|
54 |
+
image = Image.open(image_path)
|
55 |
+
image_pt = image_processor(image,return_tensors="pt")
|
56 |
+
return image_pt
|
57 |
+
|
58 |
+
def predict(image_path):
|
59 |
+
image_pt = image_process(image_path)
|
60 |
+
images = image_pt.to(device)
|
61 |
+
#print(images['pixel_values'].shape)
|
62 |
+
outputs = torch.squeeze(model(images))
|
63 |
+
output = torch.sigmoid(outputs).cpu().detach().numpy().tolist()
|
64 |
+
label_id = np.argmax(output, axis=-1)
|
65 |
+
label_score = output[label_id]
|
66 |
+
return label_id, label_score
|
67 |
+
|
68 |
+
|
69 |
+
image_path="/kaggle/input/mush-dataset/dataset1/无毒类/Cortinarius/000_Pw3qUBVmwN8.jpg"
|
70 |
+
mushroom_class, confidence = predict(image_path)
|
71 |
+
toxic = toxic_labels[labels[mushroom_class]]
|
72 |
+
address = mushroom_address_list[mushroom_class]
|
73 |
+
print(f"the class of mushroom is { labels[mushroom_class]}, its confidence is {confidence} and it is {bool(toxic)} toxic")
|
mushroom_class_train.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
from PIL import Image
|
3 |
+
from PIL import ImageFile
|
4 |
+
from torch.utils.data import Dataset,DataLoader
|
5 |
+
from transformers import AutoImageProcessor, BitModel, AdamW
|
6 |
+
import torch
|
7 |
+
from datasets import load_dataset
|
8 |
+
from torch import Tensor, nn
|
9 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
from sklearn import metrics, model_selection
|
13 |
+
from collections import Counter
|
14 |
+
|
15 |
+
|
16 |
+
# cofig_size
|
17 |
+
configs = dict()
|
18 |
+
configs['batch_size'] = 32
|
19 |
+
configs['EPOCHS'] = 10
|
20 |
+
configs['LEARNING_RATE'] = 2e-5
|
21 |
+
configs['split_size'] = 0.1
|
22 |
+
|
23 |
+
# get class label
|
24 |
+
label_info = set()
|
25 |
+
dataset_path = "/kaggle/input/mush-dataset/dataset1"
|
26 |
+
def recur_label(dataset_path):
|
27 |
+
if os.path.isfile(dataset_path):
|
28 |
+
if ".docx" in dataset_path:
|
29 |
+
return
|
30 |
+
label_info.add(dataset_path.split("/")[-2])
|
31 |
+
else:
|
32 |
+
for file_name in os.listdir(dataset_path):
|
33 |
+
new_path = dataset_path + "/" + file_name
|
34 |
+
recur_label(new_path)
|
35 |
+
return
|
36 |
+
recur_label(dataset_path)
|
37 |
+
print(label_info)
|
38 |
+
dict_label = {}
|
39 |
+
for i, val in enumerate(label_info):
|
40 |
+
dict_label[val] = i
|
41 |
+
print(dict_label)
|
42 |
+
|
43 |
+
# deal with dataset, each image_path and label will be in the "all_data" and "all_label"
|
44 |
+
all_data = []
|
45 |
+
all_label = []
|
46 |
+
dataset_path = "/kaggle/input/mush-dataset/dataset1"
|
47 |
+
def recur_data(dataset_path, dict_label):
|
48 |
+
if os.path.isfile(dataset_path):
|
49 |
+
if ".docx" in dataset_path:
|
50 |
+
return
|
51 |
+
all_data.append(dataset_path)
|
52 |
+
label_name = dataset_path.split("/")[-2]
|
53 |
+
all_label.append(dict_label[label_name])
|
54 |
+
else:
|
55 |
+
for file_name in os.listdir(dataset_path):
|
56 |
+
new_path = dataset_path + "/" + file_name
|
57 |
+
recur_data(new_path, dict_label)
|
58 |
+
return
|
59 |
+
|
60 |
+
recur_data(dataset_path,dict_label)
|
61 |
+
|
62 |
+
# split data to train and test
|
63 |
+
train_data, test_data, train_label, test_label = model_selection.train_test_split(all_data, all_label, test_size=configs['split_size'], shuffle=True)
|
64 |
+
print("train:", len(train_data), len(train_label), Counter(train_label))
|
65 |
+
print("test:", len(test_data), len(test_label), Counter(test_label))
|
66 |
+
|
67 |
+
# trrain and test data loader
|
68 |
+
class mushroom_Dataset(Dataset):
|
69 |
+
def __init__(self, data,label, transform):
|
70 |
+
self.data = data[:]
|
71 |
+
self.label = label[:]
|
72 |
+
self.transform = transform
|
73 |
+
|
74 |
+
def __len__(self):
|
75 |
+
return len(self.data)
|
76 |
+
|
77 |
+
def __getitem__(self, index):
|
78 |
+
image_path = self.data[index]
|
79 |
+
label = self.label[index]
|
80 |
+
image = Image.open(image_path)
|
81 |
+
image = self.transform(image,return_tensors="pt")
|
82 |
+
return image, label, image_path
|
83 |
+
|
84 |
+
image_processor = AutoImageProcessor.from_pretrained("google/bit-50")
|
85 |
+
train_dataset = mushroom_Dataset(train_data, train_label, image_processor )
|
86 |
+
train_loader = DataLoader(train_dataset, batch_size=configs['batch_size'],
|
87 |
+
num_workers=4, shuffle=True, pin_memory=True)
|
88 |
+
test_dataset = mushroom_Dataset(test_data, test_label, image_processor )
|
89 |
+
test_loader = DataLoader(test_dataset, batch_size=configs['batch_size'],
|
90 |
+
num_workers=4, shuffle=True, pin_memory=True)
|
91 |
+
|
92 |
+
|
93 |
+
# model class
|
94 |
+
class mush_root_model(torch.nn.Module):
|
95 |
+
def __init__(self, num_labels=1):
|
96 |
+
super(mush_root_model, self).__init__()
|
97 |
+
self.model = BitModel.from_pretrained("google/bit-50")
|
98 |
+
self.classifier = nn.Sequential(
|
99 |
+
nn.Flatten(),
|
100 |
+
nn.Linear(2048, num_labels),
|
101 |
+
)
|
102 |
+
|
103 |
+
def forward(self, input):
|
104 |
+
outputs = self.model(**input).pooler_output
|
105 |
+
#print(outputs.shape)
|
106 |
+
logits = self.classifier(outputs)
|
107 |
+
return logits
|
108 |
+
|
109 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
110 |
+
model = mush_root_model(num_labels=len(dict_label))
|
111 |
+
model.to(device);
|
112 |
+
|
113 |
+
|
114 |
+
# train
|
115 |
+
def loss_fn(outputs, targets):
|
116 |
+
return torch.nn.CrossEntropyLoss()(outputs, targets)
|
117 |
+
|
118 |
+
optimizer = AdamW(params = model.parameters(), lr=configs['LEARNING_RATE'], weight_decay=1e-6)
|
119 |
+
|
120 |
+
def validation():
|
121 |
+
model.eval()
|
122 |
+
fin_targets=[]
|
123 |
+
fin_outputs=[]
|
124 |
+
with torch.no_grad():
|
125 |
+
for _, data in enumerate(test_loader, 0):
|
126 |
+
images = data[0].to(device)
|
127 |
+
images['pixel_values'] = torch.squeeze(images['pixel_values'])
|
128 |
+
targets = data[1].to(device, dtype = torch.float)
|
129 |
+
outputs = torch.squeeze(model(images))
|
130 |
+
fin_targets.extend(targets.cpu().detach().numpy().tolist())
|
131 |
+
fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
|
132 |
+
outputs1 = np.argmax(fin_outputs, axis=-1)
|
133 |
+
accuracy = metrics.accuracy_score(fin_targets, outputs1)
|
134 |
+
return accuracy
|
135 |
+
|
136 |
+
def train(epoch):
|
137 |
+
model.train()
|
138 |
+
for i in range(epoch):
|
139 |
+
for _,data in enumerate(train_loader, 0):
|
140 |
+
images = data[0].to(device)
|
141 |
+
images['pixel_values'] = torch.squeeze(images['pixel_values'])
|
142 |
+
targets = data[1].to(device, dtype = torch.int64)
|
143 |
+
outputs = torch.squeeze(model(images))
|
144 |
+
loss = loss_fn(outputs, targets)
|
145 |
+
if _%50 == 0:
|
146 |
+
acc = validation()
|
147 |
+
print(f'Epoch: {i}, Loss: {loss.item()}, val acc: {acc}')
|
148 |
+
model.train()
|
149 |
+
loss.backward()
|
150 |
+
optimizer.step()
|
151 |
+
optimizer.zero_grad()
|
152 |
+
|
153 |
+
train(configs['EPOCHS'])
|
154 |
+
save_path = "/kaggle/working/mush_classifier_20230801.pth"
|
155 |
+
torch.save(model.state_dict(), save_path)
|