|
import sys, os, distutils.core |
|
|
|
|
|
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.abspath('./detectron2')) |
|
|
|
import detectron2 |
|
import cv2 |
|
|
|
from detectron2.utils.logger import setup_logger |
|
setup_logger() |
|
|
|
|
|
from detectron2 import model_zoo |
|
from detectron2.engine import DefaultPredictor |
|
from detectron2.config import get_cfg |
|
from detectron2.utils.visualizer import Visualizer |
|
from detectron2.data import MetadataCatalog, DatasetCatalog |
|
from detectron2.utils.visualizer import Visualizer |
|
from detectron2.checkpoint import DetectionCheckpointer |
|
from detectron2.data.datasets import register_coco_instances |
|
|
|
def get_springboard_detector(): |
|
cfg = get_cfg() |
|
cfg.OUTPUT_DIR = "./output/springboard/" |
|
|
|
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) |
|
cfg.DATASETS.TEST = () |
|
cfg.DATALOADER.NUM_WORKERS = 2 |
|
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") |
|
cfg.SOLVER.IMS_PER_BATCH = 2 |
|
cfg.SOLVER.BASE_LR = 0.00025 |
|
cfg.SOLVER.MAX_ITER = 300 |
|
cfg.SOLVER.STEPS = [] |
|
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 |
|
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 |
|
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") |
|
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 |
|
predictor = DefaultPredictor(cfg) |
|
return predictor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|