# import sys, os, distutils.core | |
# # os.system('python -m pip install pyyaml==5.3.1') | |
# # dist = distutils.core.run_setup("./detectron2/setup.py") | |
# # temp = ' '.join([f"'{x}'" for x in dist.install_requires]) | |
# # cmd = "python -m pip install {0}".format(temp) | |
# # os.system(cmd) | |
# sys.path.insert(0, os.path.abspath('./detectron2')) | |
# import detectron2 | |
# import cv2 | |
# from detectron2.utils.logger import setup_logger | |
# setup_logger() | |
# # from detectron2.modeling import build_model | |
# from detectron2 import model_zoo | |
# from detectron2.engine import DefaultPredictor | |
# from detectron2.config import get_cfg | |
# from detectron2.utils.visualizer import Visualizer | |
# from detectron2.data import MetadataCatalog, DatasetCatalog | |
# from detectron2.utils.visualizer import Visualizer | |
# from detectron2.checkpoint import DetectionCheckpointer | |
# from detectron2.data.datasets import register_coco_instances | |
# cfg = get_cfg() | |
# cfg.OUTPUT_DIR = "./output/springboard/" | |
# # model = build_model(cfg) # returns a torch.nn.Module | |
# cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) | |
# cfg.DATASETS.TEST = () | |
# cfg.DATALOADER.NUM_WORKERS = 2 | |
# cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") # Let training initialize from model zoo | |
# cfg.SOLVER.IMS_PER_BATCH = 2 # This is the real "batch size" commonly known to deep learning people | |
# cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR | |
# cfg.SOLVER.MAX_ITER = 300 # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset | |
# cfg.SOLVER.STEPS = [] # do not decay learning rate | |
# cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # The "RoIHead batch size". 128 is faster, and good enough for this toy dataset (default: 512) | |
# cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 | |
# cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") # path to the model we just trained | |
# cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set a custom testing threshold | |
# predictor = DefaultPredictor(cfg) | |
# register_coco_instances("springboard_trains", {}, "./coco_annotations/springboard/train.json", "../data/Boards/spring") | |
# register_coco_instances("springboard_vals", {}, "./coco_annotations/springboard/val.json", "../data/Boards/spring") | |
# from detectron2.utils.visualizer import ColorMode | |
# splash_metadata = MetadataCatalog.get('springboard_vals') | |
# dataset_dicts = DatasetCatalog.get("springboard_vals") | |
# outputs_array = [] | |
# for d in dataset_dicts: | |
# im = cv2.imread(d["file_name"]) | |
# outputs = predictor(im) | |
# outputs_array.append(outputs) # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format | |
# v = Visualizer(im[:, :, ::-1], | |
# metadata=splash_metadata, | |
# scale=0.5, | |
# instance_mode=ColorMode.IMAGE_BW # remove the colors of unsegmented pixels. This option is only available for segmentation models | |
# ) | |
# out = v.draw_instance_predictions(outputs["instances"].to("cpu")) | |
# img = out.get_image()[:, :, ::-1] | |
# filename = os.path.join("./output/", d["file_name"][3:]) | |
# print(filename) | |
# if not cv2.imwrite(filename, img): | |
# print('no image written') | |
import torch | |
import numpy as np | |
import math | |
import cv2 | |
import sys, os | |
from matplotlib import image | |
from matplotlib import pyplot as plt | |
from models.detectron2.springboard_detector_setup import get_springboard_detector | |
from models.detectron2.platform_detector_setup import get_platform_detector | |
from models.pose_estimator.pose_estimator_model_setup import get_pose_estimation | |
# springboard MICRO PROGRAM | |
# returns "left" or "right" depending on whether the board is on the left or right side of the frame | |
def find_which_side_board_on(output): | |
pred_classes = output['instances'].pred_classes.cpu().numpy() | |
platforms = np.where(pred_classes == 0)[0] | |
scores = output['instances'].scores[platforms] | |
if len(scores) == 0: | |
return | |
pred_masks = output['instances'].pred_masks[platforms] | |
max_instance = torch.argmax(scores) | |
pred_mask = np.array(pred_masks[max_instance].cpu()) | |
for i in range(len(pred_mask[0])//2): | |
if sum(pred_mask[:, i]) != 0: | |
return "left" | |
elif sum(pred_mask[:, len(pred_mask[0]) - i - 1]) != 0: | |
return "right" | |
return None | |
def board_end(output, board_side=None): | |
# pred_classes = output['instances'].pred_classes.cpu().numpy() | |
# splashes = np.where(pred_classes == 0)[0] | |
# scores = output['instances'].scores[splashes] | |
# if len(scores) == 0: | |
# return | |
# pred_masks = output['instances'].pred_masks[splashes] | |
# max_instance = torch.argmax(scores) | |
# pred_mask = pred_masks[max_instance] # splash instance with highest confidence | |
pred_classes = output['instances'].pred_classes.cpu().numpy() | |
platforms = np.where(pred_classes == 0)[0] | |
scores = output['instances'].scores[platforms] | |
if len(scores) == 0: | |
return | |
pred_masks = output['instances'].pred_masks[platforms] | |
max_instance = torch.argmax(scores) | |
pred_mask = np.array(pred_masks[max_instance].cpu()) # splash instance with highest confidence | |
# need to figure out whether springboard is on left or right side of frame, then need to find where the edge is | |
if board_side is None: | |
board_side = find_which_side_board_on(output) | |
if board_side == "left": | |
for i in range(len(pred_mask[0]) - 1, -1, -1): | |
if sum(pred_mask[:, i]) != 0: | |
trues = np.where(pred_mask[:, i])[0] | |
return (i, min(trues)) | |
if board_side == "right": | |
for i in range(len(pred_mask[0])): | |
if sum(pred_mask[:, i]) != 0: | |
trues = np.where(pred_mask[:, i])[0] | |
return (i, min(trues)) | |
return None | |
def draw_board_end_coord(im, coord): | |
print("hello, im in the drawing func") | |
image = cv2.circle(im, (int(coord[0]),int(coord[1])), radius=10, color=(0, 0, 255), thickness=-1) | |
filename = os.path.join("./output/board_end/", d["file_name"][3:]) | |
print(filename) | |
if not cv2.imwrite(filename, image): | |
print(filename) | |
print("file failed to write") | |
# loops over each image, plots a point for the end of board | |
# i = 0 | |
# for d in dataset_dicts: | |
# im = cv2.imread(d["file_name"]) | |
# outputs = predictor(im) | |
# # to draw a point on co-ordinate (200,300) | |
# coord = board_end(outputs) | |
# if coord == None: | |
# continue | |
# # plt.plot(coord[0], coord[1], marker='v', color="white") | |
# draw_board_end_coord(im, coord) | |
# i+=1 | |
## TODO: ADD POSE ESTIMATOR, AND CALCULATE DISTANCE FROM BOARD | |
# PLOT RESULTS OF ONE FULL DIVE | |
# KEYPOINT_INDEXES = { | |
# 0: 'r ankle', | |
# 1: 'r knee', | |
# 2: 'r hip', | |
# 3: 'l hip', | |
# 4: 'l knee', | |
# 5: 'l ankle', | |
# 6: 'pelvis', | |
# 7: 'thorax', | |
# 8: 'upper neck', | |
# 9: 'head', | |
# 10: 'r wrist', | |
# 11: 'r elbow', | |
# 12: 'r shoulder', | |
# 13: 'l shoulder', | |
# 14: 'l elbow', | |
# 15: 'l wrist', | |
# } | |
# demo_image = '../data/Boards/spring/img_17_10_00014517.jpg' | |
# im = cv2.imread(demo_image) | |
# pose_pred = get_pose_estimation(demo_image) | |
# print("pose_pred", pose_pred) | |
# predictor = get_springboard_detector() | |
# outputs = predictor(im) | |
# # to draw a point on co-ordinate (200,300) | |
# coord = board_end(outputs) | |
def draw_two_coord(im, coord1, coord2, filename): | |
print("hello, im in the drawing func") | |
image = cv2.circle(im, (int(coord1[0]),int(coord1[1])), radius=5, color=(0, 0, 255), thickness=-1) | |
image = cv2.circle(image, (int(coord2[0]),int(coord2[1])), radius=5, color=(0, 255, 0), thickness=-1) | |
print(filename) | |
if not cv2.imwrite(filename, image): | |
print(filename) | |
print("file failed to write") | |
# draw_two_coord(im, coord, np.array(pose_pred)[0][5], filename==os.path.join("./output/board_end/", demo_image[3:])) | |
# print("pose_pred.shape", np.array(pose_pred).shape) | |
# print("coord.shape", np.array(coord).shape) | |
# print("DISTANCE BETWEEN END BOARD AND LEFT ANKLE:", math.dist(np.array(pose_pred)[0][5], np.array(coord))) | |
def calculate_distance_from_springboard_for_one_frame(filepath, visualize=False, dive_folder_num="", springboard_detector=None, pose_pred=None, pose_model=None, board_end_coord=None, board_side=None): | |
if springboard_detector is None: | |
springboard_detector = get_springboard_detector() | |
if pose_pred is None: | |
diver_box, pose_pred = get_pose_estimation(filepath, pose_model=pose_model) | |
im = cv2.imread(filepath) | |
outputs = springboard_detector(im) | |
if board_end_coord is None: | |
board_end_coord = board_end(outputs, board_side=board_side) | |
minDist = None | |
if board_end_coord is not None and pose_pred is not None and len(board_end_coord) == 2: | |
minDist = float('inf') | |
for i in range(len(np.array(pose_pred)[0])): | |
distance = math.dist(np.array(pose_pred)[0][i], np.array(board_end_coord)) | |
if distance < minDist: | |
minDist = distance | |
minJoint = i | |
if visualize: | |
file_name = filepath.split('/')[-1] | |
folder = "./output/data/distance_from_board/{}".format(dive_folder_num) | |
out_filename = os.path.join(folder, file_name) | |
if not os.path.exists(folder): | |
os.makedirs(folder) | |
draw_two_coord(im, board_end_coord, np.array(pose_pred)[0][minJoint], filename=out_filename) | |
## more verbose | |
# else: | |
# print("springboard or diver not detected in", filepath) | |
# if board_end_coord is None: | |
# print("springboard not detected in", filepath) | |
# if pose_pred is None: | |
# print("diver not detected in", filepath) | |
return minDist | |
def calculate_distance_from_platform_for_one_frame(filepath, im=None, visualize=False, dive_folder_num="", platform_detector=None, pose_pred=None, diver_detector=None, pose_model=None, board_end_coord=None, board_side=None): | |
if platform_detector is None: | |
platform_detector = get_platform_detector() | |
if pose_pred is None: | |
diver_box, pose_pred = get_pose_estimation(filepath, image_bgr=im, diver_detector=diver_detector, pose_model=pose_model) | |
if im is None and filepath != "": | |
im = cv2.imread(filepath) | |
if board_end_coord is None: | |
outputs = platform_detector(im) | |
board_end_coord = board_end(outputs, board_side=board_side) | |
minDist = None | |
if board_end_coord is not None and pose_pred is not None and len(board_end_coord) == 2: | |
minDist = float('inf') | |
for i in range(len(np.array(pose_pred)[0])): | |
distance = math.dist(np.array(pose_pred)[0][i], np.array(board_end_coord)) | |
if distance < minDist: | |
minDist = distance | |
minJoint = i | |
if visualize: | |
file_name = filepath.split('/')[-1] | |
folder = "./output/data/distance_from_board/{}".format(dive_folder_num) | |
out_filename = os.path.join(folder, file_name) | |
if not os.path.exists(folder): | |
os.makedirs(folder) | |
draw_two_coord(im, board_end_coord, np.array(pose_pred)[0][minJoint], filename=out_filename) | |
## more verbose | |
# else: | |
# print("platform or diver not detected in", filepath) | |
# if board_end_coord is None: | |
# print("platform not detected in", filepath) | |
# if pose_pred is None: | |
# print("diver not detected in", filepath) | |
return minDist | |
# distances = [] | |
# directory = "./FineDiving/datasets/FINADiving_MTL_256s/17/73/" | |
# file_names = os.listdir(directory) | |
# for file_name in file_names: | |
# path = os.path.join(directory, file_name) | |
# pose_pred = get_pose_estimation(path) | |
# print("PATH IM_17_73:", path) | |
# im = cv2.imread(path) | |
# outputs = predictor(im) | |
# coord = board_end(outputs) | |
# if coord is not None and pose_pred is not None and len(coord) == 2: | |
# distance = math.dist(np.array(pose_pred)[0][5], np.array(coord)) | |
# if distance is None: | |
# distances.append(0) | |
# else: | |
# distances.append(distance) | |
# filename = os.path.join("./output/data/img_17_73/", file_name) | |
# draw_two_coord(im, coord, np.array(pose_pred)[0][5], filename=filename) | |
# else: | |
# distances.append(0) | |
# plt.plot(range(len(distances)), distances) | |
# plt.savefig('./output/data/img_17_73/img_17_73_board_dist_graph.png') | |