DECO / data /preprocess /hot_dca.py
ac5113's picture
added files
99a05f0
import os
import cv2
import numpy as np
from tqdm import tqdm
import sys
import imagesize
import argparse
import torch
import pandas as pd
import json
import monai.metrics as metrics
HOT_TRAIN_SPLIT = "/ps/scratch/ps_shared/ychen2/4shashank/split/hot_train.odgt"
HOT_VAL_SPLIT = "/ps/scratch/ps_shared/ychen2/4shashank/split/hot_validation.odgt"
HOT_TEST_SPLIT = "/ps/scratch/ps_shared/ychen2/4shashank/split/hot_test.odgt"
def metric(mask, pred, back=True):
iou = metrics.compute_meaniou(pred, mask, back, False)
iou = iou.mean()
return iou
def combine_hot_prox_split(split):
if split == 'train':
with open(HOT_TRAIN_SPLIT, "r") as f:
records = [
json.loads(line.strip("\n")) for line in f.readlines()
]
elif split == 'val':
with open(HOT_VAL_SPLIT, "r") as f:
records = [
json.loads(line.strip("\n")) for line in f.readlines()
]
elif split == 'test':
with open(HOT_TEST_SPLIT, "r") as f:
records = [
json.loads(line.strip("\n")) for line in f.readlines()
]
return records
def hot_extract(img_dataset_path, smpl_params_path, dca_csv_path, out_dir, split=None, vis_path=None, visualize=False, include_supporting=True):
n_vertices = 6890
# structs we use
imgnames_ = []
poses_, shapes_, transls_ = [], [], []
cams_k_ = []
polygon_2d_contact_ = []
contact_3d_labels_ = []
scene_seg_, part_seg_ = [], []
img_dir = os.path.join(img_dataset_path, 'images', 'training')
smpl_params = np.load(smpl_params_path)
# smpl_params = np.load(smpl_params_path, allow_pickle=True)
# smpl_params = smpl_params['arr_0'].item()
annotations_dir = img_dir.replace('images', 'annotations')
records = combine_hot_prox_split(split)
# load dca csv
dca_csv = pd.read_csv(dca_csv_path)
iou_thresh = 0
num_with_3d_contact = 0
focal_length_accumulator = []
for i, record in enumerate(tqdm(records, dynamic_ncols=True)):
imgpath = record['fpath_img']
imgname = os.path.basename(imgpath)
# save image in temp_images
if visualize:
img = cv2.imread(os.path.join(img_dir, imgname))
cv2.imwrite(os.path.join(vis_path, os.path.basename(imgname)), img)
# load image to get the size
img_w, img_h = record["width"], record["height"]
# get mask anns
polygon_2d_contact_path = os.path.join(annotations_dir, os.path.splitext(imgname)[0] + '.png')
# Get 3D contact annotations from DCA mturk csv
dca_row = dca_csv.loc[dca_csv['imgnames'] == imgname] # if no imgnames column, run scripts/datascripts/add_imgname_column_to_deco_csv.py
if len(dca_row) == 0:
contact_3d_labels = []
continue
else:
num_with_3d_contact += 1
supporting_object = dca_row['supporting_object'].values[0]
vertices = eval(dca_row['vertices'].values[0])
contact_3d_list = vertices[os.path.join('hot/training/', imgname)]
# Aggregate values in all keys
contact_3d_idx = []
for item in contact_3d_list:
# one iteration loop as it is a list of one dict key value
for k, v in item.items():
if include_supporting:
contact_3d_idx.extend(v)
else:
if k != 'SUPPORTING':
contact_3d_idx.extend(v)
# removed repeated values
contact_3d_idx = list(set(contact_3d_idx))
contact_3d_labels = np.zeros(n_vertices) # smpl has 6980 vertices
contact_3d_labels[contact_3d_idx] = 1.
# find indices that match the imname
inds = np.where(smpl_params['imgname'] == os.path.join(img_dir, imgname))[0]
select_inds = []
ious = []
for ind in inds:
# part mask
part_path = smpl_params['part_seg'][ind]
# load the part_mask
part_mask = cv2.imread(part_path)
# binarize the part mask
part_mask = np.where(part_mask > 0, 1, 0)
# save part mask
if visualize:
cv2.imwrite(os.path.join(vis_path, os.path.basename(part_path)), part_mask*255)
# load gt polygon mask
polygon_2d_contact = cv2.imread(polygon_2d_contact_path)
# binarize the gt polygon mask
polygon_2d_contact = np.where(polygon_2d_contact > 0, 1, 0)
# save gt polygon mask in temp_images
if visualize:
cv2.imwrite(os.path.join(vis_path, os.path.basename(polygon_2d_contact_path)), polygon_2d_contact*255)
polygon_2d_contact = torch.from_numpy(polygon_2d_contact)[None,:].permute(0,3,1,2)
part_mask = torch.from_numpy(part_mask)[None,:].permute(0,3,1,2)
# compute iou with part mask and gt polygon mask
iou = metric(polygon_2d_contact, part_mask)
if iou > iou_thresh:
ious.append(iou)
select_inds.append(ind)
# get select_ind with maximum iou
if len(select_inds) > 0:
max_iou_ind = select_inds[np.argmax(ious)]
else:
continue
# part mask
part_path = smpl_params['part_seg'][max_iou_ind]
# scene mask
scene_path = smpl_params['scene_seg'][max_iou_ind]
# get smpl params
pose = smpl_params['pose'][max_iou_ind]
shape = smpl_params['shape'][max_iou_ind]
transl = smpl_params['global_t'][max_iou_ind]
focal_length = smpl_params['focal_l'][max_iou_ind]
camC = np.array([[img_w//2, img_h//2]])
# read GT 2D keypoints
K = np.eye(3, dtype=np.float64)
K[0, 0] = focal_length
K[1, 1] = focal_length
K[:2, 2:] = camC.T
# store data
imgnames_.append(os.path.join(img_dir, imgname))
polygon_2d_contact_.append(polygon_2d_contact_path)
# we use the heuristic that the 3D contact labeled is for the person with maximum iou with HOT contacts
contact_3d_labels_.append(contact_3d_labels)
scene_seg_.append(scene_path)
part_seg_.append(part_path)
poses_.append(pose.squeeze())
transls_.append(transl.squeeze())
shapes_.append(shape.squeeze())
cams_k_.append(K.tolist())
focal_length_accumulator.append(focal_length)
print('Average focal length: ', np.mean(focal_length_accumulator))
print('Median focal length: ', np.median(focal_length_accumulator))
print('Std Dev focal length: ', np.std(focal_length_accumulator))
# store the data struct
os.makedirs(out_dir, exist_ok=True)
out_file = os.path.join(out_dir, f'hot_dca_supporting_{str(include_supporting)}_{split}.npz')
np.savez(out_file, imgname=imgnames_,
pose=poses_,
transl=transls_,
shape=shapes_,
cam_k=cams_k_,
polygon_2d_contact=polygon_2d_contact_,
contact_label=contact_3d_labels_,
scene_seg=scene_seg_,
part_seg=part_seg_
)
print(f'Total number of rows: {len(imgnames_)}')
print('Saved to ', out_file)
print(f'Number of images with 3D contact labels: {num_with_3d_contact}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--img_dataset_path', type=str, default='/ps/project/datasets/HOT/Contact_Data/')
parser.add_argument('--smpl_params_path', type=str, default='/ps/scratch/ps_shared/stripathi/deco/4agniv/hot/hot.npz')
parser.add_argument('--dca_csv_path', type=str, default='/ps/scratch/ps_shared/stripathi/deco/4agniv/hot/dca.csv')
parser.add_argument('--out_dir', type=str, default='/is/cluster/work/stripathi/pycharm_remote/dca_contact/data/dataset_extras')
parser.add_argument('--vis_path', type=str, default='/is/cluster/work/stripathi/pycharm_remote/dca_contact/temp_images')
parser.add_argument('--visualize', action='store_true', default=False)
parser.add_argument('--include_supporting', action='store_true', default=False)
parser.add_argument('--split', type=str, default='train')
args = parser.parse_args()
hot_extract(img_dataset_path=args.img_dataset_path,
smpl_params_path=args.smpl_params_path,
dca_csv_path=args.dca_csv_path,
out_dir=args.out_dir,
vis_path=args.vis_path,
visualize=args.visualize,
split=args.split,
include_supporting=args.include_supporting)