Ubuntu
[mod] percent
5901ece
import argparse
import os
import torch
from model.detector import *
from model.backbone import *
from model.data import Therin
import datetime
from model.detector.fasterRCNN import FasterRCNN
from model.backbone.densenet import DenseNet
from model.utils.engine import *
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone, _resnet_fpn_extractor
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision import transforms as T
from PIL import Image, ImageDraw
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def label_to_text_en(l):
d = {0: "creeping", 1: "crawling", 2: "stooping", 3: "climbing", 4: "other"}
return d[l]
def label_to_text_ja(l):
d = {0: "ใ—ใฎใณใ“ใ‚“ใงใ„ใ‚‹", 1: "้€™ใฃใฆใ„ใ‚‹", 2: "ใ‹ใŒใ‚“ใงใ„ใ‚‹", 3: "ใ‚ˆใ˜็™ปใฃใฆใ„ใ‚‹", 4: "ใใฎไป–"}
return d[l]
def show_bb(img, x, y, w, h, text, textcolor, bbcolor):
draw = ImageDraw.Draw(img)
text_w, text_h = draw.textsize(text)
label_y = y if y <= text_h else y - text_h
draw.rectangle((x, label_y, x+w, label_y+h), outline=bbcolor)
draw.rectangle((x, label_y, x+text_w, label_y+text_h), outline=bbcolor, fill=bbcolor)
draw.text((x, label_y), text, fill=textcolor)
def postprocess(true_image, o):
copy_im = true_image.copy()
data = o[0]
boxes = data["boxes"]
labels = data["labels"].tolist()
scores = data["scores"].tolist()
selected_labels = []
selected_scores = []
selected_indices = []
thresh = 0.30
for i, box in enumerate(boxes.tolist()):
# if scores[i] > thresh:
if i == scores.index(max(scores)):
show_bb(copy_im, box[0],box[1],box[2],box[3], label_to_text_en(labels[i]) , (255, 255, 255), (255, 0, 0)) #xywh
selected_labels.append(label_to_text_ja(labels[i]))
selected_scores.append( '{:.3f}'.format(scores[i]))
selected_indices.append(i)
copy_im.show()
copy_im.save("img/detected.png")
return selected_labels, selected_scores, selected_indices
def inference(image_pil):
num_classes = 5
backbone = resnet_fpn_backbone('resnet18', False)
model = FasterRCNN(backbone, num_classes)
state_dict = torch.load('model/model/densenet-model-9-mAp--1.0.pth',map_location=device)
model.load_state_dict(state_dict["model"])
model.eval()
_transform = T.Compose([T.ToTensor()])
image = image_pil.convert("RGB")
image = _transform(image)
with torch.no_grad():
output = model([image])
res = postprocess(image_pil, output)
return output, res