|
import argparse |
|
import os |
|
import torch |
|
from model.detector import * |
|
from model.backbone import * |
|
from model.data import Therin |
|
import datetime |
|
|
|
from model.detector.fasterRCNN import FasterRCNN |
|
from model.backbone.densenet import DenseNet |
|
from model.utils.engine import * |
|
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone, _resnet_fpn_extractor |
|
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_V2_Weights |
|
from torchvision import transforms as T |
|
from PIL import Image, ImageDraw |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def label_to_text_en(l): |
|
d = {0: "creeping", 1: "crawling", 2: "stooping", 3: "climbing", 4: "other"} |
|
return d[l] |
|
|
|
def label_to_text_ja(l): |
|
d = {0: "ใใฎใณใใใงใใ", 1: "้ใฃใฆใใ", 2: "ใใใใงใใ", 3: "ใใ็ปใฃใฆใใ", 4: "ใใฎไป"} |
|
return d[l] |
|
|
|
def show_bb(img, x, y, w, h, text, textcolor, bbcolor): |
|
draw = ImageDraw.Draw(img) |
|
text_w, text_h = draw.textsize(text) |
|
label_y = y if y <= text_h else y - text_h |
|
draw.rectangle((x, label_y, x+w, label_y+h), outline=bbcolor) |
|
draw.rectangle((x, label_y, x+text_w, label_y+text_h), outline=bbcolor, fill=bbcolor) |
|
draw.text((x, label_y), text, fill=textcolor) |
|
|
|
def postprocess(true_image, o): |
|
copy_im = true_image.copy() |
|
data = o[0] |
|
boxes = data["boxes"] |
|
labels = data["labels"].tolist() |
|
scores = data["scores"].tolist() |
|
|
|
selected_labels = [] |
|
selected_scores = [] |
|
selected_indices = [] |
|
thresh = 0.30 |
|
for i, box in enumerate(boxes.tolist()): |
|
|
|
|
|
if i == scores.index(max(scores)): |
|
show_bb(copy_im, box[0],box[1],box[2],box[3], label_to_text_en(labels[i]) , (255, 255, 255), (255, 0, 0)) |
|
|
|
selected_labels.append(label_to_text_ja(labels[i])) |
|
selected_scores.append( '{:.3f}'.format(scores[i])) |
|
selected_indices.append(i) |
|
copy_im.show() |
|
copy_im.save("img/detected.png") |
|
return selected_labels, selected_scores, selected_indices |
|
|
|
|
|
def inference(image_pil): |
|
num_classes = 5 |
|
backbone = resnet_fpn_backbone('resnet18', False) |
|
model = FasterRCNN(backbone, num_classes) |
|
state_dict = torch.load('model/model/densenet-model-9-mAp--1.0.pth',map_location=device) |
|
model.load_state_dict(state_dict["model"]) |
|
|
|
model.eval() |
|
_transform = T.Compose([T.ToTensor()]) |
|
image = image_pil.convert("RGB") |
|
image = _transform(image) |
|
with torch.no_grad(): |
|
output = model([image]) |
|
res = postprocess(image_pil, output) |
|
return output, res |
|
|