diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..5e05fa914db14f49952a8405d5554d4b95b84dfd
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,44 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/01.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/02.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/03.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/04.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/05.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/06.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/07.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/08.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/09.mp4 filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..1fee058a018d4062806512f1bac29dca6b96e876
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+data_ssc/
+demo_out/
+pretrained_models/*
+.vscode/
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cf5a2adf25531ebddfc4702c1fa68df337c1e2f0
--- /dev/null
+++ b/README.md
@@ -0,0 +1,13 @@
+---
+title: AiOS
+emoji: ⚡
+colorFrom: blue
+colorTo: indigo
+sdk: gradio
+python_version: 3.9
+sdk_version: 4.38.1
+app_file: app.py
+pinned: false
+---
+
+Check out the configuration reference at https://huggingface.co./docs/hub/spaces-config-reference
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3ebe114d4734c6e50ba3f8c9f725eb1ec2204f2
--- /dev/null
+++ b/app.py
@@ -0,0 +1,126 @@
+import os
+import sys
+import subprocess
+import pkg_resources
+
+def is_package_installed(package_name):
+ try:
+ pkg_resources.get_distribution(package_name)
+ return True
+ except pkg_resources.DistributionNotFound:
+ return False
+
+if is_package_installed("mmcv"):
+ print("MMCV is installed.")
+else:
+ print("MMCV is not installed. Build it from the source.")
+ os.environ["MMCV_WITH_OPS"] = "1"
+ os.environ["FORCE_MLU"] = "1"
+ subprocess.run(["pip", "install", "-e", "./mmcv"], check=True)
+ subprocess.run(["pip", "list"], check=True)
+
+if is_package_installed("pytorch3d"):
+ print("pytorch3d is installed.")
+else:
+ print("pytorch3d is not installed. Build it from the source.")
+ subprocess.run(["pip", "install", "-e", "./pytorch3d"], check=True)
+
+if is_package_installed("MultiScaleDeformableAttention"):
+ print("MultiScaleDeformableAttention is installed.")
+else:
+ print("MultiScaleDeformableAttention is not installed. Build it from the source.")
+ subprocess.run(["pip", "install", "-e", "./models/aios/ops"], check=True)
+
+import os.path as osp
+from pathlib import Path
+import cv2
+import gradio as gr
+import torch
+import math
+import spaces
+from huggingface_hub import hf_hub_download
+
+hf_hub_download(repo_id="ttxskk/AiOS", filename="aios_checkpoint.pth", local_dir="/home/user/app/pretrained_models")
+
+OUT_FOLDER = '/home/user/app/demo_out'
+os.makedirs(OUT_FOLDER, exist_ok=True)
+
+DEMO_CONFIG = '/home/user/app/config/aios_smplx_demo.py'
+MODEL_PATH = '/home/user/app/pretrained_models/aios_checkpoint.pth'
+@spaces.GPU(enable_queue=True, duration=300)
+def infer(video_input, batch_size, threshold=0.5, num_person=1):
+ os.system(f'rm -rf {OUT_FOLDER}/*')
+ os.system(f'torchrun --nproc_per_node 1 \
+ main.py \
+ -c {DEMO_CONFIG} \
+ --options batch_size={batch_size} backbone="resnet50" num_person={num_person} threshold={threshold} \
+ --resume {MODEL_PATH} \
+ --eval \
+ --inference \
+ --inference_input {video_input} \
+ --to_vid \
+ --output_dir {OUT_FOLDER}')
+
+ video_path = os.path.join(OUT_FOLDER, 'demo_vid.mp4')
+ save_path_img = os.path.join(OUT_FOLDER, 'res_img')
+ save_path_mesh = os.path.join(OUT_FOLDER, 'mesh')
+ save_mesh_file = os.path.join(OUT_FOLDER, 'mesh.zip')
+ os.system(f'zip -r {save_mesh_file} {save_path_mesh}')
+ yield video_path, save_mesh_file
+
+TITLE = """
+
+
+
+
AiOS: All-in-One-Stage Expressive Human Pose and Shape Estimation
+
+
+
+
+
+
+
Recover multiple expressive human pose and shape recovery from an RGB image without any additional requirements, such as an off-the-shelf detection model.
+
+
+"""
+with gr.Blocks(title="AiOS", theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
+
+ gr.Markdown(TITLE)
+ with gr.Row():
+ with gr.Column(scale=2):
+ video_input = gr.Video(label="Input video", elem_classes="video")
+ with gr.Column(scale=1):
+ batch_size = gr.Textbox(label="Batch Size", type="text", value=8)
+ num_person = gr.Textbox(label="Number of Person", type="text", value=1)
+ threshold = gr.Slider(0, 1.0, value=0.5, label='Score Threshold')
+ send_button = gr.Button("Infer")
+ gr.HTML("""
""")
+
+ with gr.Row():
+ with gr.Column():
+ # processed_frames = gr.Image(label="Last processed frame")
+ video_output = gr.Video(elem_classes="video")
+ with gr.Column():
+ meshes_output = gr.File(label="3D meshes")
+
+ send_button.click(fn=infer, inputs=[video_input, batch_size, threshold, num_person], outputs=[video_output, meshes_output])
+ # example_videos = gr.Examples([
+ # ['./assets/01.mp4'],
+ # ['./assets/02.mp4'],
+ # ['./assets/03.mp4'],
+ # ['./assets/04.mp4'],
+ # ['./assets/05.mp4'],
+ # ['./assets/06.mp4'],
+ # ['./assets/07.mp4'],
+ # ['./assets/08.mp4'],
+ # ['./assets/09.mp4'],
+ # ],
+ # inputs=[video_input, 0.5])
+
+demo.queue().launch(debug=True)
diff --git a/assets/01.mp4 b/assets/01.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0a8e2831c3461c62dc6afa241addb829154bd812
--- /dev/null
+++ b/assets/01.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2ba560996c248d78be6556f1727ae6ced81cd62a002715c3ffd542f6202b204b
+size 2751935
diff --git a/assets/02.mp4 b/assets/02.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..1d418b5be50c00473f988180f7e4b07a8904f666
--- /dev/null
+++ b/assets/02.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:00702a08c978b27b3ddf6ddfd48c5a057753664c8e80d83f4b4e04dff45b8a71
+size 2827267
diff --git a/assets/03.mp4 b/assets/03.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..952e5c5b1e896ba31934ffbb5af03d93304f371f
--- /dev/null
+++ b/assets/03.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bfcc1ce90a0921ffa5550a04f743470081ff4599c265cf491e636a8ea70233d4
+size 4033767
diff --git a/assets/04.mp4 b/assets/04.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..1e95ea77e8571587b68657ce832f33b538ce2dd7
--- /dev/null
+++ b/assets/04.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:28531c3c0ad9cbcc097a00f8553aafcdc0513a881f0fa6d1a7937248f46fce0c
+size 2639842
diff --git a/assets/05.mp4 b/assets/05.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..e7d9d8f6ed14f0b8ddb5131840d02f8e8152195f
--- /dev/null
+++ b/assets/05.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1cf7f1b65d87f0a77c1d9456771e4f88228aa836426b4ad0cbad672e80d07e36
+size 3584040
diff --git a/assets/06.mp4 b/assets/06.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..dfb9d5a3279c180e8a4f03803fb42d966e451e80
--- /dev/null
+++ b/assets/06.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fcb4139d4863c5ec92224f7cb452ec4631be0613eb4c3f82ee7fbb6f89510fe2
+size 19797950
diff --git a/assets/07.mp4 b/assets/07.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..9d4b83e851a6040a7a864a8a546c60e20cfe3964
--- /dev/null
+++ b/assets/07.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4c71c5ed8573cb727c515d733e51c5da4654c58ab096cbca4bdf9b072e8284c7
+size 3274979
diff --git a/assets/08.mp4 b/assets/08.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0442a9995c9ef1efe83ac64cad26ec7a0d93ef29
--- /dev/null
+++ b/assets/08.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d14f03e984a0ebefd9e8429c8e0d3ecdb0ffc9126ad91a489b57dc0f5d12695b
+size 6825913
diff --git a/assets/09.mp4 b/assets/09.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..763be21a7cc25fda8bf2ae924843d58ff9922b56
--- /dev/null
+++ b/assets/09.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:30b5b6f75f024647a9e430f02b33caa1ccec327b487ba5bb451e2859e1e45142
+size 6336699
diff --git a/config/__init__.py b/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/config/aios_smplx.py b/config/aios_smplx.py
new file mode 100644
index 0000000000000000000000000000000000000000..51192ad8fa0e096446bd106ac00465737012bc98
--- /dev/null
+++ b/config/aios_smplx.py
@@ -0,0 +1,259 @@
+
+num_classes = 2
+lr = 0.0001*1.414/10
+param_dict_type = 'default'
+lr_backbone = 1e-05*1.414/10
+lr_backbone_names = ['backbone.0']
+lr_linear_proj_names = ['reference_points', 'sampling_offsets']
+lr_linear_proj_mult = 0.1
+ddetr_lr_param = False
+batch_size = 2
+weight_decay = 0.0001
+epochs = 200
+lr_drop = 11
+save_checkpoint_interval = 1
+clip_max_norm = 0.1
+onecyclelr = False
+multi_step_lr = True
+lr_drop_list = [30, 60]
+
+modelname = 'aios_smplx'
+frozen_weights = None
+backbone = 'resnet50'
+use_checkpoint = False
+
+dilation = False
+position_embedding = 'sine'
+pe_temperatureH = 20
+pe_temperatureW = 20
+return_interm_indices = [1, 2, 3]
+backbone_freeze_keywords = None
+enc_layers = 6
+dec_layers = 6
+pre_norm = False
+dim_feedforward = 2048
+hidden_dim = 256
+dropout = 0.0
+nheads = 8
+num_queries = 900
+query_dim = 4
+num_patterns = 0
+random_refpoints_xy = False
+fix_refpoints_hw = -1
+dec_layer_number = None
+num_feature_levels = 4
+enc_n_points = 4
+dec_n_points = 4
+dln_xy_noise = 0.2
+dln_hw_noise = 0.2
+two_stage_type = 'standard'
+two_stage_bbox_embed_share = False
+two_stage_class_embed_share = False
+two_stage_learn_wh = False
+two_stage_default_hw = 0.05
+two_stage_keep_all_tokens = False
+rm_detach = None
+num_select = 50
+transformer_activation = 'relu'
+batch_norm_type = 'FrozenBatchNorm2d'
+
+masks = False
+losses = ["smpl_pose", "smpl_beta", "smpl_expr",
+ "smpl_kp2d","smpl_kp3d","smpl_kp3d_ra",'labels', 'boxes', "keypoints"]
+# losses = ['labels', 'boxes', "keypoints"]
+aux_loss = True
+set_cost_class = 2.0
+set_cost_bbox = 5.0
+set_cost_giou = 2.0
+set_cost_keypoints = 10.0
+set_cost_kpvis = 0.0
+set_cost_oks = 4.0
+cls_loss_coef = 2.0
+# keypoints_loss_coef = 10.0
+
+smpl_pose_loss_root_coef = 10 * 0.1
+smpl_pose_loss_body_coef = 1 * 0.1
+smpl_pose_loss_lhand_coef = 1 * 0.1
+smpl_pose_loss_rhand_coef = 1 * 0.1
+smpl_pose_loss_jaw_coef = 1 * 0.1
+smpl_beta_loss_coef = 0.01
+smpl_expr_loss_coef = 0.01
+
+# smpl_kp3d_loss_coef = 10
+smpl_body_kp3d_loss_coef = 10.0 * 0.1
+smpl_face_kp3d_loss_coef = 1.0 * 0.1
+smpl_lhand_kp3d_loss_coef = 1 * 0.1
+smpl_rhand_kp3d_loss_coef = 1 * 0.1
+
+# kp3d ra
+smpl_body_kp3d_ra_loss_coef = 10 * 0.1
+smpl_face_kp3d_ra_loss_coef = 1 * 0.1
+smpl_lhand_kp3d_ra_loss_coef = 1 * 0.1
+smpl_rhand_kp3d_ra_loss_coef = 1 * 0.1
+
+
+# smpl_kp2d_ba_loss_coef = 1.0
+smpl_body_kp2d_loss_coef = 10.0 * 0.1
+smpl_lhand_kp2d_loss_coef = 5.0 * 0.1
+smpl_rhand_kp2d_loss_coef = 5.0 * 0.1
+smpl_face_kp2d_loss_coef = 1.0 * 0.1
+
+smpl_body_kp2d_ba_loss_coef = 0 * 0.1
+smpl_face_kp2d_ba_loss_coef = 0 * 0.1
+smpl_lhand_kp2d_ba_loss_coef = 0 * 0.1
+smpl_rhand_kp2d_ba_loss_coef = 0 * 0.1
+
+bbox_loss_coef = 5.0
+body_bbox_loss_coef = 5.0
+lhand_bbox_loss_coef = 5.0
+rhand_bbox_loss_coef = 5.0
+face_bbox_loss_coef = 5.0
+
+giou_loss_coef = 2.0
+body_giou_loss_coef = 2.0
+rhand_giou_loss_coef = 2.0
+lhand_giou_loss_coef = 2.0
+face_giou_loss_coef = 2.0
+
+keypoints_loss_coef = 10.0
+rhand_keypoints_loss_coef = 10.0
+lhand_keypoints_loss_coef = 10.0
+face_keypoints_loss_coef = 10.0
+
+oks_loss_coef=4.0
+rhand_oks_loss_coef = 0.5
+lhand_oks_loss_coef = 0.5
+face_oks_loss_coef = 4.0
+
+
+enc_loss_coef = 1.0
+interm_loss_coef = 1.0
+no_interm_box_loss = False
+focal_alpha = 0.25
+rm_self_attn_layers = None
+indices_idx_list = [1, 2, 3, 4, 5, 6, 7]
+
+decoder_sa_type = 'sa'
+matcher_type = 'HungarianMatcher'
+decoder_module_seq = ['sa', 'ca', 'ffn']
+nms_iou_threshold = -1
+
+dec_pred_bbox_embed_share = False
+dec_pred_class_embed_share = False
+dec_pred_pose_embed_share = False
+body_only = True
+
+# for dn
+use_dn = True
+dn_number = 100
+dn_box_noise_scale = 0.4
+dn_label_noise_ratio = 0.5
+embed_init_tgt = False
+dn_label_coef = 0.3
+dn_bbox_coef = 0.5
+dn_batch_gt_fuse = False
+dn_attn_mask_type_list = ['match2dn', 'dn2dn', 'group2group']
+dn_labelbook_size = 100
+
+match_unstable_error = False
+
+# for ema
+use_ema = True
+ema_decay = 0.9997
+ema_epoch = 0
+
+cls_no_bias = False
+num_body_points = 17 # for coco
+num_hand_points = 6 # for coco
+num_face_points = 6 # for coco
+num_group = 100
+num_box_decoder_layers = 2
+num_hand_face_decoder_layers = 4
+no_mmpose_keypoint_evaluator = True
+strong_aug = False
+
+body_model_test=\
+ dict(
+ type='smplx',
+ keypoint_src='smplx',
+ num_expression_coeffs=10,
+ num_betas=10,
+ keypoint_dst='smplx_137',
+ model_path='data/body_models/smplx',
+ use_pca=False,
+ use_face_contour=True)
+
+body_model_train = \
+ dict(
+ type='smplx',
+ keypoint_src='smplx',
+ num_expression_coeffs=10,
+ num_betas=10,
+ keypoint_dst='smplx_137',
+ model_path='data/body_models/smplx',
+ use_pca=False,
+ use_face_contour=True)
+
+# will be update in exp
+exp_name = 'output/exp52/dataset_debug'
+
+
+end_epoch = 150
+train_batch_size = 32
+
+scheduler = 'step'
+step_size = 20
+gamma = 0.1
+
+# continue
+continue_train = True
+pretrained_model_path = '../output/train_gta_synbody_ft_20230410_132110/model_dump/snapshot_2.pth.tar'
+
+# dataset setting
+# dataset_list = ['AGORA_MM','BEDLAM', 'COCO_NA']
+# trainset_3d = ['AGORA_MM','BEDLAM', 'COCO_NA']
+dataset_list = ['INFERENCE_demo']
+trainset_3d = []
+trainset_2d = []
+trainset_partition = {
+ }
+trainset_humandata = []
+testset = 'INFERENCE_demo'
+train_sizes=[480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
+train_max_size=1333
+test_sizes=[800]
+test_max_size=1333
+no_aug=False
+# model
+use_cache = True
+
+## UBody setting
+train_sample_interval = 10
+test_sample_interval = 100
+make_same_len = False
+
+## input, output size
+input_body_shape = (256, 192)
+output_hm_shape = (16, 16, 12)
+input_hand_shape = (256, 256)
+output_hand_hm_shape = (16, 16, 16)
+output_face_hm_shape = (8, 8, 8)
+input_face_shape = (192, 192)
+focal = (5000, 5000) # virtual focal lengths
+princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2
+ ) # virtual principal point position
+body_3d_size = 2
+hand_3d_size = 0.3
+face_3d_size = 0.3
+camera_3d_size = 2.5
+
+bbox_ratio = 1.2
+
+## directory
+output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
+
+agora_benchmark = 'na' # 'agora_model', 'test_only'
+
+# strategy
+data_strategy = 'balance' # 'balance' need to define total_data_len
+total_data_len = 'auto'
\ No newline at end of file
diff --git a/config/aios_smplx_agora_val.py b/config/aios_smplx_agora_val.py
new file mode 100644
index 0000000000000000000000000000000000000000..d85b8a081435723cddd741b703d68e5c6bb80ef6
--- /dev/null
+++ b/config/aios_smplx_agora_val.py
@@ -0,0 +1,265 @@
+
+num_classes = 2
+lr = 1e-04
+param_dict_type = 'default'
+lr_backbone = 1e-05
+lr_backbone_names = ['backbone.0']
+lr_linear_proj_names = ['reference_points', 'sampling_offsets']
+lr_linear_proj_mult = 0.1
+ddetr_lr_param = False
+batch_size = 2
+weight_decay = 0.0001
+epochs = 200
+lr_drop = 11
+save_checkpoint_interval = 1
+clip_max_norm = 0.1
+onecyclelr = False
+multi_step_lr = True
+lr_drop_list = [30, 60]
+
+modelname = 'aios_smplx'
+frozen_weights = None
+backbone = 'resnet50'
+use_checkpoint = False
+
+dilation = False
+position_embedding = 'sine'
+pe_temperatureH = 20
+pe_temperatureW = 20
+return_interm_indices = [1, 2, 3]
+backbone_freeze_keywords = None
+enc_layers = 6
+dec_layers = 6
+pre_norm = False
+dim_feedforward = 2048
+hidden_dim = 256
+dropout = 0.0
+nheads = 8
+num_queries = 900
+query_dim = 4
+num_patterns = 0
+random_refpoints_xy = False
+fix_refpoints_hw = -1
+dec_layer_number = None
+num_feature_levels = 4
+enc_n_points = 4
+dec_n_points = 4
+dln_xy_noise = 0.2
+dln_hw_noise = 0.2
+two_stage_type = 'standard'
+two_stage_bbox_embed_share = False
+two_stage_class_embed_share = False
+two_stage_learn_wh = False
+two_stage_default_hw = 0.05
+two_stage_keep_all_tokens = False
+rm_detach = None
+num_select = 50
+transformer_activation = 'relu'
+batch_norm_type = 'FrozenBatchNorm2d'
+
+masks = False
+losses = ["smpl_pose", "smpl_beta", "smpl_expr",
+ "smpl_kp2d","smpl_kp3d","smpl_kp3d_ra",'labels', 'boxes', "keypoints"]
+# losses = ['labels', 'boxes', "keypoints"]
+aux_loss = True
+set_cost_class = 2.0
+set_cost_bbox = 5.0
+set_cost_giou = 2.0
+set_cost_keypoints = 10.0
+set_cost_kpvis = 0.0
+set_cost_oks = 4.0
+cls_loss_coef = 2.0
+# keypoints_loss_coef = 10.0
+
+smpl_pose_loss_root_coef = 10 * 0.1
+smpl_pose_loss_body_coef = 1 * 0.1
+smpl_pose_loss_lhand_coef = 1 * 0.1
+smpl_pose_loss_rhand_coef = 1 * 0.1
+smpl_pose_loss_jaw_coef = 1 * 0.1
+smpl_beta_loss_coef = 0.01
+smpl_expr_loss_coef = 0.01
+
+# smpl_kp3d_loss_coef = 10
+smpl_body_kp3d_loss_coef = 10.0 * 0.1
+smpl_face_kp3d_loss_coef = 1.0 * 0.1
+smpl_lhand_kp3d_loss_coef = 1 * 0.1
+smpl_rhand_kp3d_loss_coef = 1 * 0.1
+
+# kp3d ra
+smpl_body_kp3d_ra_loss_coef = 10 * 0.1
+smpl_face_kp3d_ra_loss_coef = 1 * 0.1
+smpl_lhand_kp3d_ra_loss_coef = 1 * 0.1
+smpl_rhand_kp3d_ra_loss_coef = 1 * 0.1
+
+
+# smpl_kp2d_ba_loss_coef = 1.0
+smpl_body_kp2d_loss_coef = 10.0 * 0.1
+smpl_lhand_kp2d_loss_coef = 5.0 * 0.1
+smpl_rhand_kp2d_loss_coef = 5.0 * 0.1
+smpl_face_kp2d_loss_coef = 1.0 * 0.1
+
+smpl_body_kp2d_ba_loss_coef = 0 * 0.1
+smpl_face_kp2d_ba_loss_coef = 0 * 0.1
+smpl_lhand_kp2d_ba_loss_coef = 0 * 0.1
+smpl_rhand_kp2d_ba_loss_coef = 0 * 0.1
+
+bbox_loss_coef = 5.0
+body_bbox_loss_coef = 5.0
+lhand_bbox_loss_coef = 5.0
+rhand_bbox_loss_coef = 5.0
+face_bbox_loss_coef = 5.0
+
+giou_loss_coef = 2.0
+body_giou_loss_coef = 2.0
+rhand_giou_loss_coef = 2.0
+lhand_giou_loss_coef = 2.0
+face_giou_loss_coef = 2.0
+
+keypoints_loss_coef = 10.0
+rhand_keypoints_loss_coef = 10.0
+lhand_keypoints_loss_coef = 10.0
+face_keypoints_loss_coef = 10.0
+
+oks_loss_coef=4.0
+rhand_oks_loss_coef = 0.5
+lhand_oks_loss_coef = 0.5
+face_oks_loss_coef = 4.0
+
+
+enc_loss_coef = 1.0
+interm_loss_coef = 1.0
+no_interm_box_loss = False
+focal_alpha = 0.25
+rm_self_attn_layers = None
+indices_idx_list = [1, 2, 3, 4, 5, 6, 7]
+
+decoder_sa_type = 'sa'
+matcher_type = 'HungarianMatcher'
+decoder_module_seq = ['sa', 'ca', 'ffn']
+nms_iou_threshold = -1
+
+dec_pred_bbox_embed_share = False
+dec_pred_class_embed_share = False
+dec_pred_pose_embed_share = False
+body_only = True
+
+# for dn
+use_dn = True
+dn_number = 100
+dn_box_noise_scale = 0.4
+dn_label_noise_ratio = 0.5
+embed_init_tgt = False
+dn_label_coef = 0.3
+dn_bbox_coef = 0.5
+dn_batch_gt_fuse = False
+dn_attn_mask_type_list = ['match2dn', 'dn2dn', 'group2group']
+dn_labelbook_size = 100
+
+match_unstable_error = False
+
+# for ema
+use_ema = True
+ema_decay = 0.9997
+ema_epoch = 0
+
+cls_no_bias = False
+num_body_points = 17 # for coco
+num_hand_points = 6 # for coco
+num_face_points = 6 # for coco
+num_group = 100
+num_box_decoder_layers = 2
+num_hand_face_decoder_layers = 4
+no_mmpose_keypoint_evaluator = True
+strong_aug = False
+
+body_model_test=\
+ dict(
+ type='smplx',
+ keypoint_src='smplx',
+ num_expression_coeffs=10,
+ num_betas=10,
+ keypoint_dst='smplx_137',
+ model_path='data/body_models/smplx',
+ use_pca=False,
+ use_face_contour=True)
+
+body_model_train = \
+ dict(
+ type='smplx',
+ keypoint_src='smplx',
+ num_expression_coeffs=10,
+ num_betas=10,
+ keypoint_dst='smplx_137',
+ model_path='data/body_models/smplx',
+ use_pca=False,
+ use_face_contour=True)
+
+# will be update in exp
+exp_name = 'output/exp52/dataset_debug'
+
+
+end_epoch = 150
+train_batch_size = 32
+
+scheduler = 'step'
+step_size = 20
+gamma = 0.1
+
+# continue
+continue_train = True
+pretrained_model_path = '../output/train_gta_synbody_ft_20230410_132110/model_dump/snapshot_2.pth.tar'
+
+# dataset setting
+# dataset_list = ['AGORA_MM','BEDLAM', 'COCO_NA']
+# trainset_3d = ['AGORA_MM','BEDLAM', 'COCO_NA']
+dataset_list = ['AGORA_MM','BEDLAM', 'COCO_NA']
+trainset_3d = ['AGORA_MM','BEDLAM', 'COCO_NA']
+trainset_2d = []
+trainset_partition = {
+ 'AGORA_MM': 0.4,
+ 'BEDLAM': 0.7,
+ 'COCO_NA': 1,
+
+ # 'EgoBody_Egocentric': 1,
+ # 'EgoBody_Kinect': 1.0,
+ }
+trainset_humandata = []
+testset = 'INFERENCE_AGORA'
+train_sizes=[480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
+train_max_size=1333
+test_sizes=[800]
+test_max_size=1333
+no_aug=False
+# model
+use_cache = True
+
+## UBody setting
+train_sample_interval = 10
+test_sample_interval = 100
+make_same_len = False
+
+## input, output size
+input_body_shape = (256, 192)
+output_hm_shape = (16, 16, 12)
+input_hand_shape = (256, 256)
+output_hand_hm_shape = (16, 16, 16)
+output_face_hm_shape = (8, 8, 8)
+input_face_shape = (192, 192)
+focal = (5000, 5000) # virtual focal lengths
+princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2
+ ) # virtual principal point position
+body_3d_size = 2
+hand_3d_size = 0.3
+face_3d_size = 0.3
+camera_3d_size = 2.5
+
+bbox_ratio = 1.2
+
+## directory
+output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
+
+agora_benchmark = 'na' # 'agora_model', 'test_only'
+
+# strategy
+data_strategy = 'balance' # 'balance' need to define total_data_len
+total_data_len = 'auto'
\ No newline at end of file
diff --git a/config/aios_smplx_bedlam.py b/config/aios_smplx_bedlam.py
new file mode 100644
index 0000000000000000000000000000000000000000..88eeb29f74eceb0404269ffa56c628afa068fcff
--- /dev/null
+++ b/config/aios_smplx_bedlam.py
@@ -0,0 +1,265 @@
+
+num_classes = 2
+lr = 0.0001*1.414/10
+param_dict_type = 'default'
+lr_backbone = 1e-05*1.414/10
+lr_backbone_names = ['backbone.0']
+lr_linear_proj_names = ['reference_points', 'sampling_offsets']
+lr_linear_proj_mult = 0.1
+ddetr_lr_param = False
+batch_size = 2
+weight_decay = 0.0001
+epochs = 200
+lr_drop = 11
+save_checkpoint_interval = 1
+clip_max_norm = 0.1
+onecyclelr = False
+multi_step_lr = True
+lr_drop_list = [30, 60]
+
+modelname = 'aios_smplx'
+frozen_weights = None
+backbone = 'resnet50'
+use_checkpoint = False
+
+dilation = False
+position_embedding = 'sine'
+pe_temperatureH = 20
+pe_temperatureW = 20
+return_interm_indices = [1, 2, 3]
+backbone_freeze_keywords = None
+enc_layers = 6
+dec_layers = 6
+pre_norm = False
+dim_feedforward = 2048
+hidden_dim = 256
+dropout = 0.0
+nheads = 8
+num_queries = 900
+query_dim = 4
+num_patterns = 0
+random_refpoints_xy = False
+fix_refpoints_hw = -1
+dec_layer_number = None
+num_feature_levels = 4
+enc_n_points = 4
+dec_n_points = 4
+dln_xy_noise = 0.2
+dln_hw_noise = 0.2
+two_stage_type = 'standard'
+two_stage_bbox_embed_share = False
+two_stage_class_embed_share = False
+two_stage_learn_wh = False
+two_stage_default_hw = 0.05
+two_stage_keep_all_tokens = False
+rm_detach = None
+num_select = 50
+transformer_activation = 'relu'
+batch_norm_type = 'FrozenBatchNorm2d'
+
+masks = False
+losses = ["smpl_pose", "smpl_beta", "smpl_expr",
+ "smpl_kp2d","smpl_kp3d","smpl_kp3d_ra",'labels', 'boxes', "keypoints"]
+# losses = ['labels', 'boxes', "keypoints"]
+aux_loss = True
+set_cost_class = 2.0
+set_cost_bbox = 5.0
+set_cost_giou = 2.0
+set_cost_keypoints = 10.0
+set_cost_kpvis = 0.0
+set_cost_oks = 4.0
+cls_loss_coef = 2.0
+# keypoints_loss_coef = 10.0
+
+smpl_pose_loss_root_coef = 10 * 0.1
+smpl_pose_loss_body_coef = 1 * 0.1
+smpl_pose_loss_lhand_coef = 1 * 0.1
+smpl_pose_loss_rhand_coef = 1 * 0.1
+smpl_pose_loss_jaw_coef = 1 * 0.1
+smpl_beta_loss_coef = 0.01
+smpl_expr_loss_coef = 0.01
+
+# smpl_kp3d_loss_coef = 10
+smpl_body_kp3d_loss_coef = 10.0 * 0.1
+smpl_face_kp3d_loss_coef = 1.0 * 0.1
+smpl_lhand_kp3d_loss_coef = 1 * 0.1
+smpl_rhand_kp3d_loss_coef = 1 * 0.1
+
+# kp3d ra
+smpl_body_kp3d_ra_loss_coef = 10 * 0.1
+smpl_face_kp3d_ra_loss_coef = 1 * 0.1
+smpl_lhand_kp3d_ra_loss_coef = 1 * 0.1
+smpl_rhand_kp3d_ra_loss_coef = 1 * 0.1
+
+
+# smpl_kp2d_ba_loss_coef = 1.0
+smpl_body_kp2d_loss_coef = 10.0 * 0.1
+smpl_lhand_kp2d_loss_coef = 5.0 * 0.1
+smpl_rhand_kp2d_loss_coef = 5.0 * 0.1
+smpl_face_kp2d_loss_coef = 1.0 * 0.1
+
+smpl_body_kp2d_ba_loss_coef = 0 * 0.1
+smpl_face_kp2d_ba_loss_coef = 0 * 0.1
+smpl_lhand_kp2d_ba_loss_coef = 0 * 0.1
+smpl_rhand_kp2d_ba_loss_coef = 0 * 0.1
+
+bbox_loss_coef = 5.0
+body_bbox_loss_coef = 5.0
+lhand_bbox_loss_coef = 5.0
+rhand_bbox_loss_coef = 5.0
+face_bbox_loss_coef = 5.0
+
+giou_loss_coef = 2.0
+body_giou_loss_coef = 2.0
+rhand_giou_loss_coef = 2.0
+lhand_giou_loss_coef = 2.0
+face_giou_loss_coef = 2.0
+
+keypoints_loss_coef = 10.0
+rhand_keypoints_loss_coef = 10.0
+lhand_keypoints_loss_coef = 10.0
+face_keypoints_loss_coef = 10.0
+
+oks_loss_coef=4.0
+rhand_oks_loss_coef = 0.5
+lhand_oks_loss_coef = 0.5
+face_oks_loss_coef = 4.0
+
+
+enc_loss_coef = 1.0
+interm_loss_coef = 1.0
+no_interm_box_loss = False
+focal_alpha = 0.25
+rm_self_attn_layers = None
+indices_idx_list = [1, 2, 3, 4, 5, 6, 7]
+
+decoder_sa_type = 'sa'
+matcher_type = 'HungarianMatcher'
+decoder_module_seq = ['sa', 'ca', 'ffn']
+nms_iou_threshold = -1
+
+dec_pred_bbox_embed_share = False
+dec_pred_class_embed_share = False
+dec_pred_pose_embed_share = False
+body_only = True
+
+# for dn
+use_dn = True
+dn_number = 100
+dn_box_noise_scale = 0.4
+dn_label_noise_ratio = 0.5
+embed_init_tgt = False
+dn_label_coef = 0.3
+dn_bbox_coef = 0.5
+dn_batch_gt_fuse = False
+dn_attn_mask_type_list = ['match2dn', 'dn2dn', 'group2group']
+dn_labelbook_size = 100
+
+match_unstable_error = False
+
+# for ema
+use_ema = True
+ema_decay = 0.9997
+ema_epoch = 0
+
+cls_no_bias = False
+num_body_points = 17 # for coco
+num_hand_points = 6 # for coco
+num_face_points = 6 # for coco
+num_group = 100
+num_box_decoder_layers = 2
+num_hand_face_decoder_layers = 4
+no_mmpose_keypoint_evaluator = True
+strong_aug = False
+
+body_model_test=\
+ dict(
+ type='smplx',
+ keypoint_src='smplx',
+ num_expression_coeffs=10,
+ num_betas=10,
+ keypoint_dst='smplx_137',
+ model_path='data/body_models/smplx',
+ use_pca=False,
+ use_face_contour=True)
+
+body_model_train = \
+ dict(
+ type='smplx',
+ keypoint_src='smplx',
+ num_expression_coeffs=10,
+ num_betas=10,
+ keypoint_dst='smplx_137',
+ model_path='data/body_models/smplx',
+ use_pca=False,
+ use_face_contour=True)
+
+# will be update in exp
+exp_name = 'output/exp52/dataset_debug'
+
+
+end_epoch = 150
+train_batch_size = 32
+
+scheduler = 'step'
+step_size = 20
+gamma = 0.1
+
+# continue
+continue_train = True
+pretrained_model_path = '../output/train_gta_synbody_ft_20230410_132110/model_dump/snapshot_2.pth.tar'
+
+# dataset setting
+# dataset_list = ['AGORA_MM','BEDLAM', 'COCO_NA']
+# trainset_3d = ['AGORA_MM','BEDLAM', 'COCO_NA']
+dataset_list = ['AGORA_MM','BEDLAM', 'COCO_NA']
+trainset_3d = ['AGORA_MM','BEDLAM', 'COCO_NA']
+trainset_2d = []
+trainset_partition = {
+ 'AGORA_MM': 0.4,
+ 'BEDLAM': 0.7,
+ 'COCO_NA': 1,
+
+ # 'EgoBody_Egocentric': 1,
+ # 'EgoBody_Kinect': 1.0,
+ }
+trainset_humandata = []
+testset = 'INFERENCE_BEDLAM'
+train_sizes=[480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
+train_max_size=1333
+test_sizes=[800]
+test_max_size=1333
+no_aug=False
+# model
+use_cache = True
+
+## UBody setting
+train_sample_interval = 10
+test_sample_interval = 100
+make_same_len = False
+
+## input, output size
+input_body_shape = (256, 192)
+output_hm_shape = (16, 16, 12)
+input_hand_shape = (256, 256)
+output_hand_hm_shape = (16, 16, 16)
+output_face_hm_shape = (8, 8, 8)
+input_face_shape = (192, 192)
+focal = (5000, 5000) # virtual focal lengths
+princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2
+ ) # virtual principal point position
+body_3d_size = 2
+hand_3d_size = 0.3
+face_3d_size = 0.3
+camera_3d_size = 2.5
+
+bbox_ratio = 1.2
+
+## directory
+output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
+
+agora_benchmark = 'na' # 'agora_model', 'test_only'
+
+# strategy
+data_strategy = 'balance' # 'balance' need to define total_data_len
+total_data_len = 'auto'
\ No newline at end of file
diff --git a/config/aios_smplx_demo.py b/config/aios_smplx_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..51192ad8fa0e096446bd106ac00465737012bc98
--- /dev/null
+++ b/config/aios_smplx_demo.py
@@ -0,0 +1,259 @@
+
+num_classes = 2
+lr = 0.0001*1.414/10
+param_dict_type = 'default'
+lr_backbone = 1e-05*1.414/10
+lr_backbone_names = ['backbone.0']
+lr_linear_proj_names = ['reference_points', 'sampling_offsets']
+lr_linear_proj_mult = 0.1
+ddetr_lr_param = False
+batch_size = 2
+weight_decay = 0.0001
+epochs = 200
+lr_drop = 11
+save_checkpoint_interval = 1
+clip_max_norm = 0.1
+onecyclelr = False
+multi_step_lr = True
+lr_drop_list = [30, 60]
+
+modelname = 'aios_smplx'
+frozen_weights = None
+backbone = 'resnet50'
+use_checkpoint = False
+
+dilation = False
+position_embedding = 'sine'
+pe_temperatureH = 20
+pe_temperatureW = 20
+return_interm_indices = [1, 2, 3]
+backbone_freeze_keywords = None
+enc_layers = 6
+dec_layers = 6
+pre_norm = False
+dim_feedforward = 2048
+hidden_dim = 256
+dropout = 0.0
+nheads = 8
+num_queries = 900
+query_dim = 4
+num_patterns = 0
+random_refpoints_xy = False
+fix_refpoints_hw = -1
+dec_layer_number = None
+num_feature_levels = 4
+enc_n_points = 4
+dec_n_points = 4
+dln_xy_noise = 0.2
+dln_hw_noise = 0.2
+two_stage_type = 'standard'
+two_stage_bbox_embed_share = False
+two_stage_class_embed_share = False
+two_stage_learn_wh = False
+two_stage_default_hw = 0.05
+two_stage_keep_all_tokens = False
+rm_detach = None
+num_select = 50
+transformer_activation = 'relu'
+batch_norm_type = 'FrozenBatchNorm2d'
+
+masks = False
+losses = ["smpl_pose", "smpl_beta", "smpl_expr",
+ "smpl_kp2d","smpl_kp3d","smpl_kp3d_ra",'labels', 'boxes', "keypoints"]
+# losses = ['labels', 'boxes', "keypoints"]
+aux_loss = True
+set_cost_class = 2.0
+set_cost_bbox = 5.0
+set_cost_giou = 2.0
+set_cost_keypoints = 10.0
+set_cost_kpvis = 0.0
+set_cost_oks = 4.0
+cls_loss_coef = 2.0
+# keypoints_loss_coef = 10.0
+
+smpl_pose_loss_root_coef = 10 * 0.1
+smpl_pose_loss_body_coef = 1 * 0.1
+smpl_pose_loss_lhand_coef = 1 * 0.1
+smpl_pose_loss_rhand_coef = 1 * 0.1
+smpl_pose_loss_jaw_coef = 1 * 0.1
+smpl_beta_loss_coef = 0.01
+smpl_expr_loss_coef = 0.01
+
+# smpl_kp3d_loss_coef = 10
+smpl_body_kp3d_loss_coef = 10.0 * 0.1
+smpl_face_kp3d_loss_coef = 1.0 * 0.1
+smpl_lhand_kp3d_loss_coef = 1 * 0.1
+smpl_rhand_kp3d_loss_coef = 1 * 0.1
+
+# kp3d ra
+smpl_body_kp3d_ra_loss_coef = 10 * 0.1
+smpl_face_kp3d_ra_loss_coef = 1 * 0.1
+smpl_lhand_kp3d_ra_loss_coef = 1 * 0.1
+smpl_rhand_kp3d_ra_loss_coef = 1 * 0.1
+
+
+# smpl_kp2d_ba_loss_coef = 1.0
+smpl_body_kp2d_loss_coef = 10.0 * 0.1
+smpl_lhand_kp2d_loss_coef = 5.0 * 0.1
+smpl_rhand_kp2d_loss_coef = 5.0 * 0.1
+smpl_face_kp2d_loss_coef = 1.0 * 0.1
+
+smpl_body_kp2d_ba_loss_coef = 0 * 0.1
+smpl_face_kp2d_ba_loss_coef = 0 * 0.1
+smpl_lhand_kp2d_ba_loss_coef = 0 * 0.1
+smpl_rhand_kp2d_ba_loss_coef = 0 * 0.1
+
+bbox_loss_coef = 5.0
+body_bbox_loss_coef = 5.0
+lhand_bbox_loss_coef = 5.0
+rhand_bbox_loss_coef = 5.0
+face_bbox_loss_coef = 5.0
+
+giou_loss_coef = 2.0
+body_giou_loss_coef = 2.0
+rhand_giou_loss_coef = 2.0
+lhand_giou_loss_coef = 2.0
+face_giou_loss_coef = 2.0
+
+keypoints_loss_coef = 10.0
+rhand_keypoints_loss_coef = 10.0
+lhand_keypoints_loss_coef = 10.0
+face_keypoints_loss_coef = 10.0
+
+oks_loss_coef=4.0
+rhand_oks_loss_coef = 0.5
+lhand_oks_loss_coef = 0.5
+face_oks_loss_coef = 4.0
+
+
+enc_loss_coef = 1.0
+interm_loss_coef = 1.0
+no_interm_box_loss = False
+focal_alpha = 0.25
+rm_self_attn_layers = None
+indices_idx_list = [1, 2, 3, 4, 5, 6, 7]
+
+decoder_sa_type = 'sa'
+matcher_type = 'HungarianMatcher'
+decoder_module_seq = ['sa', 'ca', 'ffn']
+nms_iou_threshold = -1
+
+dec_pred_bbox_embed_share = False
+dec_pred_class_embed_share = False
+dec_pred_pose_embed_share = False
+body_only = True
+
+# for dn
+use_dn = True
+dn_number = 100
+dn_box_noise_scale = 0.4
+dn_label_noise_ratio = 0.5
+embed_init_tgt = False
+dn_label_coef = 0.3
+dn_bbox_coef = 0.5
+dn_batch_gt_fuse = False
+dn_attn_mask_type_list = ['match2dn', 'dn2dn', 'group2group']
+dn_labelbook_size = 100
+
+match_unstable_error = False
+
+# for ema
+use_ema = True
+ema_decay = 0.9997
+ema_epoch = 0
+
+cls_no_bias = False
+num_body_points = 17 # for coco
+num_hand_points = 6 # for coco
+num_face_points = 6 # for coco
+num_group = 100
+num_box_decoder_layers = 2
+num_hand_face_decoder_layers = 4
+no_mmpose_keypoint_evaluator = True
+strong_aug = False
+
+body_model_test=\
+ dict(
+ type='smplx',
+ keypoint_src='smplx',
+ num_expression_coeffs=10,
+ num_betas=10,
+ keypoint_dst='smplx_137',
+ model_path='data/body_models/smplx',
+ use_pca=False,
+ use_face_contour=True)
+
+body_model_train = \
+ dict(
+ type='smplx',
+ keypoint_src='smplx',
+ num_expression_coeffs=10,
+ num_betas=10,
+ keypoint_dst='smplx_137',
+ model_path='data/body_models/smplx',
+ use_pca=False,
+ use_face_contour=True)
+
+# will be update in exp
+exp_name = 'output/exp52/dataset_debug'
+
+
+end_epoch = 150
+train_batch_size = 32
+
+scheduler = 'step'
+step_size = 20
+gamma = 0.1
+
+# continue
+continue_train = True
+pretrained_model_path = '../output/train_gta_synbody_ft_20230410_132110/model_dump/snapshot_2.pth.tar'
+
+# dataset setting
+# dataset_list = ['AGORA_MM','BEDLAM', 'COCO_NA']
+# trainset_3d = ['AGORA_MM','BEDLAM', 'COCO_NA']
+dataset_list = ['INFERENCE_demo']
+trainset_3d = []
+trainset_2d = []
+trainset_partition = {
+ }
+trainset_humandata = []
+testset = 'INFERENCE_demo'
+train_sizes=[480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
+train_max_size=1333
+test_sizes=[800]
+test_max_size=1333
+no_aug=False
+# model
+use_cache = True
+
+## UBody setting
+train_sample_interval = 10
+test_sample_interval = 100
+make_same_len = False
+
+## input, output size
+input_body_shape = (256, 192)
+output_hm_shape = (16, 16, 12)
+input_hand_shape = (256, 256)
+output_hand_hm_shape = (16, 16, 16)
+output_face_hm_shape = (8, 8, 8)
+input_face_shape = (192, 192)
+focal = (5000, 5000) # virtual focal lengths
+princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2
+ ) # virtual principal point position
+body_3d_size = 2
+hand_3d_size = 0.3
+face_3d_size = 0.3
+camera_3d_size = 2.5
+
+bbox_ratio = 1.2
+
+## directory
+output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
+
+agora_benchmark = 'na' # 'agora_model', 'test_only'
+
+# strategy
+data_strategy = 'balance' # 'balance' need to define total_data_len
+total_data_len = 'auto'
\ No newline at end of file
diff --git a/config/aios_smplx_inference.py b/config/aios_smplx_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..13cf56fd6589c6b79671b71b2b7217ee0db593fb
--- /dev/null
+++ b/config/aios_smplx_inference.py
@@ -0,0 +1,265 @@
+
+num_classes = 2
+lr = 0.0001*1.414/10
+param_dict_type = 'default'
+lr_backbone = 1e-05*1.414/10
+lr_backbone_names = ['backbone.0']
+lr_linear_proj_names = ['reference_points', 'sampling_offsets']
+lr_linear_proj_mult = 0.1
+ddetr_lr_param = False
+batch_size = 2
+weight_decay = 0.0001
+epochs = 200
+lr_drop = 11
+save_checkpoint_interval = 1
+clip_max_norm = 0.1
+onecyclelr = False
+multi_step_lr = True
+lr_drop_list = [30, 60]
+
+modelname = 'aios_smplx'
+frozen_weights = None
+backbone = 'resnet50'
+use_checkpoint = False
+
+dilation = False
+position_embedding = 'sine'
+pe_temperatureH = 20
+pe_temperatureW = 20
+return_interm_indices = [1, 2, 3]
+backbone_freeze_keywords = None
+enc_layers = 6
+dec_layers = 6
+pre_norm = False
+dim_feedforward = 2048
+hidden_dim = 256
+dropout = 0.0
+nheads = 8
+num_queries = 900
+query_dim = 4
+num_patterns = 0
+random_refpoints_xy = False
+fix_refpoints_hw = -1
+dec_layer_number = None
+num_feature_levels = 4
+enc_n_points = 4
+dec_n_points = 4
+dln_xy_noise = 0.2
+dln_hw_noise = 0.2
+two_stage_type = 'standard'
+two_stage_bbox_embed_share = False
+two_stage_class_embed_share = False
+two_stage_learn_wh = False
+two_stage_default_hw = 0.05
+two_stage_keep_all_tokens = False
+rm_detach = None
+num_select = 50
+transformer_activation = 'relu'
+batch_norm_type = 'FrozenBatchNorm2d'
+
+masks = False
+losses = ["smpl_pose", "smpl_beta", "smpl_expr",
+ "smpl_kp2d","smpl_kp3d","smpl_kp3d_ra",'labels', 'boxes', "keypoints"]
+# losses = ['labels', 'boxes', "keypoints"]
+aux_loss = True
+set_cost_class = 2.0
+set_cost_bbox = 5.0
+set_cost_giou = 2.0
+set_cost_keypoints = 10.0
+set_cost_kpvis = 0.0
+set_cost_oks = 4.0
+cls_loss_coef = 2.0
+# keypoints_loss_coef = 10.0
+
+smpl_pose_loss_root_coef = 10 * 0.1
+smpl_pose_loss_body_coef = 1 * 0.1
+smpl_pose_loss_lhand_coef = 1 * 0.1
+smpl_pose_loss_rhand_coef = 1 * 0.1
+smpl_pose_loss_jaw_coef = 1 * 0.1
+smpl_beta_loss_coef = 0.01
+smpl_expr_loss_coef = 0.01
+
+# smpl_kp3d_loss_coef = 10
+smpl_body_kp3d_loss_coef = 10.0 * 0.1
+smpl_face_kp3d_loss_coef = 1.0 * 0.1
+smpl_lhand_kp3d_loss_coef = 1 * 0.1
+smpl_rhand_kp3d_loss_coef = 1 * 0.1
+
+# kp3d ra
+smpl_body_kp3d_ra_loss_coef = 10 * 0.1
+smpl_face_kp3d_ra_loss_coef = 1 * 0.1
+smpl_lhand_kp3d_ra_loss_coef = 1 * 0.1
+smpl_rhand_kp3d_ra_loss_coef = 1 * 0.1
+
+
+# smpl_kp2d_ba_loss_coef = 1.0
+smpl_body_kp2d_loss_coef = 10.0 * 0.1
+smpl_lhand_kp2d_loss_coef = 5.0 * 0.1
+smpl_rhand_kp2d_loss_coef = 5.0 * 0.1
+smpl_face_kp2d_loss_coef = 1.0 * 0.1
+
+smpl_body_kp2d_ba_loss_coef = 0 * 0.1
+smpl_face_kp2d_ba_loss_coef = 0 * 0.1
+smpl_lhand_kp2d_ba_loss_coef = 0 * 0.1
+smpl_rhand_kp2d_ba_loss_coef = 0 * 0.1
+
+bbox_loss_coef = 5.0
+body_bbox_loss_coef = 5.0
+lhand_bbox_loss_coef = 5.0
+rhand_bbox_loss_coef = 5.0
+face_bbox_loss_coef = 5.0
+
+giou_loss_coef = 2.0
+body_giou_loss_coef = 2.0
+rhand_giou_loss_coef = 2.0
+lhand_giou_loss_coef = 2.0
+face_giou_loss_coef = 2.0
+
+keypoints_loss_coef = 10.0
+rhand_keypoints_loss_coef = 10.0
+lhand_keypoints_loss_coef = 10.0
+face_keypoints_loss_coef = 10.0
+
+oks_loss_coef=4.0
+rhand_oks_loss_coef = 0.5
+lhand_oks_loss_coef = 0.5
+face_oks_loss_coef = 4.0
+
+
+enc_loss_coef = 1.0
+interm_loss_coef = 1.0
+no_interm_box_loss = False
+focal_alpha = 0.25
+rm_self_attn_layers = None
+indices_idx_list = [1, 2, 3, 4, 5, 6, 7]
+
+decoder_sa_type = 'sa'
+matcher_type = 'HungarianMatcher'
+decoder_module_seq = ['sa', 'ca', 'ffn']
+nms_iou_threshold = -1
+
+dec_pred_bbox_embed_share = False
+dec_pred_class_embed_share = False
+dec_pred_pose_embed_share = False
+body_only = True
+
+# for dn
+use_dn = True
+dn_number = 100
+dn_box_noise_scale = 0.4
+dn_label_noise_ratio = 0.5
+embed_init_tgt = False
+dn_label_coef = 0.3
+dn_bbox_coef = 0.5
+dn_batch_gt_fuse = False
+dn_attn_mask_type_list = ['match2dn', 'dn2dn', 'group2group']
+dn_labelbook_size = 100
+
+match_unstable_error = False
+
+# for ema
+use_ema = True
+ema_decay = 0.9997
+ema_epoch = 0
+
+cls_no_bias = False
+num_body_points = 17 # for coco
+num_hand_points = 6 # for coco
+num_face_points = 6 # for coco
+num_group = 100
+num_box_decoder_layers = 2
+num_hand_face_decoder_layers = 4
+no_mmpose_keypoint_evaluator = True
+strong_aug = False
+
+body_model_test=\
+ dict(
+ type='smplx',
+ keypoint_src='smplx',
+ num_expression_coeffs=10,
+ num_betas=10,
+ keypoint_dst='smplx_137',
+ model_path='data/body_models/smplx',
+ use_pca=False,
+ use_face_contour=True)
+
+body_model_train = \
+ dict(
+ type='smplx',
+ keypoint_src='smplx',
+ num_expression_coeffs=10,
+ num_betas=10,
+ keypoint_dst='smplx_137',
+ model_path='data/body_models/smplx',
+ use_pca=False,
+ use_face_contour=True)
+
+# will be update in exp
+exp_name = 'output/exp52/dataset_debug'
+
+
+end_epoch = 150
+train_batch_size = 32
+
+scheduler = 'step'
+step_size = 20
+gamma = 0.1
+
+# continue
+continue_train = True
+pretrained_model_path = '../output/train_gta_synbody_ft_20230410_132110/model_dump/snapshot_2.pth.tar'
+
+# dataset setting
+# dataset_list = ['AGORA_MM','BEDLAM', 'COCO_NA']
+# trainset_3d = ['AGORA_MM','BEDLAM', 'COCO_NA']
+dataset_list = ['AGORA_MM','BEDLAM', 'COCO_NA']
+trainset_3d = ['AGORA_MM','BEDLAM', 'COCO_NA']
+trainset_2d = []
+trainset_partition = {
+ 'AGORA_MM': 0.4,
+ 'BEDLAM': 0.7,
+ 'COCO_NA': 1,
+
+ # 'EgoBody_Egocentric': 1,
+ # 'EgoBody_Kinect': 1.0,
+ }
+trainset_humandata = []
+testset = 'INFERENCE'
+train_sizes=[480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
+train_max_size=1333
+test_sizes=[800]
+test_max_size=1333
+no_aug=False
+# model
+use_cache = True
+
+## UBody setting
+train_sample_interval = 10
+test_sample_interval = 100
+make_same_len = False
+
+## input, output size
+input_body_shape = (256, 192)
+output_hm_shape = (16, 16, 12)
+input_hand_shape = (256, 256)
+output_hand_hm_shape = (16, 16, 16)
+output_face_hm_shape = (8, 8, 8)
+input_face_shape = (192, 192)
+focal = (5000, 5000) # virtual focal lengths
+princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2
+ ) # virtual principal point position
+body_3d_size = 2
+hand_3d_size = 0.3
+face_3d_size = 0.3
+camera_3d_size = 2.5
+
+bbox_ratio = 1.2
+
+## directory
+output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
+
+agora_benchmark = 'na' # 'agora_model', 'test_only'
+
+# strategy
+data_strategy = 'balance' # 'balance' need to define total_data_len
+total_data_len = 'auto'
\ No newline at end of file
diff --git a/config/aios_smplx_pretrain.py b/config/aios_smplx_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4c6e51bffd61e020b94ce25c61193a1b6c22369
--- /dev/null
+++ b/config/aios_smplx_pretrain.py
@@ -0,0 +1,264 @@
+num_classes = 2
+lr = 0.0001
+param_dict_type = 'default'
+lr_backbone = 1e-05
+lr_backbone_names = ['backbone.0']
+lr_linear_proj_names = ['reference_points', 'sampling_offsets']
+lr_linear_proj_mult = 0.1
+ddetr_lr_param = False
+batch_size = 2
+weight_decay = 0.0001
+epochs = 200
+lr_drop = 11
+save_checkpoint_interval = 1
+clip_max_norm = 0.1
+onecyclelr = False
+multi_step_lr = True
+lr_drop_list = [30, 60]
+
+modelname = 'aios_smplx'
+frozen_weights = None
+backbone = 'resnet50'
+use_checkpoint = False
+
+dilation = False
+position_embedding = 'sine'
+pe_temperatureH = 20
+pe_temperatureW = 20
+return_interm_indices = [1, 2, 3]
+backbone_freeze_keywords = None
+enc_layers = 6
+dec_layers = 6
+pre_norm = False
+dim_feedforward = 2048
+hidden_dim = 256
+dropout = 0.0
+nheads = 8
+num_queries = 900
+query_dim = 4
+num_patterns = 0
+random_refpoints_xy = False
+fix_refpoints_hw = -1
+dec_layer_number = None
+num_feature_levels = 4
+enc_n_points = 4
+dec_n_points = 4
+dln_xy_noise = 0.2
+dln_hw_noise = 0.2
+two_stage_type = 'standard'
+two_stage_bbox_embed_share = False
+two_stage_class_embed_share = False
+two_stage_learn_wh = False
+two_stage_default_hw = 0.05
+two_stage_keep_all_tokens = False
+rm_detach = None
+num_select = 50
+transformer_activation = 'relu'
+batch_norm_type = 'FrozenBatchNorm2d'
+
+masks = False
+losses = ["smpl_pose", "smpl_beta", "smpl_expr",
+ "smpl_kp2d","smpl_kp3d","smpl_kp3d_ra",'labels', 'boxes', "keypoints"]
+# losses = ['labels', 'boxes', "keypoints"]
+aux_loss = True
+set_cost_class = 2.0
+set_cost_bbox = 5.0
+set_cost_giou = 2.0
+set_cost_keypoints = 10.0
+set_cost_kpvis = 0.0
+set_cost_oks = 4.0
+cls_loss_coef = 2.0
+# keypoints_loss_coef = 10.0
+
+smpl_pose_loss_root_coef = 10 * 0.1
+smpl_pose_loss_body_coef = 1 * 0.1
+smpl_pose_loss_lhand_coef = 1 * 0.1
+smpl_pose_loss_rhand_coef = 1 * 0.1
+smpl_pose_loss_jaw_coef = 1 * 0.1
+smpl_beta_loss_coef = 0.01
+smpl_expr_loss_coef = 0.01
+
+# smpl_kp3d_loss_coef = 10
+smpl_body_kp3d_loss_coef = 10.0 * 0.1
+smpl_face_kp3d_loss_coef = 1.0 * 0.1
+smpl_lhand_kp3d_loss_coef = 1 * 0.1
+smpl_rhand_kp3d_loss_coef = 1 * 0.1
+
+# kp3d ra
+smpl_body_kp3d_ra_loss_coef = 10 * 0.1
+smpl_face_kp3d_ra_loss_coef = 1 * 0.1
+smpl_lhand_kp3d_ra_loss_coef = 1 * 0.1
+smpl_rhand_kp3d_ra_loss_coef = 1 * 0.1
+
+
+# smpl_kp2d_ba_loss_coef = 1.0
+smpl_body_kp2d_loss_coef = 10.0 * 0.1
+smpl_lhand_kp2d_loss_coef = 5.0 * 0.1
+smpl_rhand_kp2d_loss_coef = 5.0 * 0.1
+smpl_face_kp2d_loss_coef = 1.0 * 0.1
+
+smpl_body_kp2d_ba_loss_coef = 0 * 0.1
+smpl_face_kp2d_ba_loss_coef = 0 * 0.1
+smpl_lhand_kp2d_ba_loss_coef = 0 * 0.1
+smpl_rhand_kp2d_ba_loss_coef = 0 * 0.1
+
+bbox_loss_coef = 5.0
+body_bbox_loss_coef = 5.0
+lhand_bbox_loss_coef = 5.0
+rhand_bbox_loss_coef = 5.0
+face_bbox_loss_coef = 5.0
+
+giou_loss_coef = 2.0
+body_giou_loss_coef = 2.0
+rhand_giou_loss_coef = 2.0
+lhand_giou_loss_coef = 2.0
+face_giou_loss_coef = 2.0
+
+keypoints_loss_coef = 10.0
+rhand_keypoints_loss_coef = 10.0
+lhand_keypoints_loss_coef = 10.0
+face_keypoints_loss_coef = 10.0
+
+oks_loss_coef=4.0
+rhand_oks_loss_coef = 0.5
+lhand_oks_loss_coef = 0.5
+face_oks_loss_coef = 4.0
+
+
+enc_loss_coef = 1.0
+interm_loss_coef = 1.0
+no_interm_box_loss = False
+focal_alpha = 0.25
+rm_self_attn_layers = None
+indices_idx_list = [1, 2, 3, 4, 5, 6, 7]
+
+decoder_sa_type = 'sa'
+matcher_type = 'HungarianMatcher'
+decoder_module_seq = ['sa', 'ca', 'ffn']
+nms_iou_threshold = -1
+
+dec_pred_bbox_embed_share = False
+dec_pred_class_embed_share = False
+dec_pred_pose_embed_share = False
+body_only = True
+
+# for dn
+use_dn = True
+dn_number = 100
+dn_box_noise_scale = 0.4
+dn_label_noise_ratio = 0.5
+embed_init_tgt = False
+dn_label_coef = 0.3
+dn_bbox_coef = 0.5
+dn_batch_gt_fuse = False
+dn_attn_mask_type_list = ['match2dn', 'dn2dn', 'group2group']
+dn_labelbook_size = 100
+
+match_unstable_error = False
+
+# for ema
+use_ema = True
+ema_decay = 0.9997
+ema_epoch = 0
+
+cls_no_bias = False
+num_body_points = 17 # for coco
+num_hand_points = 6 # for coco
+num_face_points = 6 # for coco
+num_group = 100
+num_box_decoder_layers = 2
+num_hand_face_decoder_layers = 4
+no_mmpose_keypoint_evaluator = True
+strong_aug = False
+
+body_model_test=\
+ dict(
+ type='smplx',
+ keypoint_src='smplx',
+ num_expression_coeffs=10,
+ num_betas=10,
+ keypoint_dst='smplx_137',
+ model_path='data/body_models/smplx',
+ use_pca=False,
+ use_face_contour=True)
+
+body_model_train = \
+ dict(
+ type='smplx',
+ keypoint_src='smplx',
+ num_expression_coeffs=10,
+ num_betas=10,
+ keypoint_dst='smplx_137',
+ model_path='data/body_models/smplx',
+ use_pca=False,
+ use_face_contour=True)
+
+# will be update in exp
+exp_name = 'output/exp52/dataset_debug'
+
+
+end_epoch = 150
+train_batch_size = 32
+
+scheduler = 'step'
+step_size = 20
+gamma = 0.1
+
+# continue
+continue_train = True
+pretrained_model_path = '../output/train_gta_synbody_ft_20230410_132110/model_dump/snapshot_2.pth.tar'
+
+# dataset setting
+# dataset_list = ['AGORA_MM','BEDLAM', 'COCO_NA']
+# trainset_3d = ['AGORA_MM','BEDLAM', 'COCO_NA']
+dataset_list = ['AGORA_MM','BEDLAM', 'COCO_NA']
+trainset_3d = ['AGORA_MM','BEDLAM', 'COCO_NA']
+trainset_2d = []
+trainset_partition = {
+ 'AGORA_MM': 0.4,
+ 'BEDLAM': 0.7,
+ 'COCO_NA': 1,
+
+ # 'EgoBody_Egocentric': 1,
+ # 'EgoBody_Kinect': 1.0,
+ }
+trainset_humandata = []
+testset = 'AGORA_MM'
+train_sizes=[480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
+train_max_size=1333
+test_sizes=[800]
+test_max_size=1333
+no_aug=False
+# model
+use_cache = True
+
+## UBody setting
+train_sample_interval = 10
+test_sample_interval = 100
+make_same_len = False
+
+## input, output size
+input_body_shape = (256, 192)
+output_hm_shape = (16, 16, 12)
+input_hand_shape = (256, 256)
+output_hand_hm_shape = (16, 16, 16)
+output_face_hm_shape = (8, 8, 8)
+input_face_shape = (192, 192)
+focal = (5000, 5000) # virtual focal lengths
+princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2
+ ) # virtual principal point position
+body_3d_size = 2
+hand_3d_size = 0.3
+face_3d_size = 0.3
+camera_3d_size = 2.5
+
+bbox_ratio = 1.2
+
+## directory
+output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
+
+agora_benchmark = 'na' # 'agora_model', 'test_only'
+
+# strategy
+data_strategy = 'balance' # 'balance' need to define total_data_len
+total_data_len = 'auto'
\ No newline at end of file
diff --git a/config/config.py b/config/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fce085a9c5ca502df03c408a63971757001950f
--- /dev/null
+++ b/config/config.py
@@ -0,0 +1,91 @@
+import os
+import os.path as osp
+import sys
+import datetime
+from mmcv import Config as MMConfig
+
+class Config(MMConfig):
+ def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
+ super().__init__(cfg_dict, cfg_text, filename)
+
+ def get_config_fromfile(self, config_path):
+ self.config_path = config_path
+
+ cfg, _ = MMConfig._file2dict(self.config_path)
+
+ self.merge_from_dict(cfg)
+ # #import ipdb;ipdb.set_trace()
+ # self.__dict__.update(dict(cfg))
+ # # update dir
+ dir_dict = {}
+ exp_name = 'exps62'
+ time_str = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
+ dir_dict['cur_dir'] = osp.dirname(os.path.abspath(__file__))
+ dir_dict['root_dir'] = osp.join(dir_dict['cur_dir'], '..')
+ dir_dict['output_dir'] = osp.join(dir_dict['root_dir'], exp_name)
+ dir_dict['result_dir'] = osp.join(dir_dict['output_dir'], 'result')
+ dir_dict['data_dir'] = osp.join(dir_dict['root_dir'], 'dataset')
+ dir_dict['human_model_path'] = osp.join('data/body_models')
+ self.merge_from_dict(dir_dict)
+ #
+ # ## add some paths to the system root dir
+ sys.path.insert(0, osp.join(self.root_dir, 'common'))
+ sys.path.insert(0, osp.join(self.root_dir, 'united-perception_utils'))
+ sys.path.insert(0, osp.join(self.cur_dir, 'humanbench_utils'))
+ sys.path.insert(0, osp.join(self.cur_dir, 'dinov2_utils'))
+ sys.path.insert(0, osp.join(self.cur_dir, 'lora_utils'))
+ sys.path.insert(0, osp.join(self.cur_dir, 'vit_adapter_utils'))
+ from util.dir import add_pypath
+ # add_pypath(osp.join(self.data_dir))
+ for dataset in os.listdir('datasets'):
+ if dataset not in ['humandata.py', '__pycache__', 'dataset.py']:
+ add_pypath(osp.join(self.root_dir, 'data', dataset))
+ add_pypath('datasets')
+ add_pypath(self.data_dir)
+
+ def prepare_dirs(self, exp_name):
+ time_str = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
+ self.output_dir = osp.join(self.root_dir, f'{exp_name}_{time_str}')
+ self.model_dir = osp.join(self.output_dir, 'model_dump')
+ self.vis_dir = osp.join(self.output_dir, 'vis')
+ self.log_dir = osp.join(self.output_dir, 'log')
+ self.code_dir = osp.join(self.output_dir, 'code')
+ self.result_dir = osp.join(self.output_dir.split('/')[:-1])
+ from util.dir import make_folder
+ make_folder(self.model_dir)
+ make_folder(self.vis_dir)
+ make_folder(self.log_dir)
+ make_folder(self.code_dir)
+ make_folder(self.result_dir)
+
+ ## copy some code to log dir as a backup
+ copy_files = [
+ 'main/train.py', 'main/test.py', 'common/base.py', 'main/OSX.py',
+ 'common/nets', 'main/OSX_WoDecoder.py', 'data/dataset.py',
+ 'data/MSCOCO/MSCOCO.py', 'data/AGORA/AGORA.py'
+ ]
+ for file in copy_files:
+ os.system(f'cp -r {self.root_dir}/{file} {self.code_dir}')
+
+ def update_test_config(self, testset, agora_benchmark, shapy_eval_split,
+ pretrained_model_path, use_cache):
+ self.testset = testset
+ self.agora_benchmark = agora_benchmark
+ self.pretrained_model_path = pretrained_model_path
+ self.shapy_eval_split = shapy_eval_split
+ self.use_cache = use_cache
+
+ def update_config(self, num_gpus, exp_name):
+ self.num_gpus = num_gpus
+ self.exp_name = exp_name
+
+ self.prepare_dirs(self.exp_name)
+
+ # Save
+ cfg_save = MMConfig(self.__dict__)
+ cfg_save.dump(osp.join(self.code_dir, 'config_base.py'))
+
+
+cfg = Config()
+cfg.get_config_fromfile('config/aios_smplx.py')
+
diff --git a/data/body_models/J_regressor_extra.npy b/data/body_models/J_regressor_extra.npy
new file mode 100644
index 0000000000000000000000000000000000000000..d6cf8c0f6747d3c623a0d300c5176843ae99031d
--- /dev/null
+++ b/data/body_models/J_regressor_extra.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc968ea4f9855571e82f90203280836b01f13ee42a8e1b89d8d580b801242a89
+size 496160
diff --git a/data/body_models/J_regressor_h36m.npy b/data/body_models/J_regressor_h36m.npy
new file mode 100644
index 0000000000000000000000000000000000000000..d8ea80f7f2fa4c3fde21c543d28376b84e22d77a
--- /dev/null
+++ b/data/body_models/J_regressor_h36m.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c655cd7013d7829eb9acbebf0e43f952a3fa0305a53c35880e39192bfb6444a0
+size 937168
diff --git a/data/body_models/J_regressor_mano_LEFT.txt b/data/body_models/J_regressor_mano_LEFT.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a392696c2a8ddd2af11ad9821f2d60352c4f4590
--- /dev/null
+++ b/data/body_models/J_regressor_mano_LEFT.txt
@@ -0,0 +1,1902 @@
+# 21 778
+0 4 0.0019103600293901542
+0 5 0.0027920646583394562
+0 6 0.00029390154298310065
+0 7 0.00014695077149155033
+0 25 0.0016164584864070536
+0 26 0.000440852314474651
+0 32 0.011756061719324026
+0 33 0.021234386480529024
+0 34 0.019838354151359296
+0 35 0.016311535635562088
+0 36 0.015870683321087434
+0 37 0.02343864805290228
+0 38 0.01671565025716385
+0 39 0.020499632623071272
+0 40 0.005437178545187362
+0 41 0.010139603232916973
+0 42 0.002645113886847906
+0 43 0.00014695077149155033
+0 44 0.02005878030859662
+0 45 0.02233651726671565
+0 50 0.01763409257898604
+0 51 0.01704628949301984
+0 52 0.019838354151359296
+0 53 0.02079353416605437
+0 54 0.00822924320352682
+0 55 0.00822924320352682
+0 78 0.011572373254959589
+0 79 0.011939750183688464
+0 84 0.01704628949301984
+0 85 0.019691403379867745
+0 88 0.005437178545187362
+0 89 0.0007347538574577516
+0 90 0.014548126377663484
+0 91 0.018736223365172666
+0 92 0.011645848640705364
+0 106 0.018515797207935343
+0 107 0.02204261572373255
+0 108 0.012417340191036004
+0 109 0.009992652461425423
+0 110 0.016311535635562088
+0 111 0.01880969875091844
+0 112 0.0073475385745775165
+0 113 0.0014695077149155032
+0 114 0.005731080088170463
+0 116 0.02204261572373255
+0 117 0.012123438648052902
+0 118 0.013005143277002204
+0 119 0.016385011021307863
+0 120 0.008155767817781044
+0 121 0.011315209404849376
+0 122 0.009037472446730345
+0 130 0.0073475385745775165
+0 131 0.00911094783247612
+0 178 0.001763409257898604
+0 179 0.002351212343864805
+0 190 0.019544452608376194
+0 191 0.019691403379867745
+0 192 0.01704628949301984
+0 193 0.016605437178545186
+0 200 0.002351212343864805
+0 203 0.00822924320352682
+0 204 0.007641440117560617
+0 205 0.01704628949301984
+0 207 0.001763409257898604
+0 208 0.005290227773695812
+0 209 0.01763409257898604
+0 210 0.019691403379867745
+0 211 0.019691403379867745
+0 214 0.011315209404849376
+0 215 0.011315209404849376
+0 216 0.007641440117560617
+0 217 0.00822924320352682
+0 218 0.002351212343864805
+0 219 0.0011756061719324026
+0 227 0.002351212343864805
+0 229 0.007788390889052168
+0 231 0.002204261572373255
+0 232 0.016311535635562088
+0 233 0.006759735488611315
+0 234 0.011168258633357825
+0 235 0.019544452608376194
+0 236 0.0016164584864070536
+0 239 0.011315209404849376
+0 241 0.0007347538574577516
+0 242 0.002351212343864805
+0 243 0.0036737692872887582
+0 244 0.0011756061719324026
+0 254 0.0064658339456282144
+0 255 0.0038207200587803084
+0 256 0.002351212343864805
+0 257 0.002351212343864805
+0 264 0.014107274063188832
+0 265 0.00440852314474651
+0 279 0.011315209404849376
+0 284 0.00896399706098457
+0 285 0.0029390154298310064
+1 0 0.014595751184471957
+1 1 0.025294207550053488
+1 2 0.019180803912578332
+1 3 0.01039278618370778
+1 4 0.03156044627846554
+1 5 0.025752712822864135
+1 6 0.014977838911814154
+1 7 0.023307351367874065
+1 8 0.005654898364664528
+1 9 0.009170105456212748
+1 10 0.002063273727647868
+1 11 0.0006113403637475165
+1 12 0.0018340210912425497
+1 14 0.001222680727495033
+1 15 7.641754546843957e-05
+1 16 0.0011462631820265935
+1 17 0.0004585052728106374
+1 18 0.00015283509093687913
+1 19 0.0003820877273421978
+1 22 7.641754546843957e-05
+1 24 0.01413724591166132
+1 25 0.019257221458046772
+1 26 0.024377197004432218
+1 27 0.017346782821335782
+1 28 0.0007641754546843956
+1 29 0.0022161088185847473
+1 30 0.0006877579092159561
+1 31 0.0005349228182790769
+1 32 0.0005349228182790768
+1 33 0.0005349228182790769
+1 34 0.0024071526822558465
+1 35 0.002445361454990066
+1 36 0.029802842732691428
+1 37 0.022122879413113253
+1 38 0.010029802842732692
+1 39 0.02334556014060829
+1 40 0.029344337459880795
+1 41 0.032171786642213054
+1 42 0.02009781445819961
+1 43 0.009934280910897143
+1 60 0.004355800091701055
+1 61 0.00855876509246523
+1 62 0.0004585052728106374
+1 63 0.003285954455142901
+1 64 0.0012990982729634726
+1 65 7.641754546843957e-05
+1 66 0.0019868561821794286
+1 67 0.004814305364511693
+1 68 0.008253094910591475
+1 69 0.0018340210912425497
+1 70 0.0003820877273421978
+1 71 7.641754546843957e-05
+1 88 0.021320495185694635
+1 89 0.013907993275256002
+1 90 0.01986856182179429
+1 91 0.013564114320648022
+1 92 0.003763564114320649
+1 93 0.0004585052728106374
+1 94 0.008329512456059913
+1 95 0.007565337001375517
+1 104 0.0027510316368638244
+1 105 0.0072596668195017595
+1 109 0.009705028274491823
+1 110 0.005654898364664528
+1 111 0.015436344184624792
+1 112 0.019180803912578332
+1 113 0.03339446736970809
+1 114 0.0340058077334556
+1 115 0.02559987773192725
+1 116 0.008405930001528351
+1 117 0.0017767079321412199
+1 118 0.00527281063732233
+1 119 0.00032477456824086816
+1 122 0.004967140455448571
+1 123 0.007259666819501758
+1 124 0.0016811860003056705
+1 125 0.0025217790004585057
+1 126 0.008176677365123033
+1 129 0.00030567018187375826
+1 145 0.00030567018187375826
+1 146 0.0006877579092159561
+1 147 7.641754546843957e-05
+1 152 7.641754546843957e-05
+1 157 0.002063273727647868
+1 158 0.0016047684548372307
+1 159 0.0032095369096744614
+1 188 0.0007641754546843956
+1 190 0.0019868561821794286
+1 191 0.0004585052728106374
+1 192 0.0016047684548372307
+1 193 0.005884151001069847
+1 207 0.00015283509093687913
+1 208 7.641754546843957e-05
+1 209 0.00030567018187375826
+1 216 0.0008405930001528353
+1 217 0.003897294818890417
+1 218 0.0008405930001528353
+1 219 0.0014519333639003516
+1 227 0.005502063273727648
+1 229 0.008635182637933671
+1 230 0.004126547455295736
+1 231 0.009705028274491824
+1 232 0.01245605991135565
+1 233 0.016888277548525142
+1 234 0.001413724591166132
+1 235 0.005654898364664528
+1 236 0.012838147638697846
+1 239 0.00026746140913953847
+1 240 0.01543634418462479
+1 241 0.0006877579092159561
+1 242 0.0032095369096744614
+1 248 0.004890722909980132
+1 249 0.0005349228182790769
+1 250 0.0015283509093687911
+1 251 0.0009170105456212748
+1 252 0.0029038667278007036
+1 253 0.005502063273727649
+1 254 0.0019868561821794286
+1 255 0.0002292526364053187
+1 264 0.028885832187070158
+1 265 0.029650007641754548
+1 266 0.006953996637628001
+1 267 0.002445361454990066
+1 268 0.00015283509093687913
+1 285 0.010087116001834023
+1 286 0.007794589637780836
+1 287 0.0025981965459269452
+1 697 0.0004585052728106374
+1 699 7.641754546843957e-05
+1 700 0.00030567018187375826
+1 704 0.0002292526364053187
+1 705 0.0008405930001528353
+1 706 7.641754546843957e-05
+2 0 0.0027531810402559712
+2 1 0.0034972840241089364
+2 2 0.007887491628841432
+2 3 0.0056551826772825355
+2 4 0.009152466701391472
+2 5 0.01674231713669172
+2 6 0.02708534861224793
+2 7 0.02209985862043307
+2 8 0.00833395341915321
+2 9 0.009152466701391472
+2 10 0.011682416846491553
+2 11 0.0055063620805119425
+2 12 0.005431951782126646
+2 13 0.0011161544757794478
+2 14 0.006176054765979612
+2 15 0.0017858471612471167
+2 16 0.0007441029838529652
+2 19 0.0003720514919264826
+2 26 0.000967333879008855
+2 27 0.0008929235806235583
+2 28 0.013245033112582783
+2 29 0.013765905201279856
+2 30 0.009970979983629735
+2 31 0.011384775652950369
+2 36 0.0023811295483294886
+2 37 0.00014882059677059304
+2 38 7.441029838529652e-05
+2 39 0.0020834883547883026
+2 40 0.0055063620805119425
+2 41 0.009896569685244438
+2 42 0.022843961604286034
+2 43 0.032666120991145166
+2 60 0.00364610462087953
+2 61 0.0017858471612471167
+2 62 0.0002976411935411861
+2 63 0.000967333879008855
+2 64 0.0014882059677059304
+2 65 0.0004464617903117792
+2 68 0.0002976411935411861
+2 69 7.441029838529652e-05
+2 88 0.01562616266091227
+2 89 0.027234169209018527
+2 90 0.00513431058858546
+2 91 0.0006696926854676687
+2 93 7.441029838529652e-05
+2 94 0.0005952823870823722
+2 104 0.025225091152615526
+2 105 0.017858471612471165
+2 113 0.0035716943224942334
+2 114 0.002604360443485378
+2 115 0.010566262370712107
+2 123 0.026787707418706754
+2 124 0.021504576233350697
+2 125 0.01882580549148002
+2 126 0.02083488354788303
+2 127 0.0002232308951558896
+2 128 0.0002976411935411861
+2 129 0.0017114368628618197
+2 144 0.0002232308951558896
+2 145 0.0013393853709353374
+2 158 0.002604360443485378
+2 193 0.0003720514919264826
+2 217 0.0007441029838529652
+2 219 0.0004464617903117792
+2 227 0.003199642830567751
+2 229 0.003125232532182454
+2 230 0.008854825507850286
+2 231 0.00982215938685914
+2 232 0.002009078056403006
+2 233 0.007813081330456134
+2 235 7.441029838529652e-05
+2 236 0.01912344668502121
+2 240 0.01480764937867401
+2 248 0.03318699307984225
+2 249 0.01823052310439765
+2 250 0.02887119577349505
+2 251 0.02500186025745963
+2 252 0.02864796487833916
+2 253 0.032889351886301064
+2 259 0.00014882059677059304
+2 264 0.0002232308951558896
+2 265 0.0005952823870823722
+2 266 0.015402931765756382
+2 267 0.01622144504799464
+2 286 0.02805268249125679
+2 287 0.025820373539697895
+2 697 0.014510008185132822
+2 698 0.008631594612694398
+2 699 0.011161544757794479
+2 700 0.01049185207232681
+2 701 0.00811072252399732
+2 702 0.013393853709353377
+2 703 0.010938313862638589
+2 704 0.008185132822382618
+2 705 0.02187662772527718
+2 706 0.018825805491480024
+2 707 0.011905647741647447
+2 708 0.007217798943373763
+2 709 0.005059900290200163
+2 710 0.003199642830567751
+2 711 0.0019346677580177095
+2 712 0.005952823870823722
+2 713 0.00364610462087953
+2 714 0.00364610462087953
+2 715 0.0026787707418706747
+2 716 0.0021578986531735995
+2 721 0.0006696926854676687
+2 722 0.0002232308951558896
+2 723 0.0002232308951558896
+2 725 0.0004464617903117792
+2 731 0.0032740531289530473
+2 732 0.0008185132822382618
+2 741 0.0005952823870823722
+2 742 0.0005208720886970756
+2 746 0.0002232308951558896
+2 749 0.0005208720886970756
+2 753 0.0034972840241089364
+2 754 0.004018156112806012
+2 755 0.0014882059677059304
+2 757 0.0008929235806235583
+2 758 0.0014137956693206339
+2 759 0.0003720514919264826
+2 760 7.441029838529652e-05
+3 6 0.0019164148301024542
+3 7 0.0014004569912287167
+3 8 0.000884499152354979
+3 9 0.00029483305078499295
+3 10 0.004422495761774894
+3 11 0.0011793322031399718
+3 12 0.0005896661015699859
+3 14 0.0011056239404437236
+3 28 0.011203655929829732
+3 29 0.0037591213975086604
+3 30 0.004496204024471142
+3 31 0.011645905506007222
+3 43 0.0019164148301024544
+3 89 0.0005896661015699859
+3 104 0.009729490675904768
+3 105 0.002137539618191199
+3 123 0.006412618854573597
+3 124 0.0187956069875433
+3 125 0.013414903810717178
+3 126 0.004938453600648632
+3 230 0.0007370826269624824
+3 231 0.00022112478808874474
+3 236 0.0005159578388737376
+3 240 0.0008844991523549787
+3 248 0.007665659320409817
+3 249 0.013120070759932186
+3 250 0.009434657625119773
+3 251 0.012088155082184712
+3 252 0.004348787499078646
+3 253 0.003022038770546178
+3 266 0.0029483305078499295
+3 267 0.0125304046583622
+3 286 0.002727205719761185
+3 287 0.005896661015699859
+3 697 0.01805852436058082
+3 698 0.019016731775632047
+3 699 0.021375396181911987
+3 700 0.01968010613989828
+3 701 0.023512935800103187
+3 702 0.01975381440259453
+3 703 0.021965062283481978
+3 704 0.019164148301024544
+3 705 0.015331318640819633
+3 706 0.017837399572492075
+3 707 0.02889363897692931
+3 708 0.02130168791921574
+3 709 0.027050932409523103
+3 710 0.024544851477850665
+3 711 0.0209331466057345
+3 712 0.0232181027493182
+3 713 0.023070686223925697
+3 714 0.024102601901673175
+3 715 0.018353357411365814
+3 716 0.017026608682833344
+3 717 0.0016952900420137097
+3 718 0.0062652023291811
+3 719 0.0033168718213311705
+3 720 0.00125304046583622
+3 721 0.016879192157440846
+3 722 0.01090882287904474
+3 723 0.008402741947372299
+3 724 0.004717328812559887
+3 725 0.010982531141740989
+3 726 0.0033168718213311705
+3 727 0.0008107908896587306
+3 730 7.370826269624824e-05
+3 731 0.022775853173140702
+3 732 0.018279649148669565
+3 733 0.009803198938601014
+3 734 0.003022038770546178
+3 735 0.0003685413134812412
+3 736 0.011719613768703471
+3 737 0.003906537922901157
+3 738 0.0008107908896587306
+3 739 0.013488612073413427
+3 740 0.005306994914129874
+3 741 0.021301687919215745
+3 742 0.019606397877202027
+3 743 0.0022112478808874476
+3 746 0.006338910591877348
+3 747 0.00125304046583622
+3 748 0.0016952900420137097
+3 749 0.009876907201297264
+3 750 0.003022038770546178
+3 751 7.370826269624824e-05
+3 753 0.025208225842116898
+3 754 0.0209331466057345
+3 755 0.023291811012014444
+3 756 0.017837399572492075
+3 757 0.021449104444608236
+3 758 0.01975381440259453
+3 759 0.01171961376870347
+3 760 0.01348861207341343
+3 761 0.003906537922901157
+3 762 0.005306994914129872
+3 763 0.007960492371194809
+3 764 0.0008107908896587306
+3 765 0.0003685413134812412
+3 767 0.0022112478808874476
+3 768 0.0011056239404437238
+4 745 1.0
+5 0 0.0012638674343491084
+5 1 0.0001404297149276787
+5 2 0.00035107428731919675
+5 3 0.002808594298553574
+5 8 0.004072461732902682
+5 9 0.0007723634321022329
+5 10 0.004774610307541076
+5 11 0.01418340120769555
+5 12 0.012357814913635726
+5 13 0.01930908580255582
+5 14 0.007934278893413846
+5 15 0.020011234377194213
+5 16 0.0021064457239151806
+5 17 0.0006319337171745541
+5 18 0.0022468754388428594
+5 19 0.009127931470299114
+5 21 0.00042128914478303613
+5 24 0.0009127931470299115
+5 25 7.021485746383936e-05
+5 26 0.0001404297149276787
+5 27 0.0010532228619575903
+5 28 0.0004212891447830361
+5 29 0.0015447268642044658
+5 30 0.003932032017975004
+5 31 0.0009127931470299115
+5 46 0.0006319337171745542
+5 47 0.00035107428731919675
+5 48 0.003721387445583485
+5 49 0.0027383794410897346
+5 56 0.0002808594298553574
+5 57 7.021485746383936e-05
+5 58 0.0010532228619575903
+5 59 0.0028788091560174134
+5 60 0.010040724617329027
+5 61 0.005687403454570988
+5 62 0.029981744137059403
+5 63 0.017483499508496
+5 64 0.02029209380704957
+5 65 0.024294340682488414
+5 66 0.0029490240134812527
+5 67 0.0011234377194214297
+5 68 0.005827833169498665
+5 69 0.00975986518747367
+5 74 0.00217666058137902
+5 75 0.0010532228619575903
+5 76 0.00035107428731919675
+5 77 0.00021064457239151807
+5 86 0.0007723634321022329
+5 87 0.0021064457239151806
+5 93 0.018536722370453586
+5 94 0.0016851565791321445
+5 95 0.0001404297149276787
+5 104 7.021485746383936e-05
+5 105 0.0001404297149276787
+5 127 0.023592192107850022
+5 128 0.02710293498104199
+5 129 0.020713382951832608
+5 132 0.023030473248139307
+5 133 0.005195899452324112
+5 134 0.005195899452324112
+5 135 0.01305996348827412
+5 136 0.008495997753124563
+5 137 0.014323830922623225
+5 138 0.01818564808313439
+5 139 0.011515236624069652
+5 140 0.008215138323269205
+5 143 0.010742873191967421
+5 144 0.016991995506249125
+5 145 0.010040724617329027
+5 146 0.00035107428731919675
+5 147 0.0011234377194214297
+5 149 0.013832326920376354
+5 150 0.016430276646538407
+5 151 0.010181154332256704
+5 152 0.011023732621822779
+5 155 0.00035107428731919675
+5 156 0.001966016008987502
+5 157 7.021485746383936e-05
+5 158 0.003932032017975004
+5 164 0.0034405280157281284
+5 165 0.005195899452324111
+5 166 0.0014745120067406266
+5 167 0.0014745120067406264
+5 168 0.026049712119084405
+5 169 0.02927959556242101
+5 170 0.023873051537705376
+5 171 0.016008987501755372
+5 172 0.027102934981041993
+5 173 0.016921780648785283
+5 174 0.005546973739643309
+5 175 0.005406544024715631
+5 176 0.013551467490520995
+5 177 0.00758320460609465
+5 183 7.021485746383936e-05
+5 185 0.009127931470299114
+5 186 0.017834573795815194
+5 187 0.008074708608341525
+5 189 0.007161915461311614
+5 194 0.010602443477039742
+5 195 0.01060244347703974
+5 206 0.0013340822918129478
+5 212 0.007091700603847775
+5 213 0.0013340822918129476
+5 219 0.0002808594298553574
+5 220 0.00435332116275804
+5 222 0.0002808594298553574
+5 223 0.00042128914478303613
+5 225 0.0016851565791321445
+5 226 0.00042128914478303613
+5 227 0.000983008004493751
+5 228 0.00975986518747367
+5 230 0.001825586294059823
+5 231 7.021485746383936e-05
+5 246 0.00035107428731919675
+5 258 0.020924027524224127
+5 259 0.022398539530964757
+5 260 0.015587698356972338
+5 261 0.012568459486027245
+5 262 0.009619435472545991
+5 263 0.01305996348827412
+5 266 0.0010532228619575903
+5 267 0.0005617188597107148
+5 268 0.004283106305294201
+5 269 0.0017553714365959837
+5 270 0.005266114309787951
+5 271 0.004844825165004915
+5 274 0.018045218368206713
+5 276 0.0002808594298553574
+5 277 0.00021064457239151807
+5 280 0.0001404297149276787
+5 288 0.00540654402471563
+5 290 7.021485746383936e-05
+5 358 0.0002808594298553574
+5 359 0.00035107428731919675
+5 362 0.00021064457239151807
+5 363 0.0002808594298553574
+5 365 7.021485746383936e-05
+5 366 0.0009127931470299116
+5 367 0.0013340822918129476
+5 368 0.005125684594860273
+5 369 0.0034405280157281284
+5 370 0.0013340822918129476
+5 371 0.00021064457239151807
+5 373 0.00042128914478303613
+5 375 0.00035107428731919675
+5 378 0.004493750877685719
+5 379 0.0034405280157281284
+5 380 0.004634180592613397
+5 383 0.00042128914478303613
+5 385 0.0016149417216683051
+5 386 0.001404297149276787
+5 387 0.0016851565791321445
+5 388 0.0002808594298553574
+5 399 0.0014745120067406264
+6 46 0.019904998869034157
+6 47 0.01960340797707909
+6 48 0.025559828093191583
+6 49 0.02352408957249491
+6 56 0.022166930558697125
+6 57 0.020131192038000453
+6 58 0.02194073738973083
+6 59 0.028952725627686037
+6 62 0.0005277840609213601
+6 65 0.00022619316896629722
+6 86 0.02382568046444997
+6 87 0.022543919173640955
+6 127 0.0012063635678202518
+6 128 0.0007539772298876573
+6 132 0.0006031817839101259
+6 133 0.017643067179371183
+6 134 0.02382568046444997
+6 135 0.01379778330694413
+6 136 0.01259141973912388
+6 137 0.004448465656337178
+6 138 0.003091306642539395
+6 139 0.009424715373595717
+6 140 0.012214431124180048
+6 143 0.0005277840609213601
+6 144 0.0012817612908090175
+6 150 0.0008293749528764231
+6 155 0.019678805700067855
+6 156 0.0244288622483601
+6 164 0.019980396592022914
+6 165 0.017944658071326246
+6 166 0.023222498680539848
+6 167 0.023901078187438737
+6 168 0.002789715750584332
+6 169 0.002186533966674206
+6 170 0.00987710171152831
+6 171 0.005881022393123726
+6 172 0.004071477041393349
+6 173 0.011837442509236221
+6 174 0.022166930558697128
+6 175 0.02382568046444997
+6 176 0.019377214808112796
+6 177 0.013119203800045236
+6 185 0.0016587499057528462
+6 186 0.004448465656337178
+6 187 0.0005277840609213601
+6 189 0.020809771544899342
+6 194 0.015154942320741913
+6 195 0.01839704440925884
+6 212 0.021262157882831936
+6 213 0.022317726004674656
+6 221 0.006333408731056322
+6 222 0.016210510442584633
+6 223 0.018472442132247607
+6 224 0.00987710171152831
+6 225 0.02744477116791073
+6 226 0.020583578375933047
+6 228 0.0005277840609213602
+6 237 0.012516022016135112
+6 238 0.011912840232224985
+6 245 0.011912840232224985
+6 258 0.0052024428862248355
+6 259 0.002337329412651738
+6 260 0.007162783683932745
+6 261 0.013043806077056472
+6 262 0.0016587499057528462
+6 263 0.007388976852899043
+6 272 0.014174771921887958
+6 273 0.012817612908090177
+6 274 0.0059564201161124925
+6 280 0.019301817085124028
+6 281 0.011385056171303627
+6 282 0.011460453894292393
+6 283 0.017643067179371186
+6 294 0.003920681595415819
+6 295 0.0069365905149664465
+6 296 0.0037698861494382865
+6 297 0.00512704516323607
+6 298 0.006634999623011385
+6 299 0.002789715750584332
+6 300 0.0021865339666742064
+6 301 0.0038452838724270517
+6 302 0.0005277840609213601
+6 303 0.0006031817839101259
+6 305 0.00030159089195506294
+6 316 0.0016587499057528462
+6 321 0.0009047726758651889
+6 330 0.0021111362436854408
+6 331 0.0015079544597753145
+6 340 0.00512704516323607
+6 341 0.004599261102314709
+6 342 0.0011309658448314859
+6 344 0.0007539772298876573
+6 345 0.00022619316896629722
+7 46 0.008690077640857611
+7 47 0.009188688653037966
+7 48 0.0033478167960680964
+7 49 0.0034902770852624832
+7 56 0.010898212123370611
+7 57 0.012322815015314481
+7 58 0.004202578531234419
+7 59 0.003276586651470902
+7 86 0.00648194315834461
+7 87 0.0016382933257354513
+7 133 0.00035615072298596765
+7 134 0.0015670631811382577
+7 155 0.009829759954412709
+7 156 0.004131348386637225
+7 164 0.0009259918797635161
+7 165 0.0006410713013747418
+7 166 0.003917657952845645
+7 167 0.0050573402664007405
+7 174 0.001638293325735451
+7 175 0.0014246028919438706
+7 189 0.0009259918797635161
+7 194 0.00028492057838877413
+7 195 0.0006410713013747418
+7 212 0.00042738086758316123
+7 213 0.0037039675190540643
+7 221 0.019517059619631027
+7 222 0.016739083980340477
+7 223 0.0143172590640359
+7 224 0.02443193959683738
+7 225 0.00683809388133058
+7 226 0.01111190255716219
+7 237 0.016739083980340477
+7 238 0.018092456727687157
+7 245 0.01367618776266116
+7 272 0.02236626540351877
+7 273 0.01923213904124225
+7 280 0.011040672412564997
+7 281 0.020086900776408578
+7 282 0.01859106773986751
+7 283 0.0165253935465489
+7 294 0.024004558729254222
+7 295 0.024075788873851416
+7 296 0.02443193959683738
+7 297 0.025357931476600898
+7 298 0.026283923356364414
+7 299 0.023933328584657028
+7 300 0.022722416126504736
+7 301 0.02514424104280932
+7 302 0.01738015528171522
+7 303 0.020941662511574897
+7 304 0.007835315905691288
+7 305 0.017380155281715225
+7 306 0.011396823135550965
+7 307 0.0036327373744568705
+7 308 0.0012821426027494836
+7 309 0.002777975639290548
+7 310 0.011966664292328516
+7 311 0.005342260844789515
+7 312 0.0038464278082484507
+7 313 0.0014958330365410642
+7 314 0.0007835315905691288
+7 315 0.008191466628677256
+7 316 0.022651185981907542
+7 317 0.00035615072298596765
+7 321 0.02101289265617209
+7 322 0.01225158487071729
+7 323 0.007764085761094094
+7 324 0.002564285205498967
+7 325 0.01994444048721419
+7 326 0.008690077640857611
+7 327 0.0024218249163045803
+7 328 0.0165253935465489
+7 329 0.006980554170524965
+7 330 0.028064676971294254
+7 331 0.021084122800769284
+7 332 0.0019232139041242254
+7 333 0.00021369043379158061
+7 334 0.010969442267967804
+7 335 0.0024930550609017737
+7 336 0.008690077640857611
+7 337 0.003988888097442838
+7 338 0.00028492057838877413
+7 340 0.019588289764228224
+7 341 0.0242182491630458
+7 342 0.021867654391338417
+7 343 0.014103568630244322
+7 344 0.018662297884464708
+7 345 0.014673409787021868
+7 346 0.006125792435358643
+7 347 0.009758529809815513
+7 348 0.0017095234703326447
+7 349 0.0031341263622765153
+7 350 0.004772419688011967
+7 351 0.0006410713013747418
+7 352 0.0008547617351663223
+7 353 0.00042738086758316123
+7 354 0.001068452168957903
+7 355 0.0009972220243607095
+8 317 1.0
+9 11 0.0002498906728306366
+9 13 0.0002498906728306366
+9 14 0.0009995626913225464
+9 15 0.0022490160554757294
+9 16 0.0029986880739676387
+9 17 0.002249016055475729
+9 18 0.007746610857749733
+9 19 0.00949584556756419
+9 20 0.0013743987005685012
+9 21 0.00437308677453614
+9 22 0.0009995626913225461
+9 23 0.00018741800462297744
+9 48 0.0004997813456612732
+9 59 0.0002498906728306366
+9 62 0.0014368713687761604
+9 63 0.000874617354907228
+9 64 6.247266820765915e-05
+9 65 6.247266820765915e-05
+9 66 0.0024989067283063657
+9 67 0.000437308677453614
+9 68 0.0006871993502842506
+9 69 0.0029986880739676387
+9 71 0.0004997813456612732
+9 74 0.015555694383707127
+9 75 0.017867183107390515
+9 76 0.017242456425313923
+9 77 0.00868370088086462
+9 83 6.247266820765915e-05
+9 87 0.0004997813456612732
+9 93 0.0033110514150059348
+9 127 0.0006247266820765914
+9 132 0.004810395451989753
+9 133 0.0006247266820765914
+9 135 0.0001249453364153183
+9 136 0.0004997813456612732
+9 137 0.015555694383707127
+9 138 0.007246829512088461
+9 139 0.005997376147935278
+9 140 0.008683700880864622
+9 141 0.005997376147935278
+9 142 0.0025613793965140247
+9 143 0.015743112388330104
+9 144 0.009558318235771848
+9 145 0.0032485787467982754
+9 146 0.0015618167051914785
+9 147 0.006122321484350596
+9 148 0.0025613793965140247
+9 149 0.0071843568438808006
+9 150 0.01243206097332417
+9 151 0.013993877678515648
+9 152 0.007809083525957393
+9 157 0.0001249453364153183
+9 158 0.0023114887236833884
+9 160 0.0019991253826450927
+9 161 0.0002498906728306366
+9 162 0.0005622540138689324
+9 163 0.0021240707190604106
+9 164 0.0029362154057599797
+9 165 0.002561379396514025
+9 166 0.0007496720184919098
+9 167 0.0007496720184919097
+9 168 0.002124070719060411
+9 169 0.0003123633410382957
+9 170 0.0006871993502842506
+9 171 0.002249016055475729
+9 174 0.0028737427375523207
+9 175 0.0018741800462297744
+9 176 0.009433372899356529
+9 177 0.006247266820765914
+9 181 0.00018741800462297744
+9 182 0.0009995626913225464
+9 183 0.004248141438120822
+9 185 0.019179109139751356
+9 186 0.01661772974323733
+9 187 0.019054163803336036
+9 194 0.0015618167051914785
+9 195 0.0001249453364153183
+9 196 0.0004997813456612732
+9 197 0.0014993440369838195
+9 198 0.0003748360092459549
+9 199 0.0001249453364153183
+9 202 6.247266820765915e-05
+9 206 0.013181732991816079
+9 207 6.247266820765915e-05
+9 212 0.0018741800462297742
+9 213 0.0002498906728306366
+9 218 0.0003123633410382957
+9 219 0.0006871993502842506
+9 220 0.014868495033422876
+9 225 0.0006247266820765914
+9 227 0.0006871993502842506
+9 228 0.021802961204473042
+9 230 0.0002498906728306366
+9 246 0.020803398513150495
+9 247 0.017304929093521583
+9 258 0.0004997813456612732
+9 259 0.0027487974011370024
+9 260 0.0017492347098144558
+9 261 0.002623852064721684
+9 262 0.01974136315362029
+9 263 0.01655525707502967
+9 268 0.007746610857749734
+9 269 0.02167801586805772
+9 270 0.019054163803336036
+9 271 0.011932279627662898
+9 274 0.0066221028300118695
+9 275 0.0007496720184919098
+9 276 0.016742675079652648
+9 277 0.02205285187730368
+9 288 0.022427687886549634
+9 289 0.0003123633410382957
+9 290 0.00730930218029612
+9 291 0.005685012806896982
+9 292 0.0057474854751046415
+9 293 0.008933591553695257
+9 356 0.0014993440369838195
+9 357 0.0014993440369838193
+9 358 0.00668457549821953
+9 359 0.004685450115574436
+9 360 0.0007496720184919098
+9 361 0.0007496720184919098
+9 362 0.0024989067283063657
+9 363 0.0038733054288748667
+9 364 0.0014368713687761604
+9 365 0.004498032110951459
+9 366 0.009933154245017804
+9 367 0.010245517586056099
+9 368 0.015993003061160742
+9 369 0.015993003061160742
+9 370 0.021115761854188793
+9 371 0.01693009308427563
+9 372 0.0009995626913225464
+9 373 0.0037483600924595483
+9 374 0.008996064221902918
+9 375 0.012432060973324168
+9 376 0.004498032110951458
+9 377 0.0031861060785906164
+9 378 0.017554819766352217
+9 379 0.01749234709814456
+9 380 0.01649278440682201
+9 381 0.008308864871618667
+9 382 0.006434684825388891
+9 383 0.016055475729368402
+9 384 0.012557006309739488
+9 385 0.01018304491784844
+9 386 0.015180858374461174
+9 387 0.01155744361841694
+9 388 0.009058536890110576
+9 389 0.0028112700693446614
+9 391 0.00018741800462297744
+9 392 0.0005622540138689324
+9 394 0.0018117073780221152
+9 395 0.0004997813456612732
+9 399 0.01611794839757606
+9 402 0.0008746173549072279
+9 470 0.0007496720184919098
+9 471 0.0004997813456612732
+9 478 0.0007496720184919098
+9 479 0.0004997813456612732
+9 480 0.0026863247329293434
+9 481 0.002623852064721684
+9 483 0.0001249453364153183
+9 484 0.0001249453364153183
+9 485 0.0014993440369838195
+9 486 0.0004997813456612732
+9 488 0.008996064221902916
+9 489 0.006059848816142937
+9 490 0.006497157493596552
+9 491 0.0001249453364153183
+9 492 0.0003748360092459549
+9 493 0.001311926032360842
+9 494 0.000437308677453614
+9 495 0.0017492347098144558
+9 496 0.002623852064721684
+9 497 0.0027487974011370024
+9 498 0.0006247266820765914
+9 509 0.0020615980508527517
+9 510 0.0003748360092459549
+9 579 0.0019991253826450927
+10 74 0.0005264345341054373
+10 75 0.0021809430698653833
+10 76 0.000752049334436339
+10 137 0.000827254267879973
+10 143 0.0006016394675490712
+10 150 0.0003008197337745356
+10 151 0.0006016394675490712
+10 185 0.004361886139730767
+10 186 0.0010528690682108748
+10 187 0.003910656539068963
+10 206 0.0001504098668872678
+10 220 0.0003008197337745356
+10 228 0.0030834022711889904
+10 246 0.003985861472512596
+10 247 0.0012784838685417762
+10 262 0.003910656539068963
+10 263 0.0011280740016545085
+10 269 0.0032338121380762574
+10 270 0.002857787470858088
+10 271 0.0003008197337745356
+10 276 0.000902459201323607
+10 277 0.00556516507482891
+10 288 0.0027825825374144545
+10 356 0.020305332029781156
+10 357 0.019703692562232082
+10 358 0.02549447243739189
+10 359 0.023764758968188315
+10 360 0.02587049710461006
+10 361 0.022486275099646538
+10 362 0.022411070166202904
+10 363 0.02278709483342107
+10 364 0.026321726705271865
+10 365 0.02007971722945025
+10 366 0.016093855756937656
+10 367 0.022260660299315636
+10 368 0.011882379484094157
+10 369 0.009400616680454237
+10 370 0.00962623148078514
+10 371 0.011431149883432353
+10 372 0.021583815898322933
+10 373 0.024742423102955553
+10 374 0.01947807776190118
+10 375 0.01789877415958487
+10 376 0.023388734300970146
+10 377 0.023689554034744677
+10 378 0.009400616680454237
+10 379 0.005865984808603443
+10 380 0.01135594494998872
+10 381 0.022486275099646538
+10 382 0.015341806422501316
+10 383 0.01135594494998872
+10 384 0.01158155975031962
+10 385 0.019703692562232082
+10 386 0.01504098668872678
+10 387 0.018124388959915774
+10 388 0.010077461081446944
+10 389 0.02293750470030834
+10 390 0.01383770775362864
+10 391 0.017372339625479433
+10 392 0.019703692562232086
+10 393 0.011882379484094157
+10 394 0.024667218169511923
+10 395 0.024667218169511916
+10 396 0.012333609084755958
+10 397 0.011506354816875987
+10 398 0.013236068286079568
+10 399 0.0070692637437015865
+10 400 0.01940287282845755
+10 401 0.016093855756937656
+10 402 0.020530946830112053
+10 403 0.008197337745356097
+10 404 0.01759795442581033
+10 405 0.021508610964879295
+10 406 0.008197337745356095
+10 407 0.013988117620515906
+10 408 0.008949387079792434
+10 409 0.006467624276152515
+10 410 0.005264345341054373
+10 411 0.005565165074828909
+10 412 0.003835451605625329
+10 413 0.002105738136421749
+10 414 0.0012784838685417764
+10 415 0.002556967737083553
+10 417 7.52049334436339e-05
+10 420 0.0020305332029781154
+10 421 0.0006016394675490712
+10 422 0.0006016394675490712
+10 427 7.52049334436339e-05
+10 430 0.004737910806948936
+10 431 0.002331352936752651
+10 432 0.0001504098668872678
+10 440 0.0010528690682108748
+10 441 0.0021057381364217496
+10 446 7.52049334436339e-05
+10 452 0.004512296006618034
+10 453 0.003609836805294428
+10 454 0.0006016394675490712
+10 456 0.0006016394675490712
+10 457 0.0004512296006618035
+11 356 0.011297349184080336
+11 357 0.011888060252528984
+11 358 0.004430333013364838
+11 359 0.004430333013364838
+11 360 0.009229860444510078
+11 361 0.011371188067636416
+11 362 0.0038396219449161927
+11 363 0.002805877575131064
+11 364 0.005759432917374288
+11 365 0.0014767776711216124
+11 366 0.0003691944177804031
+11 367 0.0014029387875655322
+11 372 0.011371188067636418
+11 373 0.004504171896920917
+11 374 0.0012552610204533705
+11 375 0.0011075832533412094
+11 376 0.005316399616037805
+11 377 0.005685594033818208
+11 381 0.001772133205345935
+11 382 0.0003691944177804031
+11 385 0.00118142213689729
+11 386 0.0005168721848925644
+11 387 0.0011075832533412094
+11 388 7.383888355608063e-05
+11 389 0.0031012331093553864
+11 390 0.019345787491693123
+11 391 0.010928154766299934
+11 392 0.01299564350587019
+11 393 0.02082256516281474
+11 394 0.0057594329173742895
+11 395 0.00945137709517832
+11 396 0.017352137635678947
+11 397 0.02001033744369785
+11 398 0.018238204238351912
+11 400 0.01794284870412759
+11 401 0.019124270841024884
+11 402 0.016170715498781657
+11 403 0.022816215018828915
+11 404 0.01727829875212287
+11 405 0.014546260060547885
+11 406 0.0239976371557262
+11 407 0.022963892785941076
+11 408 0.02695119249796943
+11 409 0.023776120505057962
+11 410 0.019493465258805284
+11 411 0.023849959388614037
+11 412 0.026581998080189025
+11 413 0.020601048512146496
+11 414 0.019493465258805288
+11 415 0.02163479288193162
+11 416 0.004873366314701322
+11 417 0.007900760540500627
+11 418 0.0042088163626965965
+11 419 0.0016982943217898545
+11 420 0.018238204238351912
+11 421 0.012035738019641142
+11 422 0.012331093553865465
+11 423 0.0055379162667060465
+11 424 0.004061138595584434
+11 425 0.0016982943217898542
+11 426 0.0008122277191168869
+11 427 0.00834379384183711
+11 428 0.0005168721848925643
+11 429 0.0015506165546776932
+11 430 0.023406926087277558
+11 431 0.019124270841024884
+11 432 0.016392232149449903
+11 433 0.005907110684486449
+11 434 0.0019198109724580966
+11 435 0.015432326663220851
+11 436 0.006940855054271579
+11 437 0.0013290999040094513
+11 438 0.013364837923650594
+11 439 0.00694085505427158
+11 440 0.02126559846415122
+11 441 0.02355460385438972
+11 442 0.002732038691574983
+11 444 7.383888355608063e-05
+11 446 0.010854315882743852
+11 447 0.0031012331093553864
+11 448 0.007753082773388465
+11 449 0.0018459720889020155
+11 450 0.00044303330133648377
+11 451 0.00044303330133648377
+11 452 0.023776120505057962
+11 453 0.02229934283393635
+11 454 0.02126559846415122
+11 455 0.013290999040094512
+11 456 0.018385882005464073
+11 457 0.015580004430333012
+11 458 0.010189765930739126
+11 459 0.012035738019641142
+11 460 0.0034704275271357893
+11 461 0.004578010780476998
+11 462 0.005907110684486449
+11 463 0.000590711068448645
+11 464 0.000590711068448645
+11 465 0.0002953555342243225
+11 466 0.0019936498560141768
+11 467 0.0013290999040094513
+12 445 1.0
+13 16 0.0014635288607891346
+13 17 0.002575810794988877
+13 18 0.005737033134293408
+13 19 0.001990399250673223
+13 20 0.007785973539398196
+13 21 0.008664090855871677
+13 22 0.002985598876009834
+13 23 0.002224563868399485
+13 63 5.854115443156538e-05
+13 66 0.0018147757873785268
+13 67 0.0006439526987472192
+13 68 0.0002927057721578269
+13 69 0.0008195761620419153
+13 70 0.0007024938531787846
+13 71 0.0033953869570307925
+13 72 0.0024001873316941806
+13 73 0.00023416461772626153
+13 74 0.009308043554618896
+13 75 0.007551808921671934
+13 76 0.01890879288139562
+13 77 0.013230300901533777
+13 80 0.0013464465519260039
+13 81 0.0002927057721578269
+13 82 0.0016976934785153963
+13 83 0.0040978808102095764
+13 93 0.00017562346329469617
+13 100 0.00017562346329469617
+13 102 0.00011708230886313077
+13 103 0.00035124692658939234
+13 137 0.00011708230886313077
+13 141 0.020021074815595362
+13 142 0.016625687858564567
+13 143 0.0016391523240838306
+13 144 0.0005268703898840885
+13 145 0.0002927057721578269
+13 146 0.002868516567146704
+13 147 0.006673691605198454
+13 148 0.008839714319166374
+13 149 0.0002927057721578269
+13 150 0.0002927057721578269
+13 151 0.0012293642430628731
+13 152 0.0011122819341997424
+13 157 0.0008781173164734808
+13 158 0.0004097880810209577
+13 160 0.02681184872965695
+13 161 0.023592085235920848
+13 162 0.03096827069429809
+13 163 0.02476290832455216
+13 178 0.0002927057721578269
+13 179 5.854115443156538e-05
+13 180 0.0009366584709050461
+13 181 0.00444912773679897
+13 182 0.013464465519260038
+13 183 0.0167427701674277
+13 184 0.00017562346329469617
+13 185 5.854115443156538e-05
+13 186 0.0002927057721578269
+13 187 0.0008195761620419153
+13 196 0.017503805175038047
+13 197 0.023416461772626154
+13 198 0.023416461772626154
+13 199 0.02921203606135113
+13 201 0.0018733169418100922
+13 202 0.006439526987472192
+13 206 0.015162158997775435
+13 207 0.0006439526987472192
+13 218 0.0007610350076103501
+13 219 0.00046832923545252306
+13 220 0.006673691605198454
+13 227 0.00011708230886313077
+13 228 0.0009951996253366115
+13 246 0.0106544901065449
+13 247 0.014576747453459781
+13 262 0.00011708230886313077
+13 268 0.0033368458025992264
+13 269 0.010420325488818641
+13 270 0.0035710104203254887
+13 271 0.002985598876009834
+13 275 0.009834913944502985
+13 276 0.02142606252195293
+13 277 0.01164968973188151
+13 278 0.00035124692658939234
+13 288 0.004741833508956796
+13 289 0.014693829762322912
+13 290 0.02207001522070015
+13 291 0.017913593256059006
+13 292 0.011005737033134292
+13 293 0.010478866643250203
+13 358 0.0003512469265893923
+13 363 5.854115443156538e-05
+13 365 0.00035124692658939234
+13 366 0.0007024938531787846
+13 367 0.00017562346329469617
+13 368 0.00017562346329469617
+13 369 0.0009951996253366115
+13 370 0.005151621589977754
+13 371 0.005385786207704015
+13 374 0.0016976934785153963
+13 375 0.0017562346329469615
+13 376 0.0004097880810209577
+13 377 0.0003512469265893923
+13 378 0.00046832923545252306
+13 379 0.0015220700152206996
+13 381 0.0014635288607891346
+13 382 0.0009951996253366115
+13 383 0.00532724505327245
+13 384 0.0037466338836201845
+13 386 0.0011708230886313077
+13 387 0.00011708230886313077
+13 388 0.0012293642430628731
+13 389 0.0002927057721578269
+13 394 0.00017562346329469617
+13 399 0.0033953869570307925
+13 468 5.854115443156538e-05
+13 469 0.0011122819341997424
+13 470 0.0027514342582835734
+13 471 0.0012879053974944384
+13 474 5.854115443156538e-05
+13 475 0.0002927057721578269
+13 476 0.0002927057721578269
+13 477 0.0018147757873785268
+13 478 0.0020489404051047887
+13 479 0.0011122819341997424
+13 480 0.004332045427935838
+13 481 0.006556609296335323
+13 483 0.00046832923545252306
+13 484 0.012352183585060298
+13 485 0.014869453225617611
+13 486 0.005912656597588104
+13 487 0.004214963119072708
+13 488 0.01164968973188151
+13 489 0.015806111696522657
+13 490 0.008312843929282283
+13 491 0.009834913944502985
+13 492 0.006146821215314366
+13 493 0.015513405924364829
+13 494 0.02007961597002693
+13 495 0.0024001873316941806
+13 496 0.008956796628029503
+13 497 0.004741833508956796
+13 498 0.003512469265893923
+13 499 0.002517269640557311
+13 501 0.0005854115443156538
+13 502 0.0004097880810209577
+13 504 0.001990399250673223
+13 505 0.00040978808102095764
+13 509 0.010478866643250205
+13 510 0.02207001522070015
+13 513 0.0012293642430628731
+13 579 0.021660227139679192
+13 580 0.0002927057721578269
+13 581 0.00011708230886313077
+13 582 0.0012879053974944388
+13 583 0.0018147757873785272
+13 584 5.854115443156538e-05
+13 585 0.00011708230886313077
+13 586 0.0011122819341997422
+13 587 0.0008195761620419154
+13 589 0.0007610350076103501
+13 590 0.003395386957030792
+13 591 0.0026928931038520073
+13 592 0.009834913944502985
+13 593 0.009834913944502985
+13 594 0.00011708230886313077
+13 595 0.0013464465519260039
+13 596 0.0015806111696522653
+13 597 0.0002927057721578269
+13 598 0.00023416461772626153
+13 599 0.0009951996253366115
+13 600 0.0002927057721578269
+13 601 0.0012293642430628731
+13 602 0.00046832923545252306
+13 603 0.00011708230886313077
+13 604 0.003980798501346446
+13 605 0.013523006673691603
+13 606 0.011591148577449948
+13 607 0.006263903524177495
+13 608 0.014693829762322912
+13 610 0.0003512469265893923
+13 611 0.0012293642430628734
+13 612 5.854115443156538e-05
+13 613 0.005327245053272449
+13 614 0.0019318580962416575
+13 615 0.006615150450766888
+13 616 0.0026928931038520073
+13 617 0.0002927057721578269
+13 627 0.005268703898840884
+13 630 0.00011708230886313077
+13 696 0.00023416461772626153
+13 769 0.00076103500761035
+13 770 0.004683292354525231
+13 771 0.0011122819341997424
+13 772 5.854115443156538e-05
+13 774 0.00076103500761035
+13 775 0.003512469265893923
+13 776 0.008020138157124457
+14 74 0.0005157677571470676
+14 75 0.0005157677571470676
+14 76 0.004273504273504274
+14 77 0.0008104921898025347
+14 141 0.002799882110226938
+14 142 0.0003684055408193339
+14 160 0.001326259946949602
+14 161 0.0005894488653109342
+14 162 0.004420866489832007
+14 163 0.0050103153551429415
+14 196 7.368110816386678e-05
+14 197 0.0014736221632773356
+14 198 0.0030209254347185383
+14 199 0.0009578544061302684
+14 206 7.368110816386678e-05
+14 246 0.0013262599469496023
+14 247 0.0061155319776009425
+14 269 7.368110816386678e-05
+14 276 0.0034630120837017397
+14 277 0.0008841732979664015
+14 290 0.001399941055113469
+14 291 0.0052313586796345415
+14 292 0.0058944886531093425
+14 293 0.008989095195991748
+14 468 0.0199675803124079
+14 469 0.02460949012673151
+14 470 0.021220159151193633
+14 471 0.02586206896551724
+14 472 0.020704391394046565
+14 473 0.017978190391983492
+14 474 0.020114942528735632
+14 475 0.02586206896551724
+14 476 0.02291482463896257
+14 477 0.02475685234305924
+14 478 0.021293840259357502
+14 479 0.026009431181844976
+14 480 0.019451812555260833
+14 481 0.014294134983790155
+14 482 0.01422045387562629
+14 483 0.02726201002063071
+14 484 0.02026230474506337
+14 485 0.015694076038903628
+14 486 0.02726201002063071
+14 487 0.02733569112879458
+14 488 0.01215738284703802
+14 489 0.009652225169466549
+14 490 0.015767757147067494
+14 491 0.02460949012673151
+14 492 0.020114942528735635
+14 493 0.013704686118479222
+14 494 0.01333628057765989
+14 495 0.022988505747126436
+14 496 0.018272914824638966
+14 497 0.020851753610374304
+14 498 0.016578249336870028
+14 499 0.025567344532861774
+14 500 0.007515473032714411
+14 501 0.019157088122605366
+14 502 0.015104627173592693
+14 503 0.00987326849395815
+14 504 0.021293840259357502
+14 505 0.020999115826702035
+14 506 0.013262599469496024
+14 507 0.013483642793987621
+14 508 0.010389036251105217
+14 509 0.011715296198054817
+14 510 0.010167992926613616
+14 511 0.011199528440907752
+14 512 0.009357500736811082
+14 513 0.020335985853227233
+14 514 0.010683760683760684
+14 515 0.01215738284703802
+14 516 0.016357206012378427
+14 517 0.004052460949012673
+14 518 0.006704980842911877
+14 519 0.004273504273504274
+14 520 0.0036103743000294726
+14 521 0.004494547597995874
+14 522 0.003020925434718538
+14 523 0.002136752136752137
+14 524 0.0037577365163572064
+14 525 0.0005894488653109342
+14 526 0.0008104921898025347
+14 531 0.0002947244326554671
+14 541 0.0016209843796050694
+14 542 0.0006631299734748011
+14 551 0.0019157088122605363
+14 552 0.0009578544061302684
+14 563 0.005010315355142941
+14 564 0.004715590922487474
+14 565 0.0010315355142941351
+14 567 0.000663129973474801
+14 568 0.00022104332449160037
+14 579 0.006115531977600943
+15 468 0.01103996467211305
+15 469 0.010230367262824759
+15 470 0.0023551924633841174
+15 471 0.004121586810922205
+15 472 0.009199970560094207
+15 473 0.011334363730036065
+15 474 0.004047987046441452
+15 475 0.0027967910502686394
+15 476 0.0059615809229410476
+15 477 0.0014719952896150733
+15 478 0.0003679988224037683
+15 479 0.0016191948185765807
+15 482 0.011187164201074557
+15 483 0.0056671818650180315
+15 484 0.0014719952896150733
+15 485 0.0003679988224037683
+15 486 0.004563185397806727
+15 487 0.0073599764480753675
+15 491 0.002134393169941856
+15 492 0.0003679988224037683
+15 495 0.0011775962316920587
+15 496 0.0005151983513652757
+15 497 0.0005151983513652757
+15 498 7.359976448075367e-05
+15 499 0.0032383896371531613
+15 500 0.019945536174284243
+15 501 0.01781114300434239
+15 502 0.014204754544785456
+15 503 0.02524471921689851
+15 504 0.005446382571575771
+15 505 0.010524766320747773
+15 506 0.01832634135570766
+15 507 0.01884153970707294
+15 508 0.018473540884669168
+15 511 0.01781114300434239
+15 512 0.019356738058438214
+15 513 0.012143961139324354
+15 514 0.020755133583572533
+15 515 0.01862074041363068
+15 516 0.015014351954073748
+15 517 0.024361522043129462
+15 518 0.02333112534039891
+15 519 0.027011113564436594
+15 520 0.02465592110105248
+15 521 0.024067122985206444
+15 522 0.024508721572090966
+15 523 0.023478324869360415
+15 524 0.025980716861706044
+15 525 0.018031942297784646
+15 526 0.020607934054611025
+15 527 0.0012511959961728123
+15 528 0.0059615809229410476
+15 529 0.0025023919923456246
+15 530 0.0009567969382497977
+15 531 0.018179141826746157
+15 532 0.006255979980864061
+15 533 0.011187164201074557
+15 534 0.005225583278133511
+15 535 0.004710384926768235
+15 536 0.0016927945830573343
+15 537 0.0007359976448075366
+15 538 0.013247957606535658
+15 540 7.359976448075367e-05
+15 541 0.02340472510487966
+15 542 0.02031353499668801
+15 543 0.010745565614190034
+15 544 0.0032383896371531613
+15 545 0.0003679988224037683
+15 546 0.015529550305439023
+15 547 0.005593582100537278
+15 548 0.001103996467211305
+15 549 0.019356738058438214
+15 550 0.009126370795613454
+15 551 0.025465518510340766
+15 552 0.022374328402149115
+15 553 0.0029439905792301465
+15 557 0.011923161845882095
+15 558 0.0029439905792301465
+15 559 0.00942076985353647
+15 560 0.003679988224037683
+15 561 0.0002943990579230147
+15 563 0.019356738058438214
+15 564 0.024582321336571723
+15 565 0.02244792816662987
+15 566 0.015382350776477514
+15 567 0.019503937587399718
+15 568 0.015161551483035255
+15 569 0.0059615809229410476
+15 570 0.01023036726282476
+15 571 0.0030175903437109006
+15 572 0.003459188930595423
+15 573 0.005519982336056524
+15 574 0.0008095974092882903
+15 575 0.0008095974092882903
+15 576 7.359976448075367e-05
+15 577 0.0013247957606535659
+15 578 0.0008095974092882903
+16 556 1.0
+17 17 0.0004919184820801125
+17 18 0.0006324666198172875
+17 20 0.005762473647224175
+17 21 0.0021082220660576245
+17 22 0.0014757554462403375
+17 23 0.0024595924104005625
+17 70 0.000140548137737175
+17 71 0.000983836964160225
+17 72 0.0023190442726633872
+17 73 0.0004919184820801125
+17 76 0.000140548137737175
+17 77 0.0006324666198172875
+17 80 0.008151791988756148
+17 81 0.006676036542515813
+17 82 0.016303583977512297
+17 83 0.012297962052002813
+17 96 0.002178496134926213
+17 97 0.0007027406886858749
+17 98 7.02740688685875e-05
+17 99 0.0011243851018974
+17 100 0.009065354884047786
+17 101 0.007308503162333099
+17 102 0.015038650737877725
+17 103 0.017919887561489812
+17 141 0.0018271257905832748
+17 142 0.003794799718903725
+17 148 0.0013352073085031624
+17 153 0.001546029515108925
+17 154 0.0024595924104005625
+17 160 0.0134926212227688
+17 161 0.01883345045678145
+17 162 0.012438510189739986
+17 163 0.005200281096275476
+17 178 0.0007730147575544624
+17 179 0.0007730147575544624
+17 180 0.003021784961349263
+17 181 0.00758959943780745
+17 182 0.01377371749824315
+17 183 0.007238229093464512
+17 184 0.0026001405481377374
+17 196 0.0123682361208714
+17 197 0.007449051300070275
+17 198 0.0071679550245959235
+17 199 0.0202389318341532
+17 200 0.0004919184820801125
+17 201 0.027406886858749122
+17 202 0.020028109627547436
+17 206 0.0019676739283204497
+17 207 0.0004919184820801125
+17 218 0.000140548137737175
+17 220 0.00028109627547435
+17 247 7.02740688685875e-05
+17 256 0.000140548137737175
+17 257 0.0004919184820801125
+17 269 0.00035137034434293746
+17 275 0.005270555165144062
+17 276 0.0010541110330288123
+17 277 7.02740688685875e-05
+17 278 0.02508784258608574
+17 289 0.019465917076598734
+17 290 0.0044975404075896
+17 291 0.001546029515108925
+17 292 0.0002108222066057625
+17 293 7.02740688685875e-05
+17 484 0.0009135628952916374
+17 485 0.0007730147575544624
+17 489 0.000421644413211525
+17 491 0.0004919184820801125
+17 492 7.02740688685875e-05
+17 493 0.0013352073085031622
+17 494 0.002951510892480675
+17 509 0.0002108222066057625
+17 510 0.0033731553056922
+17 579 0.0027406886858749122
+17 580 0.003162333099086437
+17 581 0.0023190442726633877
+17 582 0.009978917779339425
+17 583 0.009065354884047788
+17 584 0.0010541110330288125
+17 585 0.0016865776528460997
+17 586 0.004356992269852425
+17 587 0.003513703443429375
+17 588 0.0016865776528461
+17 589 0.004567814476458187
+17 590 0.009065354884047786
+17 591 0.006886858749121575
+17 592 0.013914265635980324
+17 593 0.016795502459592413
+17 594 0.021503865073787775
+17 595 0.028742094167252288
+17 596 0.02178496134926212
+17 597 0.02059030217849614
+17 598 0.0026001405481377374
+17 599 0.004427266338721012
+17 600 0.015038650737877721
+17 601 0.015390021082220663
+17 602 0.008081517919887564
+17 603 0.007308503162333099
+17 604 0.021995783555867888
+17 605 0.02312016865776529
+17 606 0.019465917076598734
+17 607 0.025720309205903027
+17 608 0.019465917076598737
+17 609 0.0123682361208714
+17 610 0.013070976809557275
+17 611 0.02009838369641603
+17 612 0.019184820801124384
+17 613 0.012719606465214337
+17 614 0.0134926212227688
+17 615 0.0179901616303584
+17 616 0.015319747013352071
+17 617 0.005621925509486999
+17 618 0.00028109627547435
+17 619 0.0022487702037948
+17 620 0.0026704146170063252
+17 621 0.0002108222066057625
+17 622 0.007589599437807451
+17 623 0.0018271257905832748
+17 624 0.00084328882642305
+17 625 0.0007027406886858749
+17 626 0.0009135628952916376
+17 627 0.01981728742094167
+17 628 0.0002108222066057625
+17 629 0.0004919184820801124
+17 630 0.003021784961349262
+17 631 0.0006324666198172875
+17 632 7.02740688685875e-05
+17 633 7.02740688685875e-05
+17 696 0.021152494729444835
+17 769 0.0179901616303584
+17 770 0.020941672522839076
+17 771 0.011243851018974
+17 772 0.005411103302881238
+17 773 0.0028109627547434997
+17 774 0.0036542515811665496
+17 775 0.00871398453970485
+17 776 0.012157413914265636
+17 777 0.00035137034434293746
+18 82 0.0006012777151446825
+18 83 0.00015031942878617063
+18 103 0.0004509582863585119
+18 142 7.515971439308531e-05
+18 160 0.001428034573468621
+18 161 0.0071401728673431055
+18 162 0.0018789928598271326
+18 182 0.0005261180007515971
+18 196 0.0018789928598271326
+18 197 0.0009019165727170237
+18 198 0.0008267568583239384
+18 199 0.006238256294626081
+18 201 0.003006388575723412
+18 202 0.0018038331454340473
+18 275 0.00015031942878617063
+18 278 0.002179631717399474
+18 289 0.003607666290868095
+18 580 0.020142803457346866
+18 581 0.01698609545283728
+18 582 0.03013904547162721
+18 583 0.02705749718151071
+18 584 0.017737692596768138
+18 585 0.016910935738444193
+18 586 0.021270199173243146
+18 587 0.018489289740698987
+18 588 0.01924088688462984
+18 589 0.016910935738444197
+18 590 0.016234498308906428
+18 591 0.017061255167230362
+18 592 0.015182262307403233
+18 593 0.010672679443818113
+18 594 0.004509582863585118
+18 595 0.005712138293874483
+18 596 0.01570838030815483
+18 597 0.011424276587748966
+18 598 0.018714768883878245
+18 599 0.01969184517098835
+18 600 0.020593761743705377
+18 601 0.023900789177001128
+18 602 0.027583615182262308
+18 603 0.0266065388951522
+18 604 0.007591131153701616
+18 605 0.007666290868094702
+18 606 0.012701991732431419
+18 607 0.020668921458098462
+18 608 0.016309658023299513
+18 609 0.015031942878617064
+18 610 0.017061255167230362
+18 611 0.009845922585494176
+18 612 0.009845922585494176
+18 613 0.02435174746335964
+18 614 0.02450206689214581
+18 615 0.02720781661029688
+18 616 0.016459977452085682
+18 617 0.021119879744456973
+18 618 0.012251033446072902
+18 619 0.020744081172491546
+18 620 0.02247275460353251
+18 621 0.008643367155204812
+18 622 0.025779782036828264
+18 623 0.02006764374295378
+18 624 0.011724915445321307
+18 625 0.01089815858699737
+18 626 0.011800075159714395
+18 627 0.006914693724163849
+18 628 0.011499436302142051
+18 629 0.011273957158962795
+18 630 0.0214956783164224
+18 631 0.012251033446072907
+18 632 0.013077790304396842
+18 633 0.014656144306651634
+18 634 0.00496054114994363
+18 635 0.006914693724163849
+18 636 0.004735062006764375
+18 637 0.004359263434798948
+18 638 0.004208944006012777
+18 639 0.003757985719654265
+18 640 0.0021796317173994736
+18 641 0.003908305148440436
+18 642 0.0010522360015031943
+18 643 0.0006012777151446825
+18 648 0.0005261180007515971
+18 658 0.0020293122886133035
+18 659 0.0011273957158962795
+18 668 0.000751597143930853
+18 669 0.0011273957158962795
+18 680 0.002931228861330327
+18 681 0.0057872980082675695
+18 682 0.001428034573468621
+18 684 0.0006012777151446825
+18 685 0.0006012777151446825
+18 696 0.003757985719654265
+18 769 0.0009019165727170238
+18 770 0.002931228861330327
+18 771 0.00015031942878617063
+18 775 0.00022547914317925594
+18 776 0.001202555430289365
+19 580 0.012027744982290436
+19 581 0.009961629279811098
+19 582 0.0059031877213695395
+19 583 0.004574970484061393
+19 584 0.009223730814639905
+19 585 0.011289846517119244
+19 586 0.0038370720188902006
+19 587 0.0028040141676505316
+19 588 0.0057556080283353
+19 589 0.0014757969303423849
+19 590 0.0003689492325855962
+19 591 0.0008116883116883117
+19 598 0.011806375442739079
+19 599 0.004501180637544274
+19 600 0.0012544273907910272
+19 601 0.0012544273907910272
+19 602 0.005165289256198347
+19 603 0.005165289256198347
+19 609 0.0005165289256198347
+19 610 0.0008116883116883117
+19 613 0.0016971664698937428
+19 614 0.0010330578512396697
+19 615 0.0011068476977567888
+19 616 0.0003689492325855962
+19 617 0.0028040141676505316
+19 618 0.021325265643447465
+19 619 0.012322904368358915
+19 620 0.016528925619834708
+19 621 0.022432113341204252
+19 622 0.0042060212514757975
+19 623 0.009518890200708384
+19 624 0.021989374262101534
+19 625 0.015053128689492327
+19 626 0.01977567886658796
+19 628 0.017635773317591502
+19 629 0.019406729634002362
+19 630 0.010256788665879575
+19 631 0.01977567886658796
+19 632 0.018299881936245575
+19 633 0.014684179456906728
+19 634 0.027007083825265645
+19 635 0.02317001180637544
+19 636 0.026638134592680048
+19 637 0.02368654073199528
+19 638 0.02435064935064935
+19 639 0.02457201889020071
+19 640 0.021915584415584416
+19 641 0.0256788665879575
+19 642 0.018816410861865408
+19 643 0.01682408500590319
+19 644 0.002656434474616293
+19 645 0.009445100354191263
+19 646 0.0031729634002361272
+19 647 0.0025826446280991736
+19 648 0.016602715466351833
+19 649 0.006419716646989375
+19 650 0.010478158205430934
+19 651 0.004870129870129871
+19 652 0.003246753246753247
+19 653 0.0014757969303423849
+19 654 0.0013282172373081465
+19 655 0.010847107438016527
+19 656 0.0005903187721369539
+19 657 0.0005165289256198347
+19 658 0.02169421487603306
+19 659 0.019406729634002362
+19 660 0.007747933884297522
+19 661 0.001844746162927981
+19 662 0.00014757969303423848
+19 663 0.012101534828807558
+19 664 0.0038370720188902014
+19 665 0.0007378984651711924
+19 666 0.0157172373081464
+19 667 0.006050767414403779
+19 668 0.02221074380165289
+19 669 0.021103896103896107
+19 670 0.0028040141676505316
+19 674 0.006419716646989373
+19 675 0.0014020070838252656
+19 676 0.010109208972845335
+19 677 0.0030253837072018895
+19 678 7.378984651711924e-05
+19 679 0.00014757969303423848
+19 680 0.02206316410861866
+19 681 0.022432113341204252
+19 682 0.0256788665879575
+19 683 0.01977567886658796
+19 684 0.022358323494687134
+19 685 0.02088252656434475
+19 686 0.012470484061393153
+19 687 0.01586481700118064
+19 688 0.004427390791027155
+19 689 0.006198347107438017
+19 690 0.00974025974025974
+19 691 0.000885478158205431
+19 692 0.0003689492325855962
+19 693 0.00014757969303423848
+19 694 0.002877804014167651
+19 695 0.0016233766233766235
+20 673 1.0
diff --git a/data/body_models/J_regressor_mano_RIGHT.txt b/data/body_models/J_regressor_mano_RIGHT.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3151077d5f7d3a680eb3ad55d115685def2ea33b
--- /dev/null
+++ b/data/body_models/J_regressor_mano_RIGHT.txt
@@ -0,0 +1,1902 @@
+# 21 778
+0 4 0.0019103600293901542
+0 5 0.0027920646583394562
+0 6 0.00029390154298310065
+0 7 0.00014695077149155033
+0 25 0.0016164584864070536
+0 26 0.000440852314474651
+0 32 0.011756061719324026
+0 33 0.021234386480529024
+0 34 0.019838354151359296
+0 35 0.016311535635562088
+0 36 0.015870683321087434
+0 37 0.02343864805290228
+0 38 0.01671565025716385
+0 39 0.020499632623071272
+0 40 0.005437178545187362
+0 41 0.010139603232916973
+0 42 0.002645113886847906
+0 43 0.00014695077149155033
+0 44 0.02005878030859662
+0 45 0.02233651726671565
+0 50 0.01763409257898604
+0 51 0.01704628949301984
+0 52 0.019838354151359296
+0 53 0.02079353416605437
+0 54 0.00822924320352682
+0 55 0.00822924320352682
+0 78 0.011572373254959589
+0 79 0.011939750183688464
+0 84 0.01704628949301984
+0 85 0.019691403379867745
+0 88 0.005437178545187362
+0 89 0.0007347538574577516
+0 90 0.014548126377663484
+0 91 0.018736223365172666
+0 92 0.011645848640705364
+0 106 0.018515797207935343
+0 107 0.02204261572373255
+0 108 0.012417340191036004
+0 109 0.009992652461425423
+0 110 0.016311535635562088
+0 111 0.01880969875091844
+0 112 0.0073475385745775165
+0 113 0.0014695077149155032
+0 114 0.005731080088170463
+0 116 0.02204261572373255
+0 117 0.012123438648052902
+0 118 0.013005143277002204
+0 119 0.016385011021307863
+0 120 0.008155767817781044
+0 121 0.011315209404849376
+0 122 0.009037472446730345
+0 130 0.0073475385745775165
+0 131 0.00911094783247612
+0 178 0.001763409257898604
+0 179 0.002351212343864805
+0 190 0.019544452608376194
+0 191 0.019691403379867745
+0 192 0.01704628949301984
+0 193 0.016605437178545186
+0 200 0.002351212343864805
+0 203 0.00822924320352682
+0 204 0.007641440117560617
+0 205 0.01704628949301984
+0 207 0.001763409257898604
+0 208 0.005290227773695812
+0 209 0.01763409257898604
+0 210 0.019691403379867745
+0 211 0.019691403379867745
+0 214 0.011315209404849376
+0 215 0.011315209404849376
+0 216 0.007641440117560617
+0 217 0.00822924320352682
+0 218 0.002351212343864805
+0 219 0.0011756061719324026
+0 227 0.002351212343864805
+0 229 0.007788390889052168
+0 231 0.002204261572373255
+0 232 0.016311535635562088
+0 233 0.006759735488611315
+0 234 0.011168258633357825
+0 235 0.019544452608376194
+0 236 0.0016164584864070536
+0 239 0.011315209404849376
+0 241 0.0007347538574577516
+0 242 0.002351212343864805
+0 243 0.0036737692872887582
+0 244 0.0011756061719324026
+0 254 0.0064658339456282144
+0 255 0.0038207200587803084
+0 256 0.002351212343864805
+0 257 0.002351212343864805
+0 264 0.014107274063188832
+0 265 0.00440852314474651
+0 279 0.011315209404849376
+0 284 0.00896399706098457
+0 285 0.0029390154298310064
+1 0 0.014595751184471957
+1 1 0.025294207550053488
+1 2 0.019180803912578332
+1 3 0.01039278618370778
+1 4 0.03156044627846554
+1 5 0.025752712822864135
+1 6 0.014977838911814154
+1 7 0.023307351367874065
+1 8 0.005654898364664528
+1 9 0.009170105456212748
+1 10 0.002063273727647868
+1 11 0.0006113403637475165
+1 12 0.0018340210912425497
+1 14 0.001222680727495033
+1 15 7.641754546843957e-05
+1 16 0.0011462631820265935
+1 17 0.0004585052728106374
+1 18 0.00015283509093687913
+1 19 0.0003820877273421978
+1 22 7.641754546843957e-05
+1 24 0.01413724591166132
+1 25 0.019257221458046772
+1 26 0.024377197004432218
+1 27 0.017346782821335782
+1 28 0.0007641754546843956
+1 29 0.0022161088185847473
+1 30 0.0006877579092159561
+1 31 0.0005349228182790769
+1 32 0.0005349228182790768
+1 33 0.0005349228182790769
+1 34 0.0024071526822558465
+1 35 0.002445361454990066
+1 36 0.029802842732691428
+1 37 0.022122879413113253
+1 38 0.010029802842732692
+1 39 0.02334556014060829
+1 40 0.029344337459880795
+1 41 0.032171786642213054
+1 42 0.02009781445819961
+1 43 0.009934280910897143
+1 60 0.004355800091701055
+1 61 0.00855876509246523
+1 62 0.0004585052728106374
+1 63 0.003285954455142901
+1 64 0.0012990982729634726
+1 65 7.641754546843957e-05
+1 66 0.0019868561821794286
+1 67 0.004814305364511693
+1 68 0.008253094910591475
+1 69 0.0018340210912425497
+1 70 0.0003820877273421978
+1 71 7.641754546843957e-05
+1 88 0.021320495185694635
+1 89 0.013907993275256002
+1 90 0.01986856182179429
+1 91 0.013564114320648022
+1 92 0.003763564114320649
+1 93 0.0004585052728106374
+1 94 0.008329512456059913
+1 95 0.007565337001375517
+1 104 0.0027510316368638244
+1 105 0.0072596668195017595
+1 109 0.009705028274491823
+1 110 0.005654898364664528
+1 111 0.015436344184624792
+1 112 0.019180803912578332
+1 113 0.03339446736970809
+1 114 0.0340058077334556
+1 115 0.02559987773192725
+1 116 0.008405930001528351
+1 117 0.0017767079321412199
+1 118 0.00527281063732233
+1 119 0.00032477456824086816
+1 122 0.004967140455448571
+1 123 0.007259666819501758
+1 124 0.0016811860003056705
+1 125 0.0025217790004585057
+1 126 0.008176677365123033
+1 129 0.00030567018187375826
+1 145 0.00030567018187375826
+1 146 0.0006877579092159561
+1 147 7.641754546843957e-05
+1 152 7.641754546843957e-05
+1 157 0.002063273727647868
+1 158 0.0016047684548372307
+1 159 0.0032095369096744614
+1 188 0.0007641754546843956
+1 190 0.0019868561821794286
+1 191 0.0004585052728106374
+1 192 0.0016047684548372307
+1 193 0.005884151001069847
+1 207 0.00015283509093687913
+1 208 7.641754546843957e-05
+1 209 0.00030567018187375826
+1 216 0.0008405930001528353
+1 217 0.003897294818890417
+1 218 0.0008405930001528353
+1 219 0.0014519333639003516
+1 227 0.005502063273727648
+1 229 0.008635182637933671
+1 230 0.004126547455295736
+1 231 0.009705028274491824
+1 232 0.01245605991135565
+1 233 0.016888277548525142
+1 234 0.001413724591166132
+1 235 0.005654898364664528
+1 236 0.012838147638697846
+1 239 0.00026746140913953847
+1 240 0.01543634418462479
+1 241 0.0006877579092159561
+1 242 0.0032095369096744614
+1 248 0.004890722909980132
+1 249 0.0005349228182790769
+1 250 0.0015283509093687911
+1 251 0.0009170105456212748
+1 252 0.0029038667278007036
+1 253 0.005502063273727649
+1 254 0.0019868561821794286
+1 255 0.0002292526364053187
+1 264 0.028885832187070158
+1 265 0.029650007641754548
+1 266 0.006953996637628001
+1 267 0.002445361454990066
+1 268 0.00015283509093687913
+1 285 0.010087116001834023
+1 286 0.007794589637780836
+1 287 0.0025981965459269452
+1 697 0.0004585052728106374
+1 699 7.641754546843957e-05
+1 700 0.00030567018187375826
+1 704 0.0002292526364053187
+1 705 0.0008405930001528353
+1 706 7.641754546843957e-05
+2 0 0.0027531810402559712
+2 1 0.0034972840241089364
+2 2 0.007887491628841432
+2 3 0.0056551826772825355
+2 4 0.009152466701391472
+2 5 0.01674231713669172
+2 6 0.02708534861224793
+2 7 0.02209985862043307
+2 8 0.00833395341915321
+2 9 0.009152466701391472
+2 10 0.011682416846491553
+2 11 0.0055063620805119425
+2 12 0.005431951782126646
+2 13 0.0011161544757794478
+2 14 0.006176054765979612
+2 15 0.0017858471612471167
+2 16 0.0007441029838529652
+2 19 0.0003720514919264826
+2 26 0.000967333879008855
+2 27 0.0008929235806235583
+2 28 0.013245033112582783
+2 29 0.013765905201279856
+2 30 0.009970979983629735
+2 31 0.011384775652950369
+2 36 0.0023811295483294886
+2 37 0.00014882059677059304
+2 38 7.441029838529652e-05
+2 39 0.0020834883547883026
+2 40 0.0055063620805119425
+2 41 0.009896569685244438
+2 42 0.022843961604286034
+2 43 0.032666120991145166
+2 60 0.00364610462087953
+2 61 0.0017858471612471167
+2 62 0.0002976411935411861
+2 63 0.000967333879008855
+2 64 0.0014882059677059304
+2 65 0.0004464617903117792
+2 68 0.0002976411935411861
+2 69 7.441029838529652e-05
+2 88 0.01562616266091227
+2 89 0.027234169209018527
+2 90 0.00513431058858546
+2 91 0.0006696926854676687
+2 93 7.441029838529652e-05
+2 94 0.0005952823870823722
+2 104 0.025225091152615526
+2 105 0.017858471612471165
+2 113 0.0035716943224942334
+2 114 0.002604360443485378
+2 115 0.010566262370712107
+2 123 0.026787707418706754
+2 124 0.021504576233350697
+2 125 0.01882580549148002
+2 126 0.02083488354788303
+2 127 0.0002232308951558896
+2 128 0.0002976411935411861
+2 129 0.0017114368628618197
+2 144 0.0002232308951558896
+2 145 0.0013393853709353374
+2 158 0.002604360443485378
+2 193 0.0003720514919264826
+2 217 0.0007441029838529652
+2 219 0.0004464617903117792
+2 227 0.003199642830567751
+2 229 0.003125232532182454
+2 230 0.008854825507850286
+2 231 0.00982215938685914
+2 232 0.002009078056403006
+2 233 0.007813081330456134
+2 235 7.441029838529652e-05
+2 236 0.01912344668502121
+2 240 0.01480764937867401
+2 248 0.03318699307984225
+2 249 0.01823052310439765
+2 250 0.02887119577349505
+2 251 0.02500186025745963
+2 252 0.02864796487833916
+2 253 0.032889351886301064
+2 259 0.00014882059677059304
+2 264 0.0002232308951558896
+2 265 0.0005952823870823722
+2 266 0.015402931765756382
+2 267 0.01622144504799464
+2 286 0.02805268249125679
+2 287 0.025820373539697895
+2 697 0.014510008185132822
+2 698 0.008631594612694398
+2 699 0.011161544757794479
+2 700 0.01049185207232681
+2 701 0.00811072252399732
+2 702 0.013393853709353377
+2 703 0.010938313862638589
+2 704 0.008185132822382618
+2 705 0.02187662772527718
+2 706 0.018825805491480024
+2 707 0.011905647741647447
+2 708 0.007217798943373763
+2 709 0.005059900290200163
+2 710 0.003199642830567751
+2 711 0.0019346677580177095
+2 712 0.005952823870823722
+2 713 0.00364610462087953
+2 714 0.00364610462087953
+2 715 0.0026787707418706747
+2 716 0.0021578986531735995
+2 721 0.0006696926854676687
+2 722 0.0002232308951558896
+2 723 0.0002232308951558896
+2 725 0.0004464617903117792
+2 731 0.0032740531289530473
+2 732 0.0008185132822382618
+2 741 0.0005952823870823722
+2 742 0.0005208720886970756
+2 746 0.0002232308951558896
+2 749 0.0005208720886970756
+2 753 0.0034972840241089364
+2 754 0.004018156112806012
+2 755 0.0014882059677059304
+2 757 0.0008929235806235583
+2 758 0.0014137956693206339
+2 759 0.0003720514919264826
+2 760 7.441029838529652e-05
+3 6 0.0019164148301024542
+3 7 0.0014004569912287167
+3 8 0.000884499152354979
+3 9 0.00029483305078499295
+3 10 0.004422495761774894
+3 11 0.0011793322031399718
+3 12 0.0005896661015699859
+3 14 0.0011056239404437236
+3 28 0.011203655929829732
+3 29 0.0037591213975086604
+3 30 0.004496204024471142
+3 31 0.011645905506007222
+3 43 0.0019164148301024544
+3 89 0.0005896661015699859
+3 104 0.009729490675904768
+3 105 0.002137539618191199
+3 123 0.006412618854573597
+3 124 0.0187956069875433
+3 125 0.013414903810717178
+3 126 0.004938453600648632
+3 230 0.0007370826269624824
+3 231 0.00022112478808874474
+3 236 0.0005159578388737376
+3 240 0.0008844991523549787
+3 248 0.007665659320409817
+3 249 0.013120070759932186
+3 250 0.009434657625119773
+3 251 0.012088155082184712
+3 252 0.004348787499078646
+3 253 0.003022038770546178
+3 266 0.0029483305078499295
+3 267 0.0125304046583622
+3 286 0.002727205719761185
+3 287 0.005896661015699859
+3 697 0.01805852436058082
+3 698 0.019016731775632047
+3 699 0.021375396181911987
+3 700 0.01968010613989828
+3 701 0.023512935800103187
+3 702 0.01975381440259453
+3 703 0.021965062283481978
+3 704 0.019164148301024544
+3 705 0.015331318640819633
+3 706 0.017837399572492075
+3 707 0.02889363897692931
+3 708 0.02130168791921574
+3 709 0.027050932409523103
+3 710 0.024544851477850665
+3 711 0.0209331466057345
+3 712 0.0232181027493182
+3 713 0.023070686223925697
+3 714 0.024102601901673175
+3 715 0.018353357411365814
+3 716 0.017026608682833344
+3 717 0.0016952900420137097
+3 718 0.0062652023291811
+3 719 0.0033168718213311705
+3 720 0.00125304046583622
+3 721 0.016879192157440846
+3 722 0.01090882287904474
+3 723 0.008402741947372299
+3 724 0.004717328812559887
+3 725 0.010982531141740989
+3 726 0.0033168718213311705
+3 727 0.0008107908896587306
+3 730 7.370826269624824e-05
+3 731 0.022775853173140702
+3 732 0.018279649148669565
+3 733 0.009803198938601014
+3 734 0.003022038770546178
+3 735 0.0003685413134812412
+3 736 0.011719613768703471
+3 737 0.003906537922901157
+3 738 0.0008107908896587306
+3 739 0.013488612073413427
+3 740 0.005306994914129874
+3 741 0.021301687919215745
+3 742 0.019606397877202027
+3 743 0.0022112478808874476
+3 746 0.006338910591877348
+3 747 0.00125304046583622
+3 748 0.0016952900420137097
+3 749 0.009876907201297264
+3 750 0.003022038770546178
+3 751 7.370826269624824e-05
+3 753 0.025208225842116898
+3 754 0.0209331466057345
+3 755 0.023291811012014444
+3 756 0.017837399572492075
+3 757 0.021449104444608236
+3 758 0.01975381440259453
+3 759 0.01171961376870347
+3 760 0.01348861207341343
+3 761 0.003906537922901157
+3 762 0.005306994914129872
+3 763 0.007960492371194809
+3 764 0.0008107908896587306
+3 765 0.0003685413134812412
+3 767 0.0022112478808874476
+3 768 0.0011056239404437238
+4 745 1.0
+5 0 0.0012638674343491084
+5 1 0.0001404297149276787
+5 2 0.00035107428731919675
+5 3 0.002808594298553574
+5 8 0.004072461732902682
+5 9 0.0007723634321022329
+5 10 0.004774610307541076
+5 11 0.01418340120769555
+5 12 0.012357814913635726
+5 13 0.01930908580255582
+5 14 0.007934278893413846
+5 15 0.020011234377194213
+5 16 0.0021064457239151806
+5 17 0.0006319337171745541
+5 18 0.0022468754388428594
+5 19 0.009127931470299114
+5 21 0.00042128914478303613
+5 24 0.0009127931470299115
+5 25 7.021485746383936e-05
+5 26 0.0001404297149276787
+5 27 0.0010532228619575903
+5 28 0.0004212891447830361
+5 29 0.0015447268642044658
+5 30 0.003932032017975004
+5 31 0.0009127931470299115
+5 46 0.0006319337171745542
+5 47 0.00035107428731919675
+5 48 0.003721387445583485
+5 49 0.0027383794410897346
+5 56 0.0002808594298553574
+5 57 7.021485746383936e-05
+5 58 0.0010532228619575903
+5 59 0.0028788091560174134
+5 60 0.010040724617329027
+5 61 0.005687403454570988
+5 62 0.029981744137059403
+5 63 0.017483499508496
+5 64 0.02029209380704957
+5 65 0.024294340682488414
+5 66 0.0029490240134812527
+5 67 0.0011234377194214297
+5 68 0.005827833169498665
+5 69 0.00975986518747367
+5 74 0.00217666058137902
+5 75 0.0010532228619575903
+5 76 0.00035107428731919675
+5 77 0.00021064457239151807
+5 86 0.0007723634321022329
+5 87 0.0021064457239151806
+5 93 0.018536722370453586
+5 94 0.0016851565791321445
+5 95 0.0001404297149276787
+5 104 7.021485746383936e-05
+5 105 0.0001404297149276787
+5 127 0.023592192107850022
+5 128 0.02710293498104199
+5 129 0.020713382951832608
+5 132 0.023030473248139307
+5 133 0.005195899452324112
+5 134 0.005195899452324112
+5 135 0.01305996348827412
+5 136 0.008495997753124563
+5 137 0.014323830922623225
+5 138 0.01818564808313439
+5 139 0.011515236624069652
+5 140 0.008215138323269205
+5 143 0.010742873191967421
+5 144 0.016991995506249125
+5 145 0.010040724617329027
+5 146 0.00035107428731919675
+5 147 0.0011234377194214297
+5 149 0.013832326920376354
+5 150 0.016430276646538407
+5 151 0.010181154332256704
+5 152 0.011023732621822779
+5 155 0.00035107428731919675
+5 156 0.001966016008987502
+5 157 7.021485746383936e-05
+5 158 0.003932032017975004
+5 164 0.0034405280157281284
+5 165 0.005195899452324111
+5 166 0.0014745120067406266
+5 167 0.0014745120067406264
+5 168 0.026049712119084405
+5 169 0.02927959556242101
+5 170 0.023873051537705376
+5 171 0.016008987501755372
+5 172 0.027102934981041993
+5 173 0.016921780648785283
+5 174 0.005546973739643309
+5 175 0.005406544024715631
+5 176 0.013551467490520995
+5 177 0.00758320460609465
+5 183 7.021485746383936e-05
+5 185 0.009127931470299114
+5 186 0.017834573795815194
+5 187 0.008074708608341525
+5 189 0.007161915461311614
+5 194 0.010602443477039742
+5 195 0.01060244347703974
+5 206 0.0013340822918129478
+5 212 0.007091700603847775
+5 213 0.0013340822918129476
+5 219 0.0002808594298553574
+5 220 0.00435332116275804
+5 222 0.0002808594298553574
+5 223 0.00042128914478303613
+5 225 0.0016851565791321445
+5 226 0.00042128914478303613
+5 227 0.000983008004493751
+5 228 0.00975986518747367
+5 230 0.001825586294059823
+5 231 7.021485746383936e-05
+5 246 0.00035107428731919675
+5 258 0.020924027524224127
+5 259 0.022398539530964757
+5 260 0.015587698356972338
+5 261 0.012568459486027245
+5 262 0.009619435472545991
+5 263 0.01305996348827412
+5 266 0.0010532228619575903
+5 267 0.0005617188597107148
+5 268 0.004283106305294201
+5 269 0.0017553714365959837
+5 270 0.005266114309787951
+5 271 0.004844825165004915
+5 274 0.018045218368206713
+5 276 0.0002808594298553574
+5 277 0.00021064457239151807
+5 280 0.0001404297149276787
+5 288 0.00540654402471563
+5 290 7.021485746383936e-05
+5 358 0.0002808594298553574
+5 359 0.00035107428731919675
+5 362 0.00021064457239151807
+5 363 0.0002808594298553574
+5 365 7.021485746383936e-05
+5 366 0.0009127931470299116
+5 367 0.0013340822918129476
+5 368 0.005125684594860273
+5 369 0.0034405280157281284
+5 370 0.0013340822918129476
+5 371 0.00021064457239151807
+5 373 0.00042128914478303613
+5 375 0.00035107428731919675
+5 378 0.004493750877685719
+5 379 0.0034405280157281284
+5 380 0.004634180592613397
+5 383 0.00042128914478303613
+5 385 0.0016149417216683051
+5 386 0.001404297149276787
+5 387 0.0016851565791321445
+5 388 0.0002808594298553574
+5 399 0.0014745120067406264
+6 46 0.019904998869034157
+6 47 0.01960340797707909
+6 48 0.025559828093191583
+6 49 0.02352408957249491
+6 56 0.022166930558697125
+6 57 0.020131192038000453
+6 58 0.02194073738973083
+6 59 0.028952725627686037
+6 62 0.0005277840609213601
+6 65 0.00022619316896629722
+6 86 0.02382568046444997
+6 87 0.022543919173640955
+6 127 0.0012063635678202518
+6 128 0.0007539772298876573
+6 132 0.0006031817839101259
+6 133 0.017643067179371183
+6 134 0.02382568046444997
+6 135 0.01379778330694413
+6 136 0.01259141973912388
+6 137 0.004448465656337178
+6 138 0.003091306642539395
+6 139 0.009424715373595717
+6 140 0.012214431124180048
+6 143 0.0005277840609213601
+6 144 0.0012817612908090175
+6 150 0.0008293749528764231
+6 155 0.019678805700067855
+6 156 0.0244288622483601
+6 164 0.019980396592022914
+6 165 0.017944658071326246
+6 166 0.023222498680539848
+6 167 0.023901078187438737
+6 168 0.002789715750584332
+6 169 0.002186533966674206
+6 170 0.00987710171152831
+6 171 0.005881022393123726
+6 172 0.004071477041393349
+6 173 0.011837442509236221
+6 174 0.022166930558697128
+6 175 0.02382568046444997
+6 176 0.019377214808112796
+6 177 0.013119203800045236
+6 185 0.0016587499057528462
+6 186 0.004448465656337178
+6 187 0.0005277840609213601
+6 189 0.020809771544899342
+6 194 0.015154942320741913
+6 195 0.01839704440925884
+6 212 0.021262157882831936
+6 213 0.022317726004674656
+6 221 0.006333408731056322
+6 222 0.016210510442584633
+6 223 0.018472442132247607
+6 224 0.00987710171152831
+6 225 0.02744477116791073
+6 226 0.020583578375933047
+6 228 0.0005277840609213602
+6 237 0.012516022016135112
+6 238 0.011912840232224985
+6 245 0.011912840232224985
+6 258 0.0052024428862248355
+6 259 0.002337329412651738
+6 260 0.007162783683932745
+6 261 0.013043806077056472
+6 262 0.0016587499057528462
+6 263 0.007388976852899043
+6 272 0.014174771921887958
+6 273 0.012817612908090177
+6 274 0.0059564201161124925
+6 280 0.019301817085124028
+6 281 0.011385056171303627
+6 282 0.011460453894292393
+6 283 0.017643067179371186
+6 294 0.003920681595415819
+6 295 0.0069365905149664465
+6 296 0.0037698861494382865
+6 297 0.00512704516323607
+6 298 0.006634999623011385
+6 299 0.002789715750584332
+6 300 0.0021865339666742064
+6 301 0.0038452838724270517
+6 302 0.0005277840609213601
+6 303 0.0006031817839101259
+6 305 0.00030159089195506294
+6 316 0.0016587499057528462
+6 321 0.0009047726758651889
+6 330 0.0021111362436854408
+6 331 0.0015079544597753145
+6 340 0.00512704516323607
+6 341 0.004599261102314709
+6 342 0.0011309658448314859
+6 344 0.0007539772298876573
+6 345 0.00022619316896629722
+7 46 0.008690077640857611
+7 47 0.009188688653037966
+7 48 0.0033478167960680964
+7 49 0.0034902770852624832
+7 56 0.010898212123370611
+7 57 0.012322815015314481
+7 58 0.004202578531234419
+7 59 0.003276586651470902
+7 86 0.00648194315834461
+7 87 0.0016382933257354513
+7 133 0.00035615072298596765
+7 134 0.0015670631811382577
+7 155 0.009829759954412709
+7 156 0.004131348386637225
+7 164 0.0009259918797635161
+7 165 0.0006410713013747418
+7 166 0.003917657952845645
+7 167 0.0050573402664007405
+7 174 0.001638293325735451
+7 175 0.0014246028919438706
+7 189 0.0009259918797635161
+7 194 0.00028492057838877413
+7 195 0.0006410713013747418
+7 212 0.00042738086758316123
+7 213 0.0037039675190540643
+7 221 0.019517059619631027
+7 222 0.016739083980340477
+7 223 0.0143172590640359
+7 224 0.02443193959683738
+7 225 0.00683809388133058
+7 226 0.01111190255716219
+7 237 0.016739083980340477
+7 238 0.018092456727687157
+7 245 0.01367618776266116
+7 272 0.02236626540351877
+7 273 0.01923213904124225
+7 280 0.011040672412564997
+7 281 0.020086900776408578
+7 282 0.01859106773986751
+7 283 0.0165253935465489
+7 294 0.024004558729254222
+7 295 0.024075788873851416
+7 296 0.02443193959683738
+7 297 0.025357931476600898
+7 298 0.026283923356364414
+7 299 0.023933328584657028
+7 300 0.022722416126504736
+7 301 0.02514424104280932
+7 302 0.01738015528171522
+7 303 0.020941662511574897
+7 304 0.007835315905691288
+7 305 0.017380155281715225
+7 306 0.011396823135550965
+7 307 0.0036327373744568705
+7 308 0.0012821426027494836
+7 309 0.002777975639290548
+7 310 0.011966664292328516
+7 311 0.005342260844789515
+7 312 0.0038464278082484507
+7 313 0.0014958330365410642
+7 314 0.0007835315905691288
+7 315 0.008191466628677256
+7 316 0.022651185981907542
+7 317 0.00035615072298596765
+7 321 0.02101289265617209
+7 322 0.01225158487071729
+7 323 0.007764085761094094
+7 324 0.002564285205498967
+7 325 0.01994444048721419
+7 326 0.008690077640857611
+7 327 0.0024218249163045803
+7 328 0.0165253935465489
+7 329 0.006980554170524965
+7 330 0.028064676971294254
+7 331 0.021084122800769284
+7 332 0.0019232139041242254
+7 333 0.00021369043379158061
+7 334 0.010969442267967804
+7 335 0.0024930550609017737
+7 336 0.008690077640857611
+7 337 0.003988888097442838
+7 338 0.00028492057838877413
+7 340 0.019588289764228224
+7 341 0.0242182491630458
+7 342 0.021867654391338417
+7 343 0.014103568630244322
+7 344 0.018662297884464708
+7 345 0.014673409787021868
+7 346 0.006125792435358643
+7 347 0.009758529809815513
+7 348 0.0017095234703326447
+7 349 0.0031341263622765153
+7 350 0.004772419688011967
+7 351 0.0006410713013747418
+7 352 0.0008547617351663223
+7 353 0.00042738086758316123
+7 354 0.001068452168957903
+7 355 0.0009972220243607095
+8 317 1.0
+9 11 0.0002498906728306366
+9 13 0.0002498906728306366
+9 14 0.0009995626913225464
+9 15 0.0022490160554757294
+9 16 0.0029986880739676387
+9 17 0.002249016055475729
+9 18 0.007746610857749733
+9 19 0.00949584556756419
+9 20 0.0013743987005685012
+9 21 0.00437308677453614
+9 22 0.0009995626913225461
+9 23 0.00018741800462297744
+9 48 0.0004997813456612732
+9 59 0.0002498906728306366
+9 62 0.0014368713687761604
+9 63 0.000874617354907228
+9 64 6.247266820765915e-05
+9 65 6.247266820765915e-05
+9 66 0.0024989067283063657
+9 67 0.000437308677453614
+9 68 0.0006871993502842506
+9 69 0.0029986880739676387
+9 71 0.0004997813456612732
+9 74 0.015555694383707127
+9 75 0.017867183107390515
+9 76 0.017242456425313923
+9 77 0.00868370088086462
+9 83 6.247266820765915e-05
+9 87 0.0004997813456612732
+9 93 0.0033110514150059348
+9 127 0.0006247266820765914
+9 132 0.004810395451989753
+9 133 0.0006247266820765914
+9 135 0.0001249453364153183
+9 136 0.0004997813456612732
+9 137 0.015555694383707127
+9 138 0.007246829512088461
+9 139 0.005997376147935278
+9 140 0.008683700880864622
+9 141 0.005997376147935278
+9 142 0.0025613793965140247
+9 143 0.015743112388330104
+9 144 0.009558318235771848
+9 145 0.0032485787467982754
+9 146 0.0015618167051914785
+9 147 0.006122321484350596
+9 148 0.0025613793965140247
+9 149 0.0071843568438808006
+9 150 0.01243206097332417
+9 151 0.013993877678515648
+9 152 0.007809083525957393
+9 157 0.0001249453364153183
+9 158 0.0023114887236833884
+9 160 0.0019991253826450927
+9 161 0.0002498906728306366
+9 162 0.0005622540138689324
+9 163 0.0021240707190604106
+9 164 0.0029362154057599797
+9 165 0.002561379396514025
+9 166 0.0007496720184919098
+9 167 0.0007496720184919097
+9 168 0.002124070719060411
+9 169 0.0003123633410382957
+9 170 0.0006871993502842506
+9 171 0.002249016055475729
+9 174 0.0028737427375523207
+9 175 0.0018741800462297744
+9 176 0.009433372899356529
+9 177 0.006247266820765914
+9 181 0.00018741800462297744
+9 182 0.0009995626913225464
+9 183 0.004248141438120822
+9 185 0.019179109139751356
+9 186 0.01661772974323733
+9 187 0.019054163803336036
+9 194 0.0015618167051914785
+9 195 0.0001249453364153183
+9 196 0.0004997813456612732
+9 197 0.0014993440369838195
+9 198 0.0003748360092459549
+9 199 0.0001249453364153183
+9 202 6.247266820765915e-05
+9 206 0.013181732991816079
+9 207 6.247266820765915e-05
+9 212 0.0018741800462297742
+9 213 0.0002498906728306366
+9 218 0.0003123633410382957
+9 219 0.0006871993502842506
+9 220 0.014868495033422876
+9 225 0.0006247266820765914
+9 227 0.0006871993502842506
+9 228 0.021802961204473042
+9 230 0.0002498906728306366
+9 246 0.020803398513150495
+9 247 0.017304929093521583
+9 258 0.0004997813456612732
+9 259 0.0027487974011370024
+9 260 0.0017492347098144558
+9 261 0.002623852064721684
+9 262 0.01974136315362029
+9 263 0.01655525707502967
+9 268 0.007746610857749734
+9 269 0.02167801586805772
+9 270 0.019054163803336036
+9 271 0.011932279627662898
+9 274 0.0066221028300118695
+9 275 0.0007496720184919098
+9 276 0.016742675079652648
+9 277 0.02205285187730368
+9 288 0.022427687886549634
+9 289 0.0003123633410382957
+9 290 0.00730930218029612
+9 291 0.005685012806896982
+9 292 0.0057474854751046415
+9 293 0.008933591553695257
+9 356 0.0014993440369838195
+9 357 0.0014993440369838193
+9 358 0.00668457549821953
+9 359 0.004685450115574436
+9 360 0.0007496720184919098
+9 361 0.0007496720184919098
+9 362 0.0024989067283063657
+9 363 0.0038733054288748667
+9 364 0.0014368713687761604
+9 365 0.004498032110951459
+9 366 0.009933154245017804
+9 367 0.010245517586056099
+9 368 0.015993003061160742
+9 369 0.015993003061160742
+9 370 0.021115761854188793
+9 371 0.01693009308427563
+9 372 0.0009995626913225464
+9 373 0.0037483600924595483
+9 374 0.008996064221902918
+9 375 0.012432060973324168
+9 376 0.004498032110951458
+9 377 0.0031861060785906164
+9 378 0.017554819766352217
+9 379 0.01749234709814456
+9 380 0.01649278440682201
+9 381 0.008308864871618667
+9 382 0.006434684825388891
+9 383 0.016055475729368402
+9 384 0.012557006309739488
+9 385 0.01018304491784844
+9 386 0.015180858374461174
+9 387 0.01155744361841694
+9 388 0.009058536890110576
+9 389 0.0028112700693446614
+9 391 0.00018741800462297744
+9 392 0.0005622540138689324
+9 394 0.0018117073780221152
+9 395 0.0004997813456612732
+9 399 0.01611794839757606
+9 402 0.0008746173549072279
+9 470 0.0007496720184919098
+9 471 0.0004997813456612732
+9 478 0.0007496720184919098
+9 479 0.0004997813456612732
+9 480 0.0026863247329293434
+9 481 0.002623852064721684
+9 483 0.0001249453364153183
+9 484 0.0001249453364153183
+9 485 0.0014993440369838195
+9 486 0.0004997813456612732
+9 488 0.008996064221902916
+9 489 0.006059848816142937
+9 490 0.006497157493596552
+9 491 0.0001249453364153183
+9 492 0.0003748360092459549
+9 493 0.001311926032360842
+9 494 0.000437308677453614
+9 495 0.0017492347098144558
+9 496 0.002623852064721684
+9 497 0.0027487974011370024
+9 498 0.0006247266820765914
+9 509 0.0020615980508527517
+9 510 0.0003748360092459549
+9 579 0.0019991253826450927
+10 74 0.0005264345341054373
+10 75 0.0021809430698653833
+10 76 0.000752049334436339
+10 137 0.000827254267879973
+10 143 0.0006016394675490712
+10 150 0.0003008197337745356
+10 151 0.0006016394675490712
+10 185 0.004361886139730767
+10 186 0.0010528690682108748
+10 187 0.003910656539068963
+10 206 0.0001504098668872678
+10 220 0.0003008197337745356
+10 228 0.0030834022711889904
+10 246 0.003985861472512596
+10 247 0.0012784838685417762
+10 262 0.003910656539068963
+10 263 0.0011280740016545085
+10 269 0.0032338121380762574
+10 270 0.002857787470858088
+10 271 0.0003008197337745356
+10 276 0.000902459201323607
+10 277 0.00556516507482891
+10 288 0.0027825825374144545
+10 356 0.020305332029781156
+10 357 0.019703692562232082
+10 358 0.02549447243739189
+10 359 0.023764758968188315
+10 360 0.02587049710461006
+10 361 0.022486275099646538
+10 362 0.022411070166202904
+10 363 0.02278709483342107
+10 364 0.026321726705271865
+10 365 0.02007971722945025
+10 366 0.016093855756937656
+10 367 0.022260660299315636
+10 368 0.011882379484094157
+10 369 0.009400616680454237
+10 370 0.00962623148078514
+10 371 0.011431149883432353
+10 372 0.021583815898322933
+10 373 0.024742423102955553
+10 374 0.01947807776190118
+10 375 0.01789877415958487
+10 376 0.023388734300970146
+10 377 0.023689554034744677
+10 378 0.009400616680454237
+10 379 0.005865984808603443
+10 380 0.01135594494998872
+10 381 0.022486275099646538
+10 382 0.015341806422501316
+10 383 0.01135594494998872
+10 384 0.01158155975031962
+10 385 0.019703692562232082
+10 386 0.01504098668872678
+10 387 0.018124388959915774
+10 388 0.010077461081446944
+10 389 0.02293750470030834
+10 390 0.01383770775362864
+10 391 0.017372339625479433
+10 392 0.019703692562232086
+10 393 0.011882379484094157
+10 394 0.024667218169511923
+10 395 0.024667218169511916
+10 396 0.012333609084755958
+10 397 0.011506354816875987
+10 398 0.013236068286079568
+10 399 0.0070692637437015865
+10 400 0.01940287282845755
+10 401 0.016093855756937656
+10 402 0.020530946830112053
+10 403 0.008197337745356097
+10 404 0.01759795442581033
+10 405 0.021508610964879295
+10 406 0.008197337745356095
+10 407 0.013988117620515906
+10 408 0.008949387079792434
+10 409 0.006467624276152515
+10 410 0.005264345341054373
+10 411 0.005565165074828909
+10 412 0.003835451605625329
+10 413 0.002105738136421749
+10 414 0.0012784838685417764
+10 415 0.002556967737083553
+10 417 7.52049334436339e-05
+10 420 0.0020305332029781154
+10 421 0.0006016394675490712
+10 422 0.0006016394675490712
+10 427 7.52049334436339e-05
+10 430 0.004737910806948936
+10 431 0.002331352936752651
+10 432 0.0001504098668872678
+10 440 0.0010528690682108748
+10 441 0.0021057381364217496
+10 446 7.52049334436339e-05
+10 452 0.004512296006618034
+10 453 0.003609836805294428
+10 454 0.0006016394675490712
+10 456 0.0006016394675490712
+10 457 0.0004512296006618035
+11 356 0.011297349184080336
+11 357 0.011888060252528984
+11 358 0.004430333013364838
+11 359 0.004430333013364838
+11 360 0.009229860444510078
+11 361 0.011371188067636416
+11 362 0.0038396219449161927
+11 363 0.002805877575131064
+11 364 0.005759432917374288
+11 365 0.0014767776711216124
+11 366 0.0003691944177804031
+11 367 0.0014029387875655322
+11 372 0.011371188067636418
+11 373 0.004504171896920917
+11 374 0.0012552610204533705
+11 375 0.0011075832533412094
+11 376 0.005316399616037805
+11 377 0.005685594033818208
+11 381 0.001772133205345935
+11 382 0.0003691944177804031
+11 385 0.00118142213689729
+11 386 0.0005168721848925644
+11 387 0.0011075832533412094
+11 388 7.383888355608063e-05
+11 389 0.0031012331093553864
+11 390 0.019345787491693123
+11 391 0.010928154766299934
+11 392 0.01299564350587019
+11 393 0.02082256516281474
+11 394 0.0057594329173742895
+11 395 0.00945137709517832
+11 396 0.017352137635678947
+11 397 0.02001033744369785
+11 398 0.018238204238351912
+11 400 0.01794284870412759
+11 401 0.019124270841024884
+11 402 0.016170715498781657
+11 403 0.022816215018828915
+11 404 0.01727829875212287
+11 405 0.014546260060547885
+11 406 0.0239976371557262
+11 407 0.022963892785941076
+11 408 0.02695119249796943
+11 409 0.023776120505057962
+11 410 0.019493465258805284
+11 411 0.023849959388614037
+11 412 0.026581998080189025
+11 413 0.020601048512146496
+11 414 0.019493465258805288
+11 415 0.02163479288193162
+11 416 0.004873366314701322
+11 417 0.007900760540500627
+11 418 0.0042088163626965965
+11 419 0.0016982943217898545
+11 420 0.018238204238351912
+11 421 0.012035738019641142
+11 422 0.012331093553865465
+11 423 0.0055379162667060465
+11 424 0.004061138595584434
+11 425 0.0016982943217898542
+11 426 0.0008122277191168869
+11 427 0.00834379384183711
+11 428 0.0005168721848925643
+11 429 0.0015506165546776932
+11 430 0.023406926087277558
+11 431 0.019124270841024884
+11 432 0.016392232149449903
+11 433 0.005907110684486449
+11 434 0.0019198109724580966
+11 435 0.015432326663220851
+11 436 0.006940855054271579
+11 437 0.0013290999040094513
+11 438 0.013364837923650594
+11 439 0.00694085505427158
+11 440 0.02126559846415122
+11 441 0.02355460385438972
+11 442 0.002732038691574983
+11 444 7.383888355608063e-05
+11 446 0.010854315882743852
+11 447 0.0031012331093553864
+11 448 0.007753082773388465
+11 449 0.0018459720889020155
+11 450 0.00044303330133648377
+11 451 0.00044303330133648377
+11 452 0.023776120505057962
+11 453 0.02229934283393635
+11 454 0.02126559846415122
+11 455 0.013290999040094512
+11 456 0.018385882005464073
+11 457 0.015580004430333012
+11 458 0.010189765930739126
+11 459 0.012035738019641142
+11 460 0.0034704275271357893
+11 461 0.004578010780476998
+11 462 0.005907110684486449
+11 463 0.000590711068448645
+11 464 0.000590711068448645
+11 465 0.0002953555342243225
+11 466 0.0019936498560141768
+11 467 0.0013290999040094513
+12 444 1.0
+13 16 0.0014635288607891346
+13 17 0.002575810794988877
+13 18 0.005737033134293408
+13 19 0.001990399250673223
+13 20 0.007785973539398196
+13 21 0.008664090855871677
+13 22 0.002985598876009834
+13 23 0.002224563868399485
+13 63 5.854115443156538e-05
+13 66 0.0018147757873785268
+13 67 0.0006439526987472192
+13 68 0.0002927057721578269
+13 69 0.0008195761620419153
+13 70 0.0007024938531787846
+13 71 0.0033953869570307925
+13 72 0.0024001873316941806
+13 73 0.00023416461772626153
+13 74 0.009308043554618896
+13 75 0.007551808921671934
+13 76 0.01890879288139562
+13 77 0.013230300901533777
+13 80 0.0013464465519260039
+13 81 0.0002927057721578269
+13 82 0.0016976934785153963
+13 83 0.0040978808102095764
+13 93 0.00017562346329469617
+13 100 0.00017562346329469617
+13 102 0.00011708230886313077
+13 103 0.00035124692658939234
+13 137 0.00011708230886313077
+13 141 0.020021074815595362
+13 142 0.016625687858564567
+13 143 0.0016391523240838306
+13 144 0.0005268703898840885
+13 145 0.0002927057721578269
+13 146 0.002868516567146704
+13 147 0.006673691605198454
+13 148 0.008839714319166374
+13 149 0.0002927057721578269
+13 150 0.0002927057721578269
+13 151 0.0012293642430628731
+13 152 0.0011122819341997424
+13 157 0.0008781173164734808
+13 158 0.0004097880810209577
+13 160 0.02681184872965695
+13 161 0.023592085235920848
+13 162 0.03096827069429809
+13 163 0.02476290832455216
+13 178 0.0002927057721578269
+13 179 5.854115443156538e-05
+13 180 0.0009366584709050461
+13 181 0.00444912773679897
+13 182 0.013464465519260038
+13 183 0.0167427701674277
+13 184 0.00017562346329469617
+13 185 5.854115443156538e-05
+13 186 0.0002927057721578269
+13 187 0.0008195761620419153
+13 196 0.017503805175038047
+13 197 0.023416461772626154
+13 198 0.023416461772626154
+13 199 0.02921203606135113
+13 201 0.0018733169418100922
+13 202 0.006439526987472192
+13 206 0.015162158997775435
+13 207 0.0006439526987472192
+13 218 0.0007610350076103501
+13 219 0.00046832923545252306
+13 220 0.006673691605198454
+13 227 0.00011708230886313077
+13 228 0.0009951996253366115
+13 246 0.0106544901065449
+13 247 0.014576747453459781
+13 262 0.00011708230886313077
+13 268 0.0033368458025992264
+13 269 0.010420325488818641
+13 270 0.0035710104203254887
+13 271 0.002985598876009834
+13 275 0.009834913944502985
+13 276 0.02142606252195293
+13 277 0.01164968973188151
+13 278 0.00035124692658939234
+13 288 0.004741833508956796
+13 289 0.014693829762322912
+13 290 0.02207001522070015
+13 291 0.017913593256059006
+13 292 0.011005737033134292
+13 293 0.010478866643250203
+13 358 0.0003512469265893923
+13 363 5.854115443156538e-05
+13 365 0.00035124692658939234
+13 366 0.0007024938531787846
+13 367 0.00017562346329469617
+13 368 0.00017562346329469617
+13 369 0.0009951996253366115
+13 370 0.005151621589977754
+13 371 0.005385786207704015
+13 374 0.0016976934785153963
+13 375 0.0017562346329469615
+13 376 0.0004097880810209577
+13 377 0.0003512469265893923
+13 378 0.00046832923545252306
+13 379 0.0015220700152206996
+13 381 0.0014635288607891346
+13 382 0.0009951996253366115
+13 383 0.00532724505327245
+13 384 0.0037466338836201845
+13 386 0.0011708230886313077
+13 387 0.00011708230886313077
+13 388 0.0012293642430628731
+13 389 0.0002927057721578269
+13 394 0.00017562346329469617
+13 399 0.0033953869570307925
+13 468 5.854115443156538e-05
+13 469 0.0011122819341997424
+13 470 0.0027514342582835734
+13 471 0.0012879053974944384
+13 474 5.854115443156538e-05
+13 475 0.0002927057721578269
+13 476 0.0002927057721578269
+13 477 0.0018147757873785268
+13 478 0.0020489404051047887
+13 479 0.0011122819341997424
+13 480 0.004332045427935838
+13 481 0.006556609296335323
+13 483 0.00046832923545252306
+13 484 0.012352183585060298
+13 485 0.014869453225617611
+13 486 0.005912656597588104
+13 487 0.004214963119072708
+13 488 0.01164968973188151
+13 489 0.015806111696522657
+13 490 0.008312843929282283
+13 491 0.009834913944502985
+13 492 0.006146821215314366
+13 493 0.015513405924364829
+13 494 0.02007961597002693
+13 495 0.0024001873316941806
+13 496 0.008956796628029503
+13 497 0.004741833508956796
+13 498 0.003512469265893923
+13 499 0.002517269640557311
+13 501 0.0005854115443156538
+13 502 0.0004097880810209577
+13 504 0.001990399250673223
+13 505 0.00040978808102095764
+13 509 0.010478866643250205
+13 510 0.02207001522070015
+13 513 0.0012293642430628731
+13 579 0.021660227139679192
+13 580 0.0002927057721578269
+13 581 0.00011708230886313077
+13 582 0.0012879053974944388
+13 583 0.0018147757873785272
+13 584 5.854115443156538e-05
+13 585 0.00011708230886313077
+13 586 0.0011122819341997422
+13 587 0.0008195761620419154
+13 589 0.0007610350076103501
+13 590 0.003395386957030792
+13 591 0.0026928931038520073
+13 592 0.009834913944502985
+13 593 0.009834913944502985
+13 594 0.00011708230886313077
+13 595 0.0013464465519260039
+13 596 0.0015806111696522653
+13 597 0.0002927057721578269
+13 598 0.00023416461772626153
+13 599 0.0009951996253366115
+13 600 0.0002927057721578269
+13 601 0.0012293642430628731
+13 602 0.00046832923545252306
+13 603 0.00011708230886313077
+13 604 0.003980798501346446
+13 605 0.013523006673691603
+13 606 0.011591148577449948
+13 607 0.006263903524177495
+13 608 0.014693829762322912
+13 610 0.0003512469265893923
+13 611 0.0012293642430628734
+13 612 5.854115443156538e-05
+13 613 0.005327245053272449
+13 614 0.0019318580962416575
+13 615 0.006615150450766888
+13 616 0.0026928931038520073
+13 617 0.0002927057721578269
+13 627 0.005268703898840884
+13 630 0.00011708230886313077
+13 696 0.00023416461772626153
+13 769 0.00076103500761035
+13 770 0.004683292354525231
+13 771 0.0011122819341997424
+13 772 5.854115443156538e-05
+13 774 0.00076103500761035
+13 775 0.003512469265893923
+13 776 0.008020138157124457
+14 74 0.0005157677571470676
+14 75 0.0005157677571470676
+14 76 0.004273504273504274
+14 77 0.0008104921898025347
+14 141 0.002799882110226938
+14 142 0.0003684055408193339
+14 160 0.001326259946949602
+14 161 0.0005894488653109342
+14 162 0.004420866489832007
+14 163 0.0050103153551429415
+14 196 7.368110816386678e-05
+14 197 0.0014736221632773356
+14 198 0.0030209254347185383
+14 199 0.0009578544061302684
+14 206 7.368110816386678e-05
+14 246 0.0013262599469496023
+14 247 0.0061155319776009425
+14 269 7.368110816386678e-05
+14 276 0.0034630120837017397
+14 277 0.0008841732979664015
+14 290 0.001399941055113469
+14 291 0.0052313586796345415
+14 292 0.0058944886531093425
+14 293 0.008989095195991748
+14 468 0.0199675803124079
+14 469 0.02460949012673151
+14 470 0.021220159151193633
+14 471 0.02586206896551724
+14 472 0.020704391394046565
+14 473 0.017978190391983492
+14 474 0.020114942528735632
+14 475 0.02586206896551724
+14 476 0.02291482463896257
+14 477 0.02475685234305924
+14 478 0.021293840259357502
+14 479 0.026009431181844976
+14 480 0.019451812555260833
+14 481 0.014294134983790155
+14 482 0.01422045387562629
+14 483 0.02726201002063071
+14 484 0.02026230474506337
+14 485 0.015694076038903628
+14 486 0.02726201002063071
+14 487 0.02733569112879458
+14 488 0.01215738284703802
+14 489 0.009652225169466549
+14 490 0.015767757147067494
+14 491 0.02460949012673151
+14 492 0.020114942528735635
+14 493 0.013704686118479222
+14 494 0.01333628057765989
+14 495 0.022988505747126436
+14 496 0.018272914824638966
+14 497 0.020851753610374304
+14 498 0.016578249336870028
+14 499 0.025567344532861774
+14 500 0.007515473032714411
+14 501 0.019157088122605366
+14 502 0.015104627173592693
+14 503 0.00987326849395815
+14 504 0.021293840259357502
+14 505 0.020999115826702035
+14 506 0.013262599469496024
+14 507 0.013483642793987621
+14 508 0.010389036251105217
+14 509 0.011715296198054817
+14 510 0.010167992926613616
+14 511 0.011199528440907752
+14 512 0.009357500736811082
+14 513 0.020335985853227233
+14 514 0.010683760683760684
+14 515 0.01215738284703802
+14 516 0.016357206012378427
+14 517 0.004052460949012673
+14 518 0.006704980842911877
+14 519 0.004273504273504274
+14 520 0.0036103743000294726
+14 521 0.004494547597995874
+14 522 0.003020925434718538
+14 523 0.002136752136752137
+14 524 0.0037577365163572064
+14 525 0.0005894488653109342
+14 526 0.0008104921898025347
+14 531 0.0002947244326554671
+14 541 0.0016209843796050694
+14 542 0.0006631299734748011
+14 551 0.0019157088122605363
+14 552 0.0009578544061302684
+14 563 0.005010315355142941
+14 564 0.004715590922487474
+14 565 0.0010315355142941351
+14 567 0.000663129973474801
+14 568 0.00022104332449160037
+14 579 0.006115531977600943
+15 468 0.01103996467211305
+15 469 0.010230367262824759
+15 470 0.0023551924633841174
+15 471 0.004121586810922205
+15 472 0.009199970560094207
+15 473 0.011334363730036065
+15 474 0.004047987046441452
+15 475 0.0027967910502686394
+15 476 0.0059615809229410476
+15 477 0.0014719952896150733
+15 478 0.0003679988224037683
+15 479 0.0016191948185765807
+15 482 0.011187164201074557
+15 483 0.0056671818650180315
+15 484 0.0014719952896150733
+15 485 0.0003679988224037683
+15 486 0.004563185397806727
+15 487 0.0073599764480753675
+15 491 0.002134393169941856
+15 492 0.0003679988224037683
+15 495 0.0011775962316920587
+15 496 0.0005151983513652757
+15 497 0.0005151983513652757
+15 498 7.359976448075367e-05
+15 499 0.0032383896371531613
+15 500 0.019945536174284243
+15 501 0.01781114300434239
+15 502 0.014204754544785456
+15 503 0.02524471921689851
+15 504 0.005446382571575771
+15 505 0.010524766320747773
+15 506 0.01832634135570766
+15 507 0.01884153970707294
+15 508 0.018473540884669168
+15 511 0.01781114300434239
+15 512 0.019356738058438214
+15 513 0.012143961139324354
+15 514 0.020755133583572533
+15 515 0.01862074041363068
+15 516 0.015014351954073748
+15 517 0.024361522043129462
+15 518 0.02333112534039891
+15 519 0.027011113564436594
+15 520 0.02465592110105248
+15 521 0.024067122985206444
+15 522 0.024508721572090966
+15 523 0.023478324869360415
+15 524 0.025980716861706044
+15 525 0.018031942297784646
+15 526 0.020607934054611025
+15 527 0.0012511959961728123
+15 528 0.0059615809229410476
+15 529 0.0025023919923456246
+15 530 0.0009567969382497977
+15 531 0.018179141826746157
+15 532 0.006255979980864061
+15 533 0.011187164201074557
+15 534 0.005225583278133511
+15 535 0.004710384926768235
+15 536 0.0016927945830573343
+15 537 0.0007359976448075366
+15 538 0.013247957606535658
+15 540 7.359976448075367e-05
+15 541 0.02340472510487966
+15 542 0.02031353499668801
+15 543 0.010745565614190034
+15 544 0.0032383896371531613
+15 545 0.0003679988224037683
+15 546 0.015529550305439023
+15 547 0.005593582100537278
+15 548 0.001103996467211305
+15 549 0.019356738058438214
+15 550 0.009126370795613454
+15 551 0.025465518510340766
+15 552 0.022374328402149115
+15 553 0.0029439905792301465
+15 557 0.011923161845882095
+15 558 0.0029439905792301465
+15 559 0.00942076985353647
+15 560 0.003679988224037683
+15 561 0.0002943990579230147
+15 563 0.019356738058438214
+15 564 0.024582321336571723
+15 565 0.02244792816662987
+15 566 0.015382350776477514
+15 567 0.019503937587399718
+15 568 0.015161551483035255
+15 569 0.0059615809229410476
+15 570 0.01023036726282476
+15 571 0.0030175903437109006
+15 572 0.003459188930595423
+15 573 0.005519982336056524
+15 574 0.0008095974092882903
+15 575 0.0008095974092882903
+15 576 7.359976448075367e-05
+15 577 0.0013247957606535659
+15 578 0.0008095974092882903
+16 556 1.0
+17 17 0.0004919184820801125
+17 18 0.0006324666198172875
+17 20 0.005762473647224175
+17 21 0.0021082220660576245
+17 22 0.0014757554462403375
+17 23 0.0024595924104005625
+17 70 0.000140548137737175
+17 71 0.000983836964160225
+17 72 0.0023190442726633872
+17 73 0.0004919184820801125
+17 76 0.000140548137737175
+17 77 0.0006324666198172875
+17 80 0.008151791988756148
+17 81 0.006676036542515813
+17 82 0.016303583977512297
+17 83 0.012297962052002813
+17 96 0.002178496134926213
+17 97 0.0007027406886858749
+17 98 7.02740688685875e-05
+17 99 0.0011243851018974
+17 100 0.009065354884047786
+17 101 0.007308503162333099
+17 102 0.015038650737877725
+17 103 0.017919887561489812
+17 141 0.0018271257905832748
+17 142 0.003794799718903725
+17 148 0.0013352073085031624
+17 153 0.001546029515108925
+17 154 0.0024595924104005625
+17 160 0.0134926212227688
+17 161 0.01883345045678145
+17 162 0.012438510189739986
+17 163 0.005200281096275476
+17 178 0.0007730147575544624
+17 179 0.0007730147575544624
+17 180 0.003021784961349263
+17 181 0.00758959943780745
+17 182 0.01377371749824315
+17 183 0.007238229093464512
+17 184 0.0026001405481377374
+17 196 0.0123682361208714
+17 197 0.007449051300070275
+17 198 0.0071679550245959235
+17 199 0.0202389318341532
+17 200 0.0004919184820801125
+17 201 0.027406886858749122
+17 202 0.020028109627547436
+17 206 0.0019676739283204497
+17 207 0.0004919184820801125
+17 218 0.000140548137737175
+17 220 0.00028109627547435
+17 247 7.02740688685875e-05
+17 256 0.000140548137737175
+17 257 0.0004919184820801125
+17 269 0.00035137034434293746
+17 275 0.005270555165144062
+17 276 0.0010541110330288123
+17 277 7.02740688685875e-05
+17 278 0.02508784258608574
+17 289 0.019465917076598734
+17 290 0.0044975404075896
+17 291 0.001546029515108925
+17 292 0.0002108222066057625
+17 293 7.02740688685875e-05
+17 484 0.0009135628952916374
+17 485 0.0007730147575544624
+17 489 0.000421644413211525
+17 491 0.0004919184820801125
+17 492 7.02740688685875e-05
+17 493 0.0013352073085031622
+17 494 0.002951510892480675
+17 509 0.0002108222066057625
+17 510 0.0033731553056922
+17 579 0.0027406886858749122
+17 580 0.003162333099086437
+17 581 0.0023190442726633877
+17 582 0.009978917779339425
+17 583 0.009065354884047788
+17 584 0.0010541110330288125
+17 585 0.0016865776528460997
+17 586 0.004356992269852425
+17 587 0.003513703443429375
+17 588 0.0016865776528461
+17 589 0.004567814476458187
+17 590 0.009065354884047786
+17 591 0.006886858749121575
+17 592 0.013914265635980324
+17 593 0.016795502459592413
+17 594 0.021503865073787775
+17 595 0.028742094167252288
+17 596 0.02178496134926212
+17 597 0.02059030217849614
+17 598 0.0026001405481377374
+17 599 0.004427266338721012
+17 600 0.015038650737877721
+17 601 0.015390021082220663
+17 602 0.008081517919887564
+17 603 0.007308503162333099
+17 604 0.021995783555867888
+17 605 0.02312016865776529
+17 606 0.019465917076598734
+17 607 0.025720309205903027
+17 608 0.019465917076598737
+17 609 0.0123682361208714
+17 610 0.013070976809557275
+17 611 0.02009838369641603
+17 612 0.019184820801124384
+17 613 0.012719606465214337
+17 614 0.0134926212227688
+17 615 0.0179901616303584
+17 616 0.015319747013352071
+17 617 0.005621925509486999
+17 618 0.00028109627547435
+17 619 0.0022487702037948
+17 620 0.0026704146170063252
+17 621 0.0002108222066057625
+17 622 0.007589599437807451
+17 623 0.0018271257905832748
+17 624 0.00084328882642305
+17 625 0.0007027406886858749
+17 626 0.0009135628952916376
+17 627 0.01981728742094167
+17 628 0.0002108222066057625
+17 629 0.0004919184820801124
+17 630 0.003021784961349262
+17 631 0.0006324666198172875
+17 632 7.02740688685875e-05
+17 633 7.02740688685875e-05
+17 696 0.021152494729444835
+17 769 0.0179901616303584
+17 770 0.020941672522839076
+17 771 0.011243851018974
+17 772 0.005411103302881238
+17 773 0.0028109627547434997
+17 774 0.0036542515811665496
+17 775 0.00871398453970485
+17 776 0.012157413914265636
+17 777 0.00035137034434293746
+18 82 0.0006012777151446825
+18 83 0.00015031942878617063
+18 103 0.0004509582863585119
+18 142 7.515971439308531e-05
+18 160 0.001428034573468621
+18 161 0.0071401728673431055
+18 162 0.0018789928598271326
+18 182 0.0005261180007515971
+18 196 0.0018789928598271326
+18 197 0.0009019165727170237
+18 198 0.0008267568583239384
+18 199 0.006238256294626081
+18 201 0.003006388575723412
+18 202 0.0018038331454340473
+18 275 0.00015031942878617063
+18 278 0.002179631717399474
+18 289 0.003607666290868095
+18 580 0.020142803457346866
+18 581 0.01698609545283728
+18 582 0.03013904547162721
+18 583 0.02705749718151071
+18 584 0.017737692596768138
+18 585 0.016910935738444193
+18 586 0.021270199173243146
+18 587 0.018489289740698987
+18 588 0.01924088688462984
+18 589 0.016910935738444197
+18 590 0.016234498308906428
+18 591 0.017061255167230362
+18 592 0.015182262307403233
+18 593 0.010672679443818113
+18 594 0.004509582863585118
+18 595 0.005712138293874483
+18 596 0.01570838030815483
+18 597 0.011424276587748966
+18 598 0.018714768883878245
+18 599 0.01969184517098835
+18 600 0.020593761743705377
+18 601 0.023900789177001128
+18 602 0.027583615182262308
+18 603 0.0266065388951522
+18 604 0.007591131153701616
+18 605 0.007666290868094702
+18 606 0.012701991732431419
+18 607 0.020668921458098462
+18 608 0.016309658023299513
+18 609 0.015031942878617064
+18 610 0.017061255167230362
+18 611 0.009845922585494176
+18 612 0.009845922585494176
+18 613 0.02435174746335964
+18 614 0.02450206689214581
+18 615 0.02720781661029688
+18 616 0.016459977452085682
+18 617 0.021119879744456973
+18 618 0.012251033446072902
+18 619 0.020744081172491546
+18 620 0.02247275460353251
+18 621 0.008643367155204812
+18 622 0.025779782036828264
+18 623 0.02006764374295378
+18 624 0.011724915445321307
+18 625 0.01089815858699737
+18 626 0.011800075159714395
+18 627 0.006914693724163849
+18 628 0.011499436302142051
+18 629 0.011273957158962795
+18 630 0.0214956783164224
+18 631 0.012251033446072907
+18 632 0.013077790304396842
+18 633 0.014656144306651634
+18 634 0.00496054114994363
+18 635 0.006914693724163849
+18 636 0.004735062006764375
+18 637 0.004359263434798948
+18 638 0.004208944006012777
+18 639 0.003757985719654265
+18 640 0.0021796317173994736
+18 641 0.003908305148440436
+18 642 0.0010522360015031943
+18 643 0.0006012777151446825
+18 648 0.0005261180007515971
+18 658 0.0020293122886133035
+18 659 0.0011273957158962795
+18 668 0.000751597143930853
+18 669 0.0011273957158962795
+18 680 0.002931228861330327
+18 681 0.0057872980082675695
+18 682 0.001428034573468621
+18 684 0.0006012777151446825
+18 685 0.0006012777151446825
+18 696 0.003757985719654265
+18 769 0.0009019165727170238
+18 770 0.002931228861330327
+18 771 0.00015031942878617063
+18 775 0.00022547914317925594
+18 776 0.001202555430289365
+19 580 0.012027744982290436
+19 581 0.009961629279811098
+19 582 0.0059031877213695395
+19 583 0.004574970484061393
+19 584 0.009223730814639905
+19 585 0.011289846517119244
+19 586 0.0038370720188902006
+19 587 0.0028040141676505316
+19 588 0.0057556080283353
+19 589 0.0014757969303423849
+19 590 0.0003689492325855962
+19 591 0.0008116883116883117
+19 598 0.011806375442739079
+19 599 0.004501180637544274
+19 600 0.0012544273907910272
+19 601 0.0012544273907910272
+19 602 0.005165289256198347
+19 603 0.005165289256198347
+19 609 0.0005165289256198347
+19 610 0.0008116883116883117
+19 613 0.0016971664698937428
+19 614 0.0010330578512396697
+19 615 0.0011068476977567888
+19 616 0.0003689492325855962
+19 617 0.0028040141676505316
+19 618 0.021325265643447465
+19 619 0.012322904368358915
+19 620 0.016528925619834708
+19 621 0.022432113341204252
+19 622 0.0042060212514757975
+19 623 0.009518890200708384
+19 624 0.021989374262101534
+19 625 0.015053128689492327
+19 626 0.01977567886658796
+19 628 0.017635773317591502
+19 629 0.019406729634002362
+19 630 0.010256788665879575
+19 631 0.01977567886658796
+19 632 0.018299881936245575
+19 633 0.014684179456906728
+19 634 0.027007083825265645
+19 635 0.02317001180637544
+19 636 0.026638134592680048
+19 637 0.02368654073199528
+19 638 0.02435064935064935
+19 639 0.02457201889020071
+19 640 0.021915584415584416
+19 641 0.0256788665879575
+19 642 0.018816410861865408
+19 643 0.01682408500590319
+19 644 0.002656434474616293
+19 645 0.009445100354191263
+19 646 0.0031729634002361272
+19 647 0.0025826446280991736
+19 648 0.016602715466351833
+19 649 0.006419716646989375
+19 650 0.010478158205430934
+19 651 0.004870129870129871
+19 652 0.003246753246753247
+19 653 0.0014757969303423849
+19 654 0.0013282172373081465
+19 655 0.010847107438016527
+19 656 0.0005903187721369539
+19 657 0.0005165289256198347
+19 658 0.02169421487603306
+19 659 0.019406729634002362
+19 660 0.007747933884297522
+19 661 0.001844746162927981
+19 662 0.00014757969303423848
+19 663 0.012101534828807558
+19 664 0.0038370720188902014
+19 665 0.0007378984651711924
+19 666 0.0157172373081464
+19 667 0.006050767414403779
+19 668 0.02221074380165289
+19 669 0.021103896103896107
+19 670 0.0028040141676505316
+19 674 0.006419716646989373
+19 675 0.0014020070838252656
+19 676 0.010109208972845335
+19 677 0.0030253837072018895
+19 678 7.378984651711924e-05
+19 679 0.00014757969303423848
+19 680 0.02206316410861866
+19 681 0.022432113341204252
+19 682 0.0256788665879575
+19 683 0.01977567886658796
+19 684 0.022358323494687134
+19 685 0.02088252656434475
+19 686 0.012470484061393153
+19 687 0.01586481700118064
+19 688 0.004427390791027155
+19 689 0.006198347107438017
+19 690 0.00974025974025974
+19 691 0.000885478158205431
+19 692 0.0003689492325855962
+19 693 0.00014757969303423848
+19 694 0.002877804014167651
+19 695 0.0016233766233766235
+20 673 1.0
diff --git a/data/body_models/SMPLX_to_J14.pkl b/data/body_models/SMPLX_to_J14.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..db8aa5c74b860a2b9555383d5ca2a09523851fe4
--- /dev/null
+++ b/data/body_models/SMPLX_to_J14.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5df844ddea85b0a400a2e8dbe63d09d19f2b1b7ec0e0e952daeae08f83d82d61
+size 4692193
diff --git a/data/body_models/SMPL_NEUTRAL.pkl b/data/body_models/SMPL_NEUTRAL.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..65ae47d34e5b26720c9ccdd2614044832f0e30f2
--- /dev/null
+++ b/data/body_models/SMPL_NEUTRAL.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4924f235e63f7c5d5b690acedf736419c2edb846a2d69fc0956169615fa75688
+size 247186228
diff --git a/data/body_models/all_means.pkl b/data/body_models/all_means.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..03ff93f70af27dac0b808dbe45761c95ce8df397
--- /dev/null
+++ b/data/body_models/all_means.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:010c2178eff5fd58d07bab3717002e959fe62541aaaef778b09414ec0237690d
+size 4758
diff --git a/data/body_models/downsample_mat_smplx.pkl b/data/body_models/downsample_mat_smplx.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..45e09fb0bf098421656f6c3418ac05bd8fc32f18
--- /dev/null
+++ b/data/body_models/downsample_mat_smplx.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b67d12e8e9af767d9856fea8cb3366bfa8025fdf17cd4e25fc8b10f9a45eca9e
+size 18310685
diff --git a/data/body_models/joints_regressor_cmr.npy b/data/body_models/joints_regressor_cmr.npy
new file mode 100644
index 0000000000000000000000000000000000000000..06bcf3ff5f0f2797e8d090e4a5b1ea7c6c37db13
--- /dev/null
+++ b/data/body_models/joints_regressor_cmr.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a408b885040d714c94b41f64b2ec329d20dce673ae330d04a07a4b02dae7a13d
+size 661568
diff --git a/data/body_models/smpl/SMPL_FEMALE.pkl b/data/body_models/smpl/SMPL_FEMALE.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..92a201f4839bd95c1c1986437c7c6a02d7d1ae99
--- /dev/null
+++ b/data/body_models/smpl/SMPL_FEMALE.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a583c1b98e4afc19042641f1bae5cd8a1f712a6724886291a7627ec07acd408d
+size 39056454
diff --git a/data/body_models/smpl/SMPL_MALE.pkl b/data/body_models/smpl/SMPL_MALE.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..43dfecc57d9b7aa99cd2398df818ba252be7f605
--- /dev/null
+++ b/data/body_models/smpl/SMPL_MALE.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0e8c0bbbbc635dcb166ed29c303fb4bef16ea5f623e5a89263495a9e403575bd
+size 39056404
diff --git a/data/body_models/smpl/SMPL_NEUTRAL.pkl b/data/body_models/smpl/SMPL_NEUTRAL.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..65ae47d34e5b26720c9ccdd2614044832f0e30f2
--- /dev/null
+++ b/data/body_models/smpl/SMPL_NEUTRAL.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4924f235e63f7c5d5b690acedf736419c2edb846a2d69fc0956169615fa75688
+size 247186228
diff --git a/data/body_models/smpl/index.html b/data/body_models/smpl/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..60897cdaf2b7687e48b31e7025731020cfd13a5f
--- /dev/null
+++ b/data/body_models/smpl/index.html
@@ -0,0 +1,17 @@
+
+
+
+
+Directory listing for /body_models/smpl/
+
+
+Directory listing for /body_models/smpl/
+
+
+
+
+
diff --git a/data/body_models/smpl_mean_params.npz b/data/body_models/smpl_mean_params.npz
new file mode 100644
index 0000000000000000000000000000000000000000..c6f60a76976b877cbc08345b2977c6ddd83ced87
--- /dev/null
+++ b/data/body_models/smpl_mean_params.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6fd6dd687800da946d0a0492383f973b92ec20f166a0b829775882868c35fcdd
+size 1310
diff --git a/data/body_models/smplx/MANO_SMPLX_vertex_ids.pkl b/data/body_models/smplx/MANO_SMPLX_vertex_ids.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..dabec1377a0da4c511a519a00f51f1a3a23f33af
--- /dev/null
+++ b/data/body_models/smplx/MANO_SMPLX_vertex_ids.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e5abe70b6574de25470475091e8008314a5b90127eb48c3e63bfa0adf8c04dcf
+size 13535
diff --git a/data/body_models/smplx/SMPL-X__FLAME_vertex_ids.npy b/data/body_models/smplx/SMPL-X__FLAME_vertex_ids.npy
new file mode 100644
index 0000000000000000000000000000000000000000..c940d3aa6cb4cbbcc348fd518b15d8777dc350fd
--- /dev/null
+++ b/data/body_models/smplx/SMPL-X__FLAME_vertex_ids.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7e70cdc3659aae699b9732e8dd4af49106310c69b90dc83d9f73e96dbf871e49
+size 40312
diff --git a/data/body_models/smplx/SMPLX_FEMALE.npz b/data/body_models/smplx/SMPLX_FEMALE.npz
new file mode 100644
index 0000000000000000000000000000000000000000..da0a200cd85eb10f73aa36d44f1d9c509a82dfcc
--- /dev/null
+++ b/data/body_models/smplx/SMPLX_FEMALE.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b2a3686c9d6d218ff6822fba411c607a3c8125a70af340f384ce68bebecabe0e
+size 108794146
diff --git a/data/body_models/smplx/SMPLX_FEMALE.pkl b/data/body_models/smplx/SMPLX_FEMALE.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..3b3c8f90629a55b1af53896ab37d9e6863f77d3d
--- /dev/null
+++ b/data/body_models/smplx/SMPLX_FEMALE.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f3ac7af258fd217ab480b839c011545e5826cfa333ab34b3c98244ee3039bddd
+size 544434140
diff --git a/data/body_models/smplx/SMPLX_MALE.npz b/data/body_models/smplx/SMPLX_MALE.npz
new file mode 100644
index 0000000000000000000000000000000000000000..41fdef3ff2784eb06bb479ebf5fb6887aafbc183
--- /dev/null
+++ b/data/body_models/smplx/SMPLX_MALE.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ab318e3f37d2bfaae26abf4e6fab445c2a610e1d63714794d60379cc263bc2a5
+size 108753445
diff --git a/data/body_models/smplx/SMPLX_MALE.pkl b/data/body_models/smplx/SMPLX_MALE.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..450a5c0a51fb0b382cd746efae420a7131a349cc
--- /dev/null
+++ b/data/body_models/smplx/SMPLX_MALE.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:af7ebc82e44cf098598685474c0592049ddfaca8e850feb0c2b88343f9aacee3
+size 544477159
diff --git a/data/body_models/smplx/SMPLX_NEUTRAL.npz b/data/body_models/smplx/SMPLX_NEUTRAL.npz
new file mode 100644
index 0000000000000000000000000000000000000000..6f42b326bd60123bd813c0fa2df7f4660862a920
--- /dev/null
+++ b/data/body_models/smplx/SMPLX_NEUTRAL.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:376021446ddc86e99acacd795182bbef903e61d33b76b9d8b359c2b0865bd992
+size 108752058
diff --git a/data/body_models/smplx/SMPLX_NEUTRAL.pkl b/data/body_models/smplx/SMPLX_NEUTRAL.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..c2ef9ea8a36f2bf51256325bc6d24c181975483c
--- /dev/null
+++ b/data/body_models/smplx/SMPLX_NEUTRAL.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:381c808965deb4f5e845f8c3eddb0cd69930cc72e5774ce4f34c4ce3cf058361
+size 544173380
diff --git a/data/body_models/smplx/SMPLX_to_J14.npy b/data/body_models/smplx/SMPLX_to_J14.npy
new file mode 100644
index 0000000000000000000000000000000000000000..d336545c180ad9c89421cf9eae65aca2faf631d1
--- /dev/null
+++ b/data/body_models/smplx/SMPLX_to_J14.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:be01f37aa99e794ace8f52abe7b31df302fe54c68e75062ea0431a6c2f5e084f
+size 1173328
diff --git a/data/body_models/smplx/SMPLX_to_J14.pkl b/data/body_models/smplx/SMPLX_to_J14.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..db8aa5c74b860a2b9555383d5ca2a09523851fe4
--- /dev/null
+++ b/data/body_models/smplx/SMPLX_to_J14.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5df844ddea85b0a400a2e8dbe63d09d19f2b1b7ec0e0e952daeae08f83d82d61
+size 4692193
diff --git a/data/body_models/smplx/smplx_kid_template.npy b/data/body_models/smplx/smplx_kid_template.npy
new file mode 100644
index 0000000000000000000000000000000000000000..8ce7bc403545dfb29f361787cb7bca1df8316d6e
--- /dev/null
+++ b/data/body_models/smplx/smplx_kid_template.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bdce4f5886b9ddcb6da3ee0f70ae636b1aa1292f2b379c4c3149fce8abc0a604
+size 251528
diff --git a/data/body_models/smplx2smpl.pkl b/data/body_models/smplx2smpl.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..0f25e10571181989524020c803280607b7ee9a85
--- /dev/null
+++ b/data/body_models/smplx2smpl.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c1d912d121ad98132e4492d8e7a0f1a8cf4412811e14a7ef8cb337bb48eef99e
+size 578019251
diff --git a/datasets/AGORA_MM.py b/datasets/AGORA_MM.py
new file mode 100644
index 0000000000000000000000000000000000000000..45f97d3d3bde58c29f684540f87b60c7845665cc
--- /dev/null
+++ b/datasets/AGORA_MM.py
@@ -0,0 +1,974 @@
+import os
+import os.path as osp
+from glob import glob
+import numpy as np
+from config.config import cfg
+import copy
+import json
+import pickle
+import cv2
+import torch
+from pycocotools.coco import COCO
+from util.human_models import smpl_x
+from util.preprocessing import load_img, sanitize_bbox, process_bbox, load_ply, load_obj
+from util.transforms import rigid_align, rigid_align_batch
+import tqdm
+import random
+from util.formatting import DefaultFormatBundle
+from detrsmpl.data.datasets.pipelines.transforms import Normalize
+import time
+from util.preprocessing import (
+ load_img, process_bbox, augmentation_instance_sample
+ ,process_human_model_output_batch_simplify,process_db_coord_batch_no_valid)
+# from util.human_models import smpl_x
+from .humandata import HumanDataset
+import csv
+KPS2D_KEYS = [
+ 'keypoints2d_ori', 'keypoints2d_smplx', 'keypoints2d_smpl',
+ 'keypoints2d_original','keypoints2d_gta'
+]
+KPS3D_KEYS = [
+ 'keypoints3d_cam', 'keypoints3d', 'keypoints3d_smplx', 'keypoints3d_smpl',
+ 'keypoints3d_original', 'keypoints3d_gta'
+]
+class AGORA_MM(HumanDataset):
+ def __init__(self, transform, data_split):
+ super(AGORA_MM, self).__init__(transform, data_split)
+ self.img_shape = [2160,3840]
+ pre_prc_file_train = 'spec_train_smpl.npz'
+ pre_prc_file_test = 'spec_test_smpl.npz'
+ self.save_idx = 0
+ if self.data_split == 'train':
+ filename = getattr(cfg, 'filename', pre_prc_file_train)
+ else:
+ self.test_set = 'val'
+
+ self.img_dir = './data/datasets/agora'
+
+
+ if data_split == 'train':
+ if self.img_shape == [2160,3840]:
+ self.annot_path = 'data/preprocessed_npz/multihuman_data/agora_train_3840_w_occ_multi_2010.npz'
+ self.annot_path_cache = 'data/preprocessed_npz/cache/agora_train_3840_w_occ_cache_2010.npz'
+ elif self.img_shape == [720,1280]:
+ self.annot_path = 'data/preprocessed_npz/multihuman_data/agora_train_1280_multi_1010.npz'
+ self.annot_path_cache = 'data/preprocessed_npz/cache/agora_train_cache_1280_1010.npz'
+
+ elif data_split == 'test':
+ if self.img_shape == [2160,3840]:
+ self.annot_path = 'data/preprocessed_npz/multihuman_data/agora_validation_multi_3840_1010.npz'
+ self.annot_path_cache = 'data/preprocessed_npz/cache/agora_validation_cache_3840_1010_occ_cache_balance.npz'
+ elif self.img_shape == [720,1280]:
+ self.annot_path = 'data/preprocessed_npz/multihuman_data/agora_validation_1280_1010_occ.npz'
+ self.annot_path_cache = 'data/preprocessed_npz/cache/agora_validation_cache_1280_1010_occ.npz'
+
+ self.use_cache = getattr(cfg, 'use_cache', False)
+ self.cam_param = {}
+
+ # load data or cache
+ if self.use_cache and osp.isfile(self.annot_path_cache):
+ print(f'[{self.__class__.__name__}] loading cache from {self.annot_path_cache}')
+ self.datalist = self.load_cache(self.annot_path_cache)
+ else:
+ if self.use_cache:
+ print(f'[{self.__class__.__name__}] Cache not found, generating cache...')
+ self.datalist = self.load_data(
+ train_sample_interval=getattr(cfg, f'{self.__class__.__name__}_train_sample_interval', 1))
+ if self.use_cache:
+ self.save_cache(self.annot_path_cache, self.datalist)
+
+
+ def load_data(self, train_sample_interval=1):
+
+ content = np.load(self.annot_path, allow_pickle=True)
+
+ try:
+ frame_range = content['frame_range']
+ except KeyError:
+ frame_range = \
+ np.array([[i, i + 1] for i in range(self.num_data)])
+
+ num_examples = len(frame_range)
+
+ if 'meta' in content:
+ meta = content['meta'].item()
+ print('meta keys:', meta.keys())
+ else:
+ meta = None
+ print(
+ 'No meta info provided! Please give height and width manually')
+
+ print(
+ f'Start loading humandata {self.annot_path} into memory...\nDataset includes: {content.files}'
+ )
+ tic = time.time()
+ image_path = content['image_path']
+
+ if meta is not None and 'height' in meta:
+ height = np.array(meta['height'])
+ width = np.array(meta['width'])
+ image_shape = np.stack([height, width], axis=-1)
+ else:
+ image_shape = None
+
+ if meta is not None and 'gender' in meta and len(meta['gender']) != 0:
+ gender = meta['gender']
+ else:
+ gender = None
+
+ if meta is not None and 'is_kid' in meta and len(meta['is_kid']) != 0:
+ is_kid = meta['is_kid']
+ else:
+ is_kid = None
+
+ bbox_xywh = content['bbox_xywh']
+
+ if 'smplx' in content:
+ smplx = content['smplx'].item()
+ as_smplx = 'smplx'
+ elif 'smpl' in content:
+ smplx = content['smpl'].item()
+ as_smplx = 'smpl'
+ elif 'smplh' in content:
+ smplx = content['smplh'].item()
+ as_smplx = 'smplh'
+ # TODO: temp solution, should be more general. But SHAPY is very special
+ elif self.__class__.__name__ == 'SHAPY':
+ smplx = {}
+ else:
+ raise KeyError('No SMPL for SMPLX available, please check keys:\n'
+ f'{content.files}')
+
+ print('Smplx param', smplx.keys())
+
+ if 'lhand_bbox_xywh' in content and 'rhand_bbox_xywh' in content:
+ lhand_bbox_xywh = content['lhand_bbox_xywh']
+ rhand_bbox_xywh = content['rhand_bbox_xywh']
+ else:
+ lhand_bbox_xywh = np.zeros_like(bbox_xywh)
+ rhand_bbox_xywh = np.zeros_like(bbox_xywh)
+
+ if 'face_bbox_xywh' in content:
+ face_bbox_xywh = content['face_bbox_xywh']
+ else:
+ face_bbox_xywh = np.zeros_like(bbox_xywh)
+
+ decompressed = False
+ if content['__keypoints_compressed__']:
+ decompressed_kps = self.decompress_keypoints(content)
+ decompressed = True
+
+ keypoints3d = None
+ valid_kps3d = False
+ keypoints3d_mask = None
+ valid_kps3d_mask = False
+
+
+ # processing keypoints
+ for kps3d_key in KPS3D_KEYS:
+ if kps3d_key in content:
+ keypoints3d = decompressed_kps[kps3d_key][:, self.SMPLX_137_MAPPING, :] if decompressed \
+ else content[kps3d_key][:, self.SMPLX_137_MAPPING, :]
+ valid_kps3d = True
+ if keypoints3d.shape[-1] == 4:
+ valid_kps3d_mask = True
+ break
+ if self.keypoints2d is not None:
+ keypoints2d = decompressed_kps[self.keypoints2d][:, self.SMPLX_137_MAPPING, :] if decompressed \
+ else content[self.keypoints2d][:, self.SMPLX_137_MAPPING, :]
+
+
+ else:
+ for kps2d_key in KPS2D_KEYS:
+ if kps2d_key in content:
+ keypoints2d = decompressed_kps[kps2d_key][:, self.SMPLX_137_MAPPING, :] if decompressed \
+ else content[kps2d_key][:, self.SMPLX_137_MAPPING, :]
+
+ if keypoints2d.shape[-1] == 3:
+ valid_kps3d_mask = True
+ occlusion = content['meta'][()]['occ'] if 'occ' in content['meta'][()] and len(content['meta'][()]['occ'])>0 else None
+
+ print('Done. Time: {:.2f}s'.format(time.time() - tic))
+
+ datalist = []
+ # num_examples
+
+ # processing each image, filter according to bbox valid
+ for i in tqdm.tqdm(range(int(num_examples))):
+ if self.data_split == 'train' and i % train_sample_interval != 0:
+ continue
+ frame_start, frame_end = frame_range[i]
+ img_path = osp.join(self.img_dir, image_path[frame_start])
+ # im_shape = cv2.imread(img_path).shape[:2]
+ img_shape = image_shape[
+ frame_start] if image_shape is not None else self.img_shape
+
+
+ bbox_list = bbox_xywh[frame_start:frame_end, :4]
+
+ valid_idx = []
+ body_bbox_list = []
+
+ if hasattr(cfg, 'bbox_ratio'):
+ bbox_ratio = cfg.bbox_ratio * 0.833 # preprocess body bbox is giving 1.2 box padding
+ else:
+ bbox_ratio = 1.25
+
+ for bbox_i, bbox in enumerate(bbox_list):
+
+ bbox = process_bbox(bbox,
+ img_width=img_shape[1],
+ img_height=img_shape[0],
+ ratio=bbox_ratio)
+ if bbox is None:
+ continue
+ else:
+ valid_idx.append(frame_start + bbox_i)
+ bbox[2:] += bbox[:2]
+ body_bbox_list.append(bbox)
+ if len(valid_idx) == 0:
+ continue
+ valid_num = len(valid_idx)
+ # hand/face bbox
+ lhand_bbox_list = []
+ rhand_bbox_list = []
+ face_bbox_list = []
+
+ for bbox_i in valid_idx:
+ lhand_bbox = lhand_bbox_xywh[bbox_i]
+
+ rhand_bbox = rhand_bbox_xywh[bbox_i]
+ face_bbox = face_bbox_xywh[bbox_i]
+ if lhand_bbox[-1] > 0: # conf > 0
+ lhand_bbox = lhand_bbox[:4]
+ if hasattr(cfg, 'bbox_ratio'):
+ lhand_bbox = process_bbox(lhand_bbox,
+ img_width=img_shape[1],
+ img_height=img_shape[0],
+ ratio=bbox_ratio)
+ if lhand_bbox is not None:
+ lhand_bbox[2:] += lhand_bbox[:2] # xywh -> xyxy
+ else:
+ lhand_bbox = None
+ if rhand_bbox[-1] > 0:
+ rhand_bbox = rhand_bbox[:4]
+ if hasattr(cfg, 'bbox_ratio'):
+ rhand_bbox = process_bbox(rhand_bbox,
+ img_width=img_shape[1],
+ img_height=img_shape[0],
+ ratio=bbox_ratio)
+ if rhand_bbox is not None:
+ rhand_bbox[2:] += rhand_bbox[:2] # xywh -> xyxy
+ else:
+ rhand_bbox = None
+ if face_bbox[-1] > 0:
+ face_bbox = face_bbox[:4]
+ if hasattr(cfg, 'bbox_ratio'):
+ face_bbox = process_bbox(face_bbox,
+ img_width=img_shape[1],
+ img_height=img_shape[0],
+ ratio=bbox_ratio)
+ if face_bbox is not None:
+ face_bbox[2:] += face_bbox[:2] # xywh -> xyxy
+ else:
+ face_bbox = None
+ lhand_bbox_list.append(lhand_bbox)
+ rhand_bbox_list.append(rhand_bbox)
+ face_bbox_list.append(face_bbox)
+
+ # lhand_bbox = np.stack(lhand_bbox_list,axis=0)
+ # rhand_bbox = np.stack(rhand_bbox_list,axis=0)
+ # face_bbox = np.stack(face_bbox_list,axis=0)
+ joint_img = keypoints2d[valid_idx]
+
+ # num_joints = joint_cam.shape[0]
+ # joint_valid = np.ones((num_joints, 1))
+ if valid_kps3d:
+ joint_cam = keypoints3d[valid_idx]
+ else:
+ joint_cam = None
+
+ if 'leye_pose_0' in smplx.keys():
+ smplx.pop('leye_pose_0')
+ if 'leye_pose_1' in smplx.keys():
+ smplx.pop('leye_pose_1')
+ if 'leye_pose' in smplx.keys():
+ smplx.pop('leye_pose')
+ if 'reye_pose_0' in smplx.keys():
+ smplx.pop('reye_pose_0')
+ if 'reye_pose_1' in smplx.keys():
+ smplx.pop('reye_pose_1')
+ if 'reye_pose' in smplx.keys():
+ smplx.pop('reye_pose')
+
+ occlusion_frame = occlusion[valid_idx] \
+ if occlusion is not None else np.array([1]*(valid_num))
+
+ smplx_param = {k: v[valid_idx] for k, v in smplx.items()}
+ gender_ = gender[valid_idx] \
+ if gender is not None else np.array(['neutral']*(valid_num))
+
+ is_kid_ = is_kid[valid_idx] \
+ if is_kid is not None else np.array([1]*(valid_num))
+ lhand_bbox_valid = lhand_bbox_xywh[valid_idx,4]
+ rhand_bbox_valid = rhand_bbox_xywh[valid_idx,4]
+ face_bbox_valid = face_bbox_xywh[valid_idx,4]
+
+ smplx_param['root_pose'] = smplx_param.pop('global_orient', None)
+ smplx_param['shape'] = smplx_param.pop('betas', None)
+ smplx_param['trans'] = smplx_param.pop('transl', np.zeros(3))
+ smplx_param['lhand_pose'] = smplx_param.pop('left_hand_pose', None)
+ smplx_param['rhand_pose'] = smplx_param.pop(
+ 'right_hand_pose', None)
+ smplx_param['expr'] = smplx_param.pop('expression', None)
+
+ # TODO do not fix betas, give up shape supervision
+ if 'betas_neutral' in smplx_param and self.data_split == 'train':
+ smplx_param['shape'] = smplx_param.pop('betas_neutral')
+ # smplx_param['shape'] = np.zeros(10, dtype=np.float32)
+
+ if smplx_param['lhand_pose'] is None or self.body_only == True:
+ smplx_param['lhand_valid'] = np.zeros(valid_num, dtype=np.bool8)
+ else:
+ smplx_param['lhand_valid'] = lhand_bbox_valid.astype(np.bool8)
+
+ if smplx_param['rhand_pose'] is None or self.body_only == True:
+ smplx_param['rhand_valid'] = np.zeros(valid_num, dtype=np.bool8)
+ else:
+ smplx_param['rhand_valid'] = rhand_bbox_valid.astype(np.bool8)
+
+ if smplx_param['expr'] is None or self.body_only == True:
+ smplx_param['face_valid'] = np.zeros(valid_num, dtype=np.bool8)
+ else:
+ smplx_param['face_valid'] = face_bbox_valid.astype(np.bool8)
+
+ if joint_cam is not None and np.any(np.isnan(joint_cam)):
+ continue
+
+
+ datalist.append({
+ 'img_path': img_path,
+ 'img_shape': img_shape,
+ 'bbox': body_bbox_list,
+ 'lhand_bbox': lhand_bbox_list,
+ 'rhand_bbox': rhand_bbox_list,
+ 'face_bbox': face_bbox_list,
+ 'joint_img': joint_img,
+ 'joint_cam': joint_cam,
+ 'smplx_param': smplx_param,
+ 'as_smplx': as_smplx,
+ 'gender': gender_,
+ 'occlusion': occlusion_frame,
+ 'is_kid': is_kid_,
+ })
+
+ # save memory
+ del content, image_path, bbox_xywh, lhand_bbox_xywh, rhand_bbox_xywh, face_bbox_xywh, keypoints3d, keypoints2d
+
+ if self.data_split == 'train':
+ print(f'[{self.__class__.__name__} train] original size:',
+ int(num_examples), '. Sample interval:',
+ train_sample_interval, '. Sampled size:', len(datalist))
+
+ if getattr(cfg, 'data_strategy',
+ None) == 'balance' and self.data_split == 'train':
+ print(
+ f'[{self.__class__.__name__}] Using [balance] strategy with datalist shuffled...'
+ )
+ random.shuffle(datalist)
+
+ return datalist
+
+ def __getitem__(self, idx):
+ try:
+ data = copy.deepcopy(self.datalist[idx])
+ except Exception as e:
+ print(f'[{self.__class__.__name__}] Error loading data {idx}')
+ print(e)
+ exit(0)
+
+ img_path, img_shape, bbox = \
+ data['img_path'], data['img_shape'], data['bbox']
+ as_smplx = data['as_smplx']
+ gender = data['gender'].copy()
+ for gender_str, gender_num in {
+ 'neutral': -1, 'male': 0, 'female': 1}.items():
+ gender[gender==gender_str]=gender_num
+ gender = gender.astype(int)
+
+ img_whole_bbox = np.array([0, 0, img_shape[1], img_shape[0]])
+ img = load_img(img_path, order='BGR')
+
+ num_person = len(data['bbox'])
+ data_name = self.__class__.__name__
+ img, img2bb_trans, bb2img_trans, rot, do_flip = \
+ augmentation_instance_sample(img, img_whole_bbox, self.data_split,data,data_name)
+ cropped_img_shape=img.shape[:2]
+
+ num_person = len(data['bbox'])
+ if self.data_split == 'train':
+ joint_cam = data['joint_cam'] # num, 137,4
+ if joint_cam is not None:
+ dummy_cord = False
+ joint_cam[:,:,:3] = \
+ joint_cam[:,:,:3] - joint_cam[:, self.joint_set['root_joint_idx'], None, :3] # root-relative
+ else:
+ # dummy cord as joint_cam
+ dummy_cord = True
+ joint_cam = np.zeros(
+ (num_person, self.joint_set['joint_num'], 4),
+ dtype=np.float32)
+
+ joint_img = data['joint_img']
+ # do rotation on keypoints
+ joint_img_aug, joint_cam_wo_ra, joint_cam_ra, joint_trunc = \
+ process_db_coord_batch_no_valid(
+ joint_img, joint_cam, do_flip, img_shape,
+ self.joint_set['flip_pairs'], img2bb_trans, rot,
+ self.joint_set['joints_name'], smpl_x.joints_name,
+ cropped_img_shape)
+ joint_img_aug[:,:,2:] = joint_img_aug[:,:,2:] * joint_trunc
+
+ # smplx coordinates and parameters
+ smplx_param = data['smplx_param']
+ smplx_pose, smplx_shape, smplx_expr, smplx_pose_valid, \
+ smplx_joint_valid, smplx_expr_valid, smplx_shape_valid = \
+ process_human_model_output_batch_simplify(
+ smplx_param, do_flip, rot, as_smplx)
+ # if cam not provided, we take joint_img as smplx joint 2d,
+ # which is commonly the case for our processed humandata
+ # change smplx_shape if use_betas_neutral
+ # processing follows that in process_human_model_output
+
+ if self.use_betas_neutral:
+ smplx_shape = smplx_param['betas_neutral'].reshape(
+ num_person, -1)
+ smplx_shape[(np.abs(smplx_shape) > 3).any(axis=1)] = 0.
+ smplx_shape = smplx_shape.reshape(num_person, -1)
+ # SMPLX joint coordinate validity
+ # for name in ('L_Big_toe', 'L_Small_toe', 'L_Heel', 'R_Big_toe', 'R_Small_toe', 'R_Heel'):
+ # smplx_joint_valid[smpl_x.joints_name.index(name)] = 0
+ smplx_joint_valid = smplx_joint_valid[:, :, None]
+
+ lhand_bbox_center_list = []
+ lhand_bbox_valid_list = []
+ lhand_bbox_size_list = []
+ lhand_bbox_list = []
+ face_bbox_center_list = []
+ face_bbox_size_list = []
+ face_bbox_valid_list = []
+ face_bbox_list = []
+ rhand_bbox_center_list = []
+ rhand_bbox_valid_list = []
+ rhand_bbox_size_list = []
+ rhand_bbox_list = []
+ body_bbox_center_list = []
+ body_bbox_size_list = []
+ body_bbox_valid_list = []
+ body_bbox_list = []
+
+ for i in range(num_person):
+ body_bbox, body_bbox_valid = self.process_hand_face_bbox(
+ data['bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+
+ lhand_bbox, lhand_bbox_valid = self.process_hand_face_bbox(
+ data['lhand_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ lhand_bbox_valid *= smplx_param['lhand_valid'][i]
+
+ rhand_bbox, rhand_bbox_valid = self.process_hand_face_bbox(
+ data['rhand_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ rhand_bbox_valid *= smplx_param['rhand_valid'][i]
+
+ face_bbox, face_bbox_valid = self.process_hand_face_bbox(
+ data['face_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ face_bbox_valid *= smplx_param['face_valid'][i]
+
+ if do_flip:
+ lhand_bbox, rhand_bbox = rhand_bbox, lhand_bbox
+ lhand_bbox_valid, rhand_bbox_valid = rhand_bbox_valid, lhand_bbox_valid
+
+ body_bbox_list.append(body_bbox)
+ lhand_bbox_list.append(lhand_bbox)
+ rhand_bbox_list.append(rhand_bbox)
+ face_bbox_list.append(face_bbox)
+
+ lhand_bbox_center = (lhand_bbox[0] + lhand_bbox[1]) / 2.
+ rhand_bbox_center = (rhand_bbox[0] + rhand_bbox[1]) / 2.
+ face_bbox_center = (face_bbox[0] + face_bbox[1]) / 2.
+ body_bbox_center = (body_bbox[0] + body_bbox[1]) / 2.
+ lhand_bbox_size = lhand_bbox[1] - lhand_bbox[0]
+ rhand_bbox_size = rhand_bbox[1] - rhand_bbox[0]
+
+ face_bbox_size = face_bbox[1] - face_bbox[0]
+ body_bbox_size = body_bbox[1] - body_bbox[0]
+ lhand_bbox_center_list.append(lhand_bbox_center)
+ lhand_bbox_valid_list.append(lhand_bbox_valid)
+ lhand_bbox_size_list.append(lhand_bbox_size)
+ face_bbox_center_list.append(face_bbox_center)
+ face_bbox_size_list.append(face_bbox_size)
+ face_bbox_valid_list.append(face_bbox_valid)
+ rhand_bbox_center_list.append(rhand_bbox_center)
+ rhand_bbox_valid_list.append(rhand_bbox_valid)
+ rhand_bbox_size_list.append(rhand_bbox_size)
+ body_bbox_center_list.append(body_bbox_center)
+ body_bbox_size_list.append(body_bbox_size)
+ body_bbox_valid_list.append(body_bbox_valid)
+
+
+ body_bbox = np.stack(body_bbox_list, axis=0)
+ lhand_bbox = np.stack(lhand_bbox_list, axis=0)
+ rhand_bbox = np.stack(rhand_bbox_list, axis=0)
+ face_bbox = np.stack(face_bbox_list, axis=0)
+ lhand_bbox_center = np.stack(lhand_bbox_center_list, axis=0)
+ lhand_bbox_valid = np.stack(lhand_bbox_valid_list, axis=0)
+ lhand_bbox_size = np.stack(lhand_bbox_size_list, axis=0)
+ face_bbox_center = np.stack(face_bbox_center_list, axis=0)
+ face_bbox_size = np.stack(face_bbox_size_list, axis=0)
+ face_bbox_valid = np.stack(face_bbox_valid_list, axis=0)
+ body_bbox_center = np.stack(body_bbox_center_list, axis=0)
+ body_bbox_size = np.stack(body_bbox_size_list, axis=0)
+ body_bbox_valid = np.stack(body_bbox_valid_list, axis=0)
+ rhand_bbox_center = np.stack(rhand_bbox_center_list, axis=0)
+ rhand_bbox_valid = np.stack(rhand_bbox_valid_list, axis=0)
+ rhand_bbox_size = np.stack(rhand_bbox_size_list, axis=0)
+
+
+ if 'occlusion' in data:
+ occlusion = data['occlusion']
+ occ_mask = occlusion<97
+
+ joint_img_aug[:,:,2] = joint_img_aug[:,:,2]*occ_mask[:,None]
+ joint_cam_wo_ra[:,:,3] = joint_cam_wo_ra[:,:,3]*occ_mask[:,None]
+ joint_trunc = joint_trunc*occ_mask[:,None,None]
+ smplx_pose_valid = smplx_pose_valid*occ_mask[:,None]
+ smplx_joint_valid = smplx_joint_valid*occ_mask[:,None,None]
+ smplx_expr_valid = smplx_expr_valid*occ_mask
+ smplx_shape_valid = smplx_shape_valid*occ_mask
+ rhand_bbox_valid = rhand_bbox_valid*occ_mask
+ lhand_bbox_valid = lhand_bbox_valid*occ_mask
+ face_bbox_valid = face_bbox_valid*occ_mask
+
+
+ if 'is_kid' in data:
+ is_kid = data['is_kid'].copy()
+ smplx_shape_valid = smplx_shape_valid * (is_kid==0)
+
+
+ inputs = {'img': img}
+
+ joint_img_aug[:,:,2] = joint_img_aug[:,:,2] * body_bbox_valid[:,None]
+
+ is_3D = float(False) if dummy_cord else float(True)
+
+ targets = {
+ # keypoints2d, [0,img_w],[0,img_h] -> [0,1] -> [0,output_hm_shape]
+ 'joint_img': joint_img_aug[body_bbox_valid>0],
+ # joint_cam, kp3d wo ra # raw kps3d probably without ra
+ 'joint_cam': joint_cam_wo_ra[body_bbox_valid>0],
+ # kps3d with body, face, hand ra
+ 'smplx_joint_cam': joint_cam_ra[body_bbox_valid>0],
+ 'smplx_pose': smplx_pose[body_bbox_valid>0],
+ 'smplx_shape': smplx_shape[body_bbox_valid>0],
+ 'smplx_expr': smplx_expr[body_bbox_valid>0],
+ 'lhand_bbox_center': lhand_bbox_center[body_bbox_valid>0],
+ 'lhand_bbox_size': lhand_bbox_size[body_bbox_valid>0],
+ 'rhand_bbox_center': rhand_bbox_center[body_bbox_valid>0],
+ 'rhand_bbox_size': rhand_bbox_size[body_bbox_valid>0],
+ 'face_bbox_center': face_bbox_center[body_bbox_valid>0],
+ 'face_bbox_size': face_bbox_size[body_bbox_valid>0],
+ 'body_bbox_center': body_bbox_center[body_bbox_valid>0],
+ 'body_bbox_size': body_bbox_size[body_bbox_valid>0],
+ 'body_bbox': body_bbox.reshape(-1,4)[body_bbox_valid>0],
+ 'lhand_bbox': lhand_bbox.reshape(-1,4)[body_bbox_valid>0],
+ 'rhand_bbox': rhand_bbox.reshape(-1,4)[body_bbox_valid>0],
+ 'face_bbox': face_bbox.reshape(-1,4)[body_bbox_valid>0],
+ 'gender': gender[body_bbox_valid>0]}
+
+ meta_info = {
+ 'joint_trunc': joint_trunc[body_bbox_valid>0],
+ 'smplx_pose_valid': smplx_pose_valid[body_bbox_valid>0],
+ 'smplx_shape_valid': smplx_shape_valid[body_bbox_valid>0],
+ 'smplx_expr_valid': smplx_expr_valid[body_bbox_valid>0],
+ 'is_3D': is_3D,
+ 'lhand_bbox_valid': lhand_bbox_valid[body_bbox_valid>0],
+ 'rhand_bbox_valid': rhand_bbox_valid[body_bbox_valid>0],
+ 'face_bbox_valid': face_bbox_valid[body_bbox_valid>0],
+ 'body_bbox_valid': body_bbox_valid[body_bbox_valid>0],
+ 'img_shape': np.array(img.shape[:2]),
+ 'ori_shape':data['img_shape'],
+ 'idx': idx
+
+ }
+ result = {**inputs, **targets, **meta_info}
+
+ result = self.normalize(result)
+ result = self.format(result)
+ return result
+
+
+
+ if self.data_split == 'test':
+ self.cam_param = {}
+ joint_cam = data['joint_cam']
+
+ if joint_cam is not None:
+ dummy_cord = False
+ joint_cam[:,:,:3] = joint_cam[:,:,:3] - joint_cam[
+ :, self.joint_set['root_joint_idx'], None, :3] # root-relative
+ else:
+ # dummy cord as joint_cam
+ dummy_cord = True
+ joint_cam = np.zeros(
+ (num_person, self.joint_set['joint_num'], 3),
+ dtype=np.float32)
+
+ joint_img = data['joint_img']
+
+
+ joint_img_aug, joint_cam_wo_ra, joint_cam_ra, joint_trunc = \
+ process_db_coord_batch_no_valid(
+ joint_img, joint_cam, do_flip, img_shape,
+ self.joint_set['flip_pairs'], img2bb_trans, rot,
+ self.joint_set['joints_name'], smpl_x.joints_name,
+ cropped_img_shape)
+
+
+
+ # smplx coordinates and parameters
+ smplx_param = data['smplx_param']
+ # smplx_cam_trans = np.array(
+ # smplx_param['trans']) if 'trans' in smplx_param else None
+ # TODO: remove this, seperate smpl and smplx
+ smplx_pose, smplx_shape, smplx_expr, smplx_pose_valid, \
+ smplx_joint_valid, smplx_expr_valid, smplx_shape_valid = \
+ process_human_model_output_batch_simplify(
+ smplx_param, do_flip, rot, as_smplx)
+
+ # if cam not provided, we take joint_img as smplx joint 2d,
+ # which is commonly the case for our processed humandata
+ if self.use_betas_neutral:
+ smplx_shape = smplx_param['betas_neutral'].reshape(
+ num_person, -1)
+ smplx_shape[(np.abs(smplx_shape) > 3).any(axis=1)] = 0.
+ smplx_shape = smplx_shape.reshape(num_person, -1)
+
+ smplx_joint_valid = smplx_joint_valid[:, :, None]
+
+ lhand_bbox_center_list = []
+ lhand_bbox_valid_list = []
+ lhand_bbox_size_list = []
+ lhand_bbox_list = []
+ face_bbox_center_list = []
+ face_bbox_size_list = []
+ face_bbox_valid_list = []
+ face_bbox_list = []
+ rhand_bbox_center_list = []
+ rhand_bbox_valid_list = []
+ rhand_bbox_size_list = []
+ rhand_bbox_list = []
+ body_bbox_center_list = []
+ body_bbox_size_list = []
+ body_bbox_valid_list = []
+ body_bbox_list = []
+
+ for i in range(num_person):
+ lhand_bbox, lhand_bbox_valid = self.process_hand_face_bbox(
+ data['lhand_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ rhand_bbox, rhand_bbox_valid = self.process_hand_face_bbox(
+ data['rhand_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ face_bbox, face_bbox_valid = self.process_hand_face_bbox(
+ data['face_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+
+ body_bbox, body_bbox_valid = self.process_hand_face_bbox(
+ data['bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+
+ if do_flip:
+ lhand_bbox, rhand_bbox = rhand_bbox, lhand_bbox
+ lhand_bbox_valid, rhand_bbox_valid = rhand_bbox_valid, lhand_bbox_valid
+
+ body_bbox_list.append(body_bbox)
+ lhand_bbox_list.append(lhand_bbox)
+ rhand_bbox_list.append(rhand_bbox)
+ face_bbox_list.append(face_bbox)
+
+ lhand_bbox_center = (lhand_bbox[0] + lhand_bbox[1]) / 2.
+ rhand_bbox_center = (rhand_bbox[0] + rhand_bbox[1]) / 2.
+ face_bbox_center = (face_bbox[0] + face_bbox[1]) / 2.
+ body_bbox_center = (body_bbox[0] + body_bbox[1]) / 2.
+ lhand_bbox_size = lhand_bbox[1] - lhand_bbox[0]
+ rhand_bbox_size = rhand_bbox[1] - rhand_bbox[0]
+
+ face_bbox_size = face_bbox[1] - face_bbox[0]
+ body_bbox_size = body_bbox[1] - body_bbox[0]
+ lhand_bbox_center_list.append(lhand_bbox_center)
+ lhand_bbox_valid_list.append(lhand_bbox_valid)
+ lhand_bbox_size_list.append(lhand_bbox_size)
+ face_bbox_center_list.append(face_bbox_center)
+ face_bbox_size_list.append(face_bbox_size)
+ face_bbox_valid_list.append(face_bbox_valid)
+ rhand_bbox_center_list.append(rhand_bbox_center)
+ rhand_bbox_valid_list.append(rhand_bbox_valid)
+ rhand_bbox_size_list.append(rhand_bbox_size)
+ body_bbox_center_list.append(body_bbox_center)
+ body_bbox_size_list.append(body_bbox_size)
+ body_bbox_valid_list.append(body_bbox_valid)
+
+ body_bbox = np.stack(body_bbox_list, axis=0)
+ lhand_bbox = np.stack(lhand_bbox_list, axis=0)
+ rhand_bbox = np.stack(rhand_bbox_list, axis=0)
+ face_bbox = np.stack(face_bbox_list, axis=0)
+ lhand_bbox_center = np.stack(lhand_bbox_center_list, axis=0)
+ lhand_bbox_valid = np.stack(lhand_bbox_valid_list, axis=0)
+ lhand_bbox_size = np.stack(lhand_bbox_size_list, axis=0)
+ face_bbox_center = np.stack(face_bbox_center_list, axis=0)
+ face_bbox_size = np.stack(face_bbox_size_list, axis=0)
+ face_bbox_valid = np.stack(face_bbox_valid_list, axis=0)
+ body_bbox_center = np.stack(body_bbox_center_list, axis=0)
+ body_bbox_size = np.stack(body_bbox_size_list, axis=0)
+ body_bbox_valid = np.stack(body_bbox_valid_list, axis=0)
+ rhand_bbox_center = np.stack(rhand_bbox_center_list, axis=0)
+ rhand_bbox_valid = np.stack(rhand_bbox_valid_list, axis=0)
+ rhand_bbox_size = np.stack(rhand_bbox_size_list, axis=0)
+
+
+ inputs = {'img': img}
+
+ targets = {
+ # keypoints2d, [0,img_w],[0,img_h] -> [0,1] -> [0,output_hm_shape]
+ 'joint_img': joint_img_aug,
+ # projected smplx if valid cam_param, else same as keypoints2d
+ # joint_cam, kp3d wo ra # raw kps3d probably without ra
+ 'joint_cam': joint_cam_wo_ra,
+ 'ann_idx': idx,
+ # kps3d with body, face, hand ra
+ 'smplx_joint_cam': joint_cam_ra,
+ 'smplx_pose': smplx_pose,
+ 'smplx_shape': smplx_shape,
+ 'smplx_expr': smplx_expr,
+ 'lhand_bbox_center': lhand_bbox_center,
+ 'lhand_bbox_size': lhand_bbox_size,
+ 'rhand_bbox_center': rhand_bbox_center,
+ 'rhand_bbox_size': rhand_bbox_size,
+ 'face_bbox_center': face_bbox_center,
+ 'face_bbox_size': face_bbox_size,
+ 'body_bbox_center': body_bbox_center,
+ 'body_bbox_size': body_bbox_size,
+ 'body_bbox': body_bbox.reshape(-1,4),
+ 'lhand_bbox': lhand_bbox.reshape(-1,4),
+ 'rhand_bbox': rhand_bbox.reshape(-1,4),
+ 'face_bbox': face_bbox.reshape(-1,4),
+ 'gender': gender,
+ 'bb2img_trans': bb2img_trans,
+ }
+
+ if self.body_only:
+ meta_info = {
+ 'joint_trunc': joint_trunc,
+ 'smplx_pose_valid': smplx_pose_valid,
+ 'smplx_shape_valid': float(smplx_shape_valid),
+ 'smplx_expr_valid': smplx_expr_valid,
+ 'is_3D': float(False) if dummy_cord else float(True),
+ 'lhand_bbox_valid': lhand_bbox_valid,
+ 'rhand_bbox_valid': rhand_bbox_valid,
+ 'face_bbox_valid': face_bbox_valid,
+ 'body_bbox_valid': body_bbox_valid,
+ 'img_shape': np.array(img.shape[:2]),
+ 'ori_shape':data['img_shape'],
+ 'idx': idx
+ }
+ else:
+ meta_info = {
+ 'joint_trunc': joint_trunc,
+ 'smplx_pose_valid': smplx_pose_valid,
+ 'smplx_shape_valid': smplx_shape_valid,
+ 'smplx_expr_valid': smplx_expr_valid,
+ 'is_3D': float(False) if dummy_cord else float(True),
+ 'lhand_bbox_valid': lhand_bbox_valid,
+ 'rhand_bbox_valid': rhand_bbox_valid,
+ 'face_bbox_valid': face_bbox_valid,
+ 'body_bbox_valid': body_bbox_valid,
+ 'img_shape': np.array(img.shape[:2]),
+ 'ori_shape':data['img_shape'],
+ 'idx': idx
+ }
+
+ result = {**inputs, **targets, **meta_info}
+ result = self.normalize(result)
+ result = self.format(result)
+ return result
+
+ def evaluate(self, outs, cur_sample_idx):
+ annots = self.datalist
+ sample_num = len(outs)
+ eval_result = {
+ 'pa_mpvpe_all': [],
+ 'pa_mpvpe_l_hand': [],
+ 'pa_mpvpe_r_hand': [],
+ 'pa_mpvpe_hand': [],
+ 'pa_mpvpe_face': [],
+ 'mpvpe_all': [],
+ 'mpvpe_l_hand': [],
+ 'mpvpe_r_hand': [],
+ 'mpvpe_hand': [],
+ 'mpvpe_face': []
+ }
+
+ vis = getattr(cfg, 'vis', False)
+ vis_save_dir = cfg.vis_dir
+
+ csv_file = f'{cfg.result_dir}/agora_smplx_error.csv'
+ file = open(csv_file, 'a', newline='')
+ for n in range(sample_num):
+ annot = annots[cur_sample_idx + n]
+ out = outs[n]
+ mesh_gt = out['smplx_mesh_cam_target']
+ mesh_out = out['smplx_mesh_cam']
+
+ # print('zzz',mesh_gt.shape,mesh_out.shape)
+ # from pytorch3d.io import save_obj
+ # for m_i,(mesh_gt_i,mesh_out_i) in enumerate(zip(mesh_gt,mesh_out)):
+ # save_obj('temp_gt_%d.obj'%m_i,verts=torch.Tensor(mesh_gt_i),faces=torch.tensor([]))
+ # save_obj('temp_pred_%d.obj'%m_i,verts=torch.Tensor(mesh_out_i),faces=torch.tensor([]))
+
+ ann_idx = out['gt_ann_idx']
+ img_path = []
+ for ann_id in ann_idx:
+ img_path.append(annots[ann_id]['img_path'])
+ eval_result['img_path'] = img_path
+ eval_result['ann_idx'] = ann_idx
+ # MPVPE from all vertices
+ mesh_out_align = \
+ mesh_out - np.dot(
+ smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['pelvis'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['pelvis'], None, :]
+
+ eval_result['mpvpe_all'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_align - mesh_gt)**2, -1)).mean(-1) * 1000)
+ mesh_out_align = rigid_align_batch(mesh_out, mesh_gt)
+ eval_result['pa_mpvpe_all'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_align - mesh_gt)**2, -1)).mean(-1) * 1000)
+
+ # MPVPE from hand vertices
+ mesh_gt_lhand = mesh_gt[:, smpl_x.hand_vertex_idx['left_hand'], :]
+ mesh_out_lhand = mesh_out[:, smpl_x.hand_vertex_idx['left_hand'], :]
+ mesh_gt_rhand = mesh_gt[:, smpl_x.hand_vertex_idx['right_hand'], :]
+ mesh_out_rhand = mesh_out[:, smpl_x.hand_vertex_idx['right_hand'], :]
+ mesh_out_lhand_align = \
+ mesh_out_lhand - \
+ np.dot(smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['lwrist'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['lwrist'], None, :]
+
+ mesh_out_rhand_align = \
+ mesh_out_rhand - \
+ np.dot(smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['rwrist'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['rwrist'], None, :]
+
+ eval_result['mpvpe_l_hand'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean(-1) *
+ 1000)
+ eval_result['mpvpe_r_hand'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean(-1) *
+ 1000)
+ eval_result['mpvpe_hand'].extend(
+ (np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean(-1) *
+ 1000 +
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean(-1) *
+ 1000) / 2.)
+ mesh_out_lhand_align = rigid_align_batch(mesh_out_lhand, mesh_gt_lhand)
+ mesh_out_rhand_align = rigid_align_batch(mesh_out_rhand, mesh_gt_rhand)
+ eval_result['pa_mpvpe_l_hand'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean(-1) *
+ 1000)
+ eval_result['pa_mpvpe_r_hand'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean(-1) *
+ 1000)
+ eval_result['pa_mpvpe_hand'].extend(
+ (np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean(-1) *
+ 1000 +
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean(-1) *
+ 1000) / 2.)
+
+
+ save_error=True
+ if save_error:
+ writer = csv.writer(file)
+ new_line = [ann_idx[n],img_path[n], eval_result['mpvpe_all'][-1], eval_result['pa_mpvpe_all'][-1]]
+ writer.writerow(new_line)
+ self.save_idx += 1
+
+
+ return eval_result
+
+
+ def print_eval_result(self, eval_result):
+
+ print('AGORA test results are dumped at: ' +
+ osp.join(cfg.result_dir, 'predictions'))
+
+ if self.data_split == 'test' and self.test_set == 'test': # do not print. just submit the results to the official evaluation server
+ return
+
+ print('======AGORA-val======')
+ print('PA MPVPE (All): %.2f mm' % np.mean(eval_result['pa_mpvpe_all']))
+ print('PA MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_l_hand']))
+ print('PA MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_r_hand']))
+ print('PA MPVPE (Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_hand']))
+ print('PA MPVPE (Face): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_face']))
+ print()
+
+ print('MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all']))
+ print('MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_l_hand']))
+ print('MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_r_hand']))
+ print('MPVPE (Hands): %.2f mm' % np.mean(eval_result['mpvpe_hand']))
+ print('MPVPE (Face): %.2f mm' % np.mean(eval_result['mpvpe_face']))
+
+ out_file = osp.join(cfg.result_dir,'agora_val.txt')
+ if os.path.exists(out_file):
+ f = open(out_file, 'a+')
+ else:
+ f = open(out_file, 'w', encoding="utf-8")
+
+ f.write('\n')
+ f.write(f'{cfg.exp_name}\n')
+ f.write(f'AGORA-val dataset: \n')
+ f.write('PA MPVPE (All): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_all']))
+ f.write('PA MPVPE (L-Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_l_hand']))
+ f.write('PA MPVPE (R-Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_r_hand']))
+ f.write('PA MPVPE (Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_hand']))
+ f.write('PA MPVPE (Face): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_face']))
+ f.write('MPVPE (All): %.2f mm\n' % np.mean(eval_result['mpvpe_all']))
+ f.write('MPVPE (L-Hands): %.2f mm\n' %
+ np.mean(eval_result['mpvpe_l_hand']))
+ f.write('MPVPE (R-Hands): %.2f mm\n' %
+ np.mean(eval_result['mpvpe_r_hand']))
+ f.write('MPVPE (Hands): %.2f mm\n' % np.mean(eval_result['mpvpe_hand']))
+ f.write('MPVPE (Face): %.2f mm\n' % np.mean(eval_result['mpvpe_face']))
diff --git a/datasets/ARCTIC.py b/datasets/ARCTIC.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c96065d579fb642596ffa3f3f15686a6e4e0d46
--- /dev/null
+++ b/datasets/ARCTIC.py
@@ -0,0 +1,215 @@
+import os
+import os.path as osp
+from glob import glob
+import numpy as np
+from config.config import cfg
+
+import csv
+
+from util.human_models import smpl_x
+
+from util.transforms import rigid_align_batch
+
+from humandata import HumanDataset
+
+class ARCTIC(HumanDataset):
+ def __init__(self, transform, data_split):
+ super(ARCTIC, self).__init__(transform, data_split)
+
+ self.img_dir = 'data/osx_data/ARCTIC'
+
+
+ if data_split == 'train':
+ self.annot_path = 'data/preprocessed_npz/multihuman_data/p1_train_multi.npz'
+ self.annot_path_cache = 'data/preprocessed_npz/cache/p1_train_cache_sample1000_080824.npz'
+ self.sample_interval = 1000
+ elif data_split == 'test':
+ self.annot_path = 'data/preprocessed_npz_old/multihuman_data/p1_val_multi.npz'
+ self.annot_path_cache = 'data/preprocessed_npz_old/cache/p1_val_cache_30.npz'
+ self.sample_interval = 30
+
+
+ self.use_cache = getattr(cfg, 'use_cache', False)
+ self.img_shape = None #1024, 1024) # (h, w)
+ self.cam_param = {}
+ self.use_cache=True
+ # load data
+ if self.use_cache and osp.isfile(self.annot_path_cache):
+ print(
+ f'[{self.__class__.__name__}] loading cache from {self.annot_path_cache}'
+ )
+ self.datalist = self.load_cache(self.annot_path_cache)
+ else:
+ if self.use_cache:
+ print(
+ f'[{self.__class__.__name__}] Cache not found, generating cache...'
+ )
+ self.datalist = self.load_data(train_sample_interval=getattr(
+ cfg, f'{self.__class__.__name__}_train_sample_interval', self.sample_interval))
+ if self.use_cache:
+ self.save_cache(self.annot_path_cache, self.datalist)
+
+
+ def evaluate(self, outs, cur_sample_idx):
+ annots = self.datalist
+ sample_num = len(outs)
+ eval_result = {
+ 'pa_mpvpe_all': [],
+ 'pa_mpvpe_l_hand': [],
+ 'pa_mpvpe_r_hand': [],
+ 'pa_mpvpe_hand': [],
+ 'pa_mpvpe_face': [],
+ 'mpvpe_all': [],
+ 'mpvpe_l_hand': [],
+ 'mpvpe_r_hand': [],
+ 'mpvpe_hand': [],
+ 'mpvpe_face': []
+ }
+
+ vis = getattr(cfg, 'vis', False)
+ vis_save_dir = cfg.vis_dir
+ csv_file = f'{cfg.result_dir}/arctic_smplx_error.csv'
+ file = open(csv_file, 'a', newline='')
+
+ for n in range(sample_num):
+ annot = annots[cur_sample_idx + n]
+ out = outs[n]
+ mesh_gt = out['smplx_mesh_cam_target']
+ mesh_out = out['smplx_mesh_cam']
+ ann_idx = out['gt_ann_idx']
+ img_path = []
+ for ann_id in ann_idx:
+ img_path.append(annots[ann_id]['img_path'])
+ eval_result['img_path'] = img_path
+ # MPVPE from all vertices
+ mesh_out_align = \
+ mesh_out - np.dot(
+ smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['pelvis'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['pelvis'], None, :]
+
+ eval_result['mpvpe_all'].append(
+ np.sqrt(np.sum(
+ (mesh_out_align - mesh_gt)**2, -1)).mean() * 1000)
+ mesh_out_align = rigid_align_batch(mesh_out, mesh_gt)
+ eval_result['pa_mpvpe_all'].append(
+ np.sqrt(np.sum(
+ (mesh_out_align - mesh_gt)**2, -1)).mean() * 1000)
+
+ # MPVPE from hand vertices
+ mesh_gt_lhand = mesh_gt[:, smpl_x.hand_vertex_idx['left_hand'], :]
+ mesh_out_lhand = mesh_out[:, smpl_x.hand_vertex_idx['left_hand'], :]
+ mesh_gt_rhand = mesh_gt[:, smpl_x.hand_vertex_idx['right_hand'], :]
+ mesh_out_rhand = mesh_out[:, smpl_x.hand_vertex_idx['right_hand'], :]
+ mesh_out_lhand_align = \
+ mesh_out_lhand - \
+ np.dot(smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['lwrist'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['lwrist'], None, :]
+
+ mesh_out_rhand_align = \
+ mesh_out_rhand - \
+ np.dot(smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['rwrist'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['rwrist'], None, :]
+
+ eval_result['mpvpe_l_hand'].append(
+ np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean() *
+ 1000)
+ eval_result['mpvpe_r_hand'].append(
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean() *
+ 1000)
+ eval_result['mpvpe_hand'].append(
+ (np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean() *
+ 1000 +
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean() *
+ 1000) / 2.)
+ mesh_out_lhand_align = rigid_align_batch(mesh_out_lhand, mesh_gt_lhand)
+ mesh_out_rhand_align = rigid_align_batch(mesh_out_rhand, mesh_gt_rhand)
+ eval_result['pa_mpvpe_l_hand'].append(
+ np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean() *
+ 1000)
+ eval_result['pa_mpvpe_r_hand'].append(
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean() *
+ 1000)
+ eval_result['pa_mpvpe_hand'].append(
+ (np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean() *
+ 1000 +
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean() *
+ 1000) / 2.)
+
+ # MPVPE from face vertices
+ mesh_gt_face = mesh_gt[:, smpl_x.face_vertex_idx, :]
+ mesh_out_face = mesh_out[:, smpl_x.face_vertex_idx, :]
+ mesh_out_face_align = \
+ mesh_out_face - \
+ np.dot(smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['neck'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['neck'], None, :]
+ eval_result['mpvpe_face'].append(
+ np.sqrt(np.sum(
+ (mesh_out_face_align - mesh_gt_face)**2, -1)).mean() * 1000)
+ mesh_out_face_align = rigid_align_batch(mesh_out_face, mesh_gt_face)
+ eval_result['pa_mpvpe_face'].append(
+ np.sqrt(np.sum(
+ (mesh_out_face_align - mesh_gt_face)**2, -1)).mean() * 1000)
+
+ save_error=True
+ if save_error:
+ writer = csv.writer(file)
+ new_line = [ann_idx[n], img_path[n], eval_result['mpvpe_all'][-1], eval_result['pa_mpvpe_all'][-1]]
+ writer.writerow(new_line)
+ # self.save_idx += 1
+ return eval_result
+
+ def print_eval_result(self, eval_result):
+
+ print('======ARCTIC-val======')
+ print('PA MPVPE (All): %.2f mm' % np.mean(eval_result['pa_mpvpe_all']))
+ print('PA MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_l_hand']))
+ print('PA MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_r_hand']))
+ print('PA MPVPE (Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_hand']))
+ print('PA MPVPE (Face): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_face']))
+ print()
+
+ print('MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all']))
+ print('MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_l_hand']))
+ print('MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_r_hand']))
+ print('MPVPE (Hands): %.2f mm' % np.mean(eval_result['mpvpe_hand']))
+ print('MPVPE (Face): %.2f mm' % np.mean(eval_result['mpvpe_face']))
+
+ out_file = osp.join(cfg.result_dir,'arctic_val.txt')
+ if os.path.exists(out_file):
+ f = open(out_file, 'a+')
+ else:
+ f = open(out_file, 'w', encoding="utf-8")
+ f.write('\n')
+ f.write(f'{cfg.exp_name}\n')
+ f.write(f'ARCTIC-val dataset: \n')
+ f.write('PA MPVPE (All): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_all']))
+ f.write('PA MPVPE (L-Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_l_hand']))
+ f.write('PA MPVPE (R-Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_r_hand']))
+ f.write('PA MPVPE (Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_hand']))
+ f.write('PA MPVPE (Face): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_face']))
+ f.write('MPVPE (All): %.2f mm\n' % np.mean(eval_result['mpvpe_all']))
+ f.write('MPVPE (L-Hands): %.2f mm\n' %
+ np.mean(eval_result['mpvpe_l_hand']))
+ f.write('MPVPE (R-Hands): %.2f mm\n' %
+ np.mean(eval_result['mpvpe_r_hand']))
+ f.write('MPVPE (Hands): %.2f mm\n' % np.mean(eval_result['mpvpe_hand']))
+ f.write('MPVPE (Face): %.2f mm\n' % np.mean(eval_result['mpvpe_face']))
diff --git a/datasets/BEDLAM.py b/datasets/BEDLAM.py
new file mode 100644
index 0000000000000000000000000000000000000000..566de0e3253a12a4086b5d635b5dcf6410f42fe5
--- /dev/null
+++ b/datasets/BEDLAM.py
@@ -0,0 +1,32 @@
+import os.path as osp
+from config.config import cfg
+from humandata import HumanDataset
+
+
+class BEDLAM(HumanDataset):
+ def __init__(self, transform, data_split):
+ super(BEDLAM, self).__init__(transform, data_split)
+
+ self.img_dir = './data/datasets/bedlam/train_images/'
+ self.annot_path = 'data/preprocessed_npz/multihuman_data/bedlam_train_multi_0915.npz'
+ self.annot_path_cache = 'data/preprocessed_npz/cache/bedlam_train_cache_080824.npz'
+ self.use_cache = getattr(cfg, 'use_cache', False)
+
+ self.img_shape = None #1024, 1024) # (h, w)
+ self.cam_param = {}
+
+ # load data or cache
+ if self.use_cache and osp.isfile(self.annot_path_cache):
+ print(
+ f'[{self.__class__.__name__}] loading cache from {self.annot_path_cache}'
+ )
+ self.datalist = self.load_cache(self.annot_path_cache)
+ else:
+ if self.use_cache:
+ print(
+ f'[{self.__class__.__name__}] Cache not found, generating cache...'
+ )
+ self.datalist = self.load_data(train_sample_interval=getattr(
+ cfg, f'{self.__class__.__name__}_train_sample_interval', 5))
+ if self.use_cache:
+ self.save_cache(self.annot_path_cache, self.datalist)
diff --git a/datasets/COCO_NA.py b/datasets/COCO_NA.py
new file mode 100644
index 0000000000000000000000000000000000000000..38a553deb21e5e005becc4442bc0ea7a32189dcf
--- /dev/null
+++ b/datasets/COCO_NA.py
@@ -0,0 +1,36 @@
+import os
+import os.path as osp
+import numpy as np
+
+# from osx.common.utils.human_models import smpl_x
+
+from humandata import HumanDataset
+from config.config import cfg
+
+
+class COCO_NA(HumanDataset):
+ def __init__(self, transform, data_split):
+ super(COCO_NA, self).__init__(transform, data_split)
+ self.img_dir = 'data/datasets/coco_2017'
+ self.annot_path = 'data/preprocessed_npz/multihuman_data/coco_wholebody_new_train_multi.npz'
+ self.annot_path_cache = 'data/preprocessed_npz/cache/coco_train_cache_080824.npz'
+ # osp.join(cfg.data_dir, 'cache', filename)
+ self.keypoints2d = 'keypoints2d_ori'
+ self.use_cache = getattr(cfg, 'use_cache', False)
+ self.cam_param = {}
+
+ # load data or cache
+ if self.use_cache and osp.isfile(self.annot_path_cache):
+ print(
+ f'[{self.__class__.__name__}] loading cache from {self.annot_path_cache}'
+ )
+ self.datalist = self.load_cache(self.annot_path_cache)
+ else:
+ if self.use_cache:
+ print(
+ f'[{self.__class__.__name__}] Cache not found, generating cache...'
+ )
+ self.datalist = self.load_data(train_sample_interval=getattr(
+ cfg, f'{self.__class__.__name__}_train_sample_interval', 1))
+ if self.use_cache:
+ self.save_cache(self.annot_path_cache, self.datalist)
diff --git a/datasets/EHF.py b/datasets/EHF.py
new file mode 100644
index 0000000000000000000000000000000000000000..f75fad8515111c09c2dd968ae69f27e10908eb88
--- /dev/null
+++ b/datasets/EHF.py
@@ -0,0 +1,289 @@
+import os
+import os.path as osp
+from glob import glob
+import numpy as np
+from config.config import cfg
+import copy
+import json
+import cv2
+import torch
+from pycocotools.coco import COCO
+from util.human_models import smpl_x
+from util.preprocessing import load_img, process_bbox, load_ply
+from util.transforms import rigid_align, rigid_align_batch
+from humandata import HumanDataset
+import csv
+
+class EHF(HumanDataset):
+ def __init__(self, transform, data_split):
+ super(EHF, self).__init__(transform, data_split)
+
+ self.transform = transform
+ self.data_split = data_split
+ self.save_idx = 0
+ # self.cam_param = {'R': [-2.98747896, 0.01172457, -0.05704687]}
+ # self.cam_param['R'], _ = cv2.Rodrigues(np.array(self.cam_param['R']))
+ self.cam_param = {}
+ self.img_dir = 'data/data_weichen/ehf'
+ self.img_shape = [1200, 1600]
+
+ self.annot_path = 'data_tmp/multihuman_data/ehf_val_230908_100.npz'
+ self.annot_path_cache = 'data_tmp/cache/ehf_val_cache_230908_100.npz'
+
+ if self.use_cache and osp.isfile(self.annot_path_cache):
+ print(f'[{self.__class__.__name__}] loading cache from {self.annot_path_cache}')
+ self.datalist = self.load_cache(self.annot_path_cache)
+ else:
+ if self.use_cache:
+ print(f'[{self.__class__.__name__}] Cache not found, generating cache...')
+ self.datalist = self.load_data(
+ train_sample_interval=getattr(cfg, f'{self.__class__.__name__}_train_sample_interval', 1))
+ if self.use_cache:
+ self.save_cache(self.annot_path_cache, self.datalist)
+
+
+ def evaluate(self, outs, cur_sample_idx):
+ annots = self.datalist
+ sample_num = len(outs)
+ eval_result = {
+ 'pa_mpvpe_all': [],
+ 'pa_mpvpe_l_hand': [],
+ 'pa_mpvpe_r_hand': [],
+ 'pa_mpvpe_hand': [],
+ 'pa_mpvpe_face': [],
+ 'mpvpe_all': [],
+ 'mpvpe_l_hand': [],
+ 'mpvpe_r_hand': [],
+ 'mpvpe_hand': [],
+ 'mpvpe_face': [],
+ 'pa_mpjpe_body': [],
+ 'pa_mpjpe_l_hand': [],
+ 'pa_mpjpe_r_hand': [],
+ 'pa_mpjpe_hand': []
+ }
+
+ csv_file = f'{cfg.result_dir}/ehf_smplx_error.csv'
+ file = open(csv_file, 'a', newline='')
+ for n in range(sample_num):
+ annot = annots[cur_sample_idx + n]
+ ann_id = annot['img_path'].split('/')[-1].split('_')[0]
+ out = outs[n]
+ ann_idx = out['gt_ann_idx']
+ img_path = []
+ for ann_id in ann_idx:
+ img_path.append(annots[ann_id]['img_path'])
+ eval_result['img_path'] = img_path
+ eval_result['ann_idx'] = ann_idx
+ # MPVPE from all vertices np.dot(self.cam_param['R'], out['smplx_mesh_cam_target'].transpose(0,2,1)).transpose(1,2,0)
+ # mesh_gt = np.dot(
+ # self.cam_param['R'],
+ # out['smplx_mesh_cam_target'].transpose(0,2,1)
+ # ).transpose(1,2,0)
+ mesh_gt = out['smplx_mesh_cam_target']
+ mesh_out = out['smplx_mesh_cam']
+
+ # mesh_gt_align = rigid_align(mesh_gt, mesh_out)
+
+ # print(mesh_out.shape)
+ mesh_out_align = rigid_align_batch(mesh_out, mesh_gt)
+ eval_result['pa_mpvpe_all'].append(
+ np.sqrt(np.sum(
+ (mesh_out_align - mesh_gt)**2, -1)).mean() * 1000)
+ mesh_out_align = mesh_out - np.dot(
+ smpl_x.J_regressor,
+ mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['pelvis'], None, :] + np.dot(
+ smpl_x.J_regressor,
+ mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['pelvis'], None, :]
+ eval_result['mpvpe_all'].append(
+ np.sqrt(np.sum(
+ (mesh_out_align - mesh_gt)**2, -1)).mean() * 1000)
+
+ # MPVPE from hand vertices
+ mesh_gt_lhand = mesh_gt[:, smpl_x.hand_vertex_idx['left_hand'], :]
+ mesh_out_lhand = mesh_out[:, smpl_x.hand_vertex_idx['left_hand'], :]
+ mesh_out_lhand_align = rigid_align_batch(mesh_out_lhand, mesh_gt_lhand)
+ mesh_gt_rhand = mesh_gt[:, smpl_x.hand_vertex_idx['right_hand'], :]
+ mesh_out_rhand = mesh_out[:, smpl_x.hand_vertex_idx['right_hand'], :]
+ mesh_out_rhand_align = rigid_align_batch(mesh_out_rhand, mesh_gt_rhand)
+ eval_result['pa_mpvpe_l_hand'].append(
+ np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean() *
+ 1000)
+ eval_result['pa_mpvpe_r_hand'].append(
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean() *
+ 1000)
+ eval_result['pa_mpvpe_hand'].append(
+ (np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean() *
+ 1000 +
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean() *
+ 1000) / 2.)
+
+ mesh_out_lhand_align = mesh_out_lhand - np.dot(
+ smpl_x.J_regressor,
+ mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['lwrist'], None, :] + np.dot(
+ smpl_x.J_regressor,
+ mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['lwrist'], None, :]
+ mesh_out_rhand_align = mesh_out_rhand - np.dot(
+ smpl_x.J_regressor,
+ mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['rwrist'], None, :] + np.dot(
+ smpl_x.J_regressor,
+ mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['rwrist'], None, :]
+
+ eval_result['mpvpe_l_hand'].append(
+ np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean() *
+ 1000)
+ eval_result['mpvpe_r_hand'].append(
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean() *
+ 1000)
+ eval_result['mpvpe_hand'].append(
+ (np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean() *
+ 1000 +
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean() *
+ 1000) / 2.)
+
+ # MPVPE from face vertices
+ mesh_gt_face = mesh_gt[:, smpl_x.face_vertex_idx, :]
+ mesh_out_face = mesh_out[:, smpl_x.face_vertex_idx, :]
+ mesh_out_face_align = rigid_align_batch(mesh_out_face, mesh_gt_face)
+ eval_result['pa_mpvpe_face'].append(
+ np.sqrt(np.sum(
+ (mesh_out_face_align - mesh_gt_face)**2, -1)).mean() * 1000)
+ mesh_out_face_align = mesh_out_face - np.dot(
+ smpl_x.J_regressor,
+ mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['neck'], None, :] + np.dot(
+ smpl_x.J_regressor,
+ mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['neck'], None, :]
+ eval_result['mpvpe_face'].append(
+ np.sqrt(np.sum(
+ (mesh_out_face_align - mesh_gt_face)**2, -1)).mean() * 1000)
+
+ # MPJPE from body joints
+ joint_gt_body = np.dot(smpl_x.j14_regressor, mesh_gt).transpose(1,0,2)
+ joint_out_body = np.dot(smpl_x.j14_regressor, mesh_out).transpose(1,0,2)
+ joint_out_body_align = rigid_align_batch(joint_out_body, joint_gt_body)
+ eval_result['pa_mpjpe_body'].append(
+ np.sqrt(np.sum(
+ (joint_out_body_align - joint_gt_body)**2, -1)).mean() *
+ 1000)
+
+ # MPJPE from hand joints
+ joint_gt_lhand = np.dot(smpl_x.orig_hand_regressor['left'],
+ mesh_gt).transpose(1,0,2)
+ joint_out_lhand = np.dot(smpl_x.orig_hand_regressor['left'],
+ mesh_out).transpose(1,0,2)
+ joint_out_lhand_align = rigid_align_batch(joint_out_lhand,
+ joint_gt_lhand)
+ joint_gt_rhand = np.dot(smpl_x.orig_hand_regressor['right'],
+ mesh_gt).transpose(1,0,2)
+ joint_out_rhand = np.dot(smpl_x.orig_hand_regressor['right'],
+ mesh_out).transpose(1,0,2)
+ joint_out_rhand_align = rigid_align_batch(joint_out_rhand,
+ joint_gt_rhand)
+ eval_result['pa_mpjpe_l_hand'].append(
+ np.sqrt(np.sum(
+ (joint_out_lhand_align - joint_gt_lhand)**2, -1)).mean() *
+ 1000)
+ eval_result['pa_mpjpe_r_hand'].append(
+ np.sqrt(np.sum(
+ (joint_out_rhand_align - joint_gt_rhand)**2, 1)).mean() *
+ 1000)
+ eval_result['pa_mpjpe_hand'].append(
+ (np.sqrt(np.sum(
+ (joint_out_lhand_align - joint_gt_lhand)**2, -1)).mean() *
+ 1000 +
+ np.sqrt(np.sum(
+ (joint_out_rhand_align - joint_gt_rhand)**2, -1)).mean() *
+ 1000) / 2.)
+ save_error=True
+ if save_error:
+ writer = csv.writer(file)
+ new_line = [ann_idx[n],img_path[n], eval_result['mpvpe_all'][-1], eval_result['pa_mpvpe_all'][-1]]
+ writer.writerow(new_line)
+ self.save_idx += 1
+
+ # vis = cfg.vis
+
+
+ for k,v in eval_result.items():
+ if k != 'img_path' and k != 'ann_idx':
+
+ if len(v)>1:
+ eval_result[k] = np.concatenate(v,axis=0)
+ else:
+ eval_result[k] = np.array(v)
+ return eval_result
+
+ def print_eval_result(self, eval_result):
+ print('======EHF======')
+ print('PA MPVPE (All): %.2f mm' % np.mean(eval_result['pa_mpvpe_all']))
+ print('PA MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_l_hand']))
+ print('PA MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_r_hand']))
+ print('PA MPVPE (Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_hand']))
+ print('PA MPVPE (Face): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_face']))
+ print()
+
+ print('MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all']))
+ print('MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_l_hand']))
+ print('MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_r_hand']))
+ print('MPVPE (Hands): %.2f mm' % np.mean(eval_result['mpvpe_hand']))
+ print('MPVPE (Face): %.2f mm' % np.mean(eval_result['mpvpe_face']))
+ print()
+
+ print('PA MPJPE (Body): %.2f mm' %
+ np.mean(eval_result['pa_mpjpe_body']))
+ print('PA MPJPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpjpe_l_hand']))
+ print('PA MPJPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpjpe_r_hand']))
+ print('PA MPJPE (Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpjpe_hand']))
+ out_file = osp.join(cfg.result_dir,'ehf_test.txt')
+ if os.path.exists(out_file):
+ f = open(out_file, 'a+')
+ else:
+ f = open(out_file, 'w', encoding="utf-8")
+
+ f.write('\n')
+ f.write(f'{cfg.exp_name}\n')
+ f.write(f'EHF dataset: \n')
+ f.write('PA MPVPE (All): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_all']))
+ f.write('PA MPVPE (L-Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_l_hand']))
+ f.write('PA MPVPE (R-Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_r_hand']))
+ f.write('PA MPVPE (Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_hand']))
+ f.write('PA MPVPE (Face): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_face']))
+ f.write('MPVPE (All): %.2f mm\n' % np.mean(eval_result['mpvpe_all']))
+ f.write('MPVPE (L-Hands): %.2f mm\n' %
+ np.mean(eval_result['mpvpe_l_hand']))
+ f.write('MPVPE (R-Hands): %.2f mm\n' %
+ np.mean(eval_result['mpvpe_r_hand']))
+ f.write('MPVPE (Hands): %.2f mm\n' % np.mean(eval_result['mpvpe_hand']))
+ f.write('MPVPE (Face): %.2f mm\n' % np.mean(eval_result['mpvpe_face']))
+ f.write('PA MPJPE (Body): %.2f mm\n' %
+ np.mean(eval_result['pa_mpjpe_body']))
+ f.write('PA MPJPE (L-Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpjpe_l_hand']))
+ f.write('PA MPJPE (R-Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpjpe_r_hand']))
+ f.write('PA MPJPE (Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpjpe_hand']))
+
+ f.close()
+
diff --git a/datasets/EgoBody_Egocentric.py b/datasets/EgoBody_Egocentric.py
new file mode 100644
index 0000000000000000000000000000000000000000..c69993965abdd12b6da0bdb4b6400913abebceb1
--- /dev/null
+++ b/datasets/EgoBody_Egocentric.py
@@ -0,0 +1,211 @@
+import os
+import os.path as osp
+import numpy as np
+import torch
+import cv2
+import json
+import copy
+import csv
+from pycocotools.coco import COCO
+from config.config import cfg
+from util.human_models import smpl_x
+
+from util.transforms import world2cam, cam2pixel, rigid_align
+from humandata import HumanDataset
+from util.transforms import rigid_align, rigid_align_batch
+
+
+
+class EgoBody_Egocentric(HumanDataset):
+ def __init__(self, transform, data_split):
+ super(EgoBody_Egocentric, self).__init__(transform, data_split)
+
+ if self.data_split == 'train':
+ filename = 'data/preprocessed_npz/multihuman_data/egobody_egocentric_train_multi_080824.npz'
+ self.annot_path_cache = 'data/preprocessed_npz/cache/egobody_egocentric_train_cache_080824.npz'
+ self.sample_interval = 5
+ else:
+ filename = 'data/preprocessed_npz/multihuman_data/egobody_egocentric_val_multi_080824.npz'
+ self.annot_path_cache = 'data/preprocessed_npz/cache/egobody_egocentric_val_cache_080824.npz'
+ self.sample_interval = 1
+ self.use_betas_neutral = getattr(cfg, 'egobody_fix_betas', False)
+
+ self.img_dir = 'data/osx_data/EgoBody'
+ self.annot_path = filename
+ self.use_cache = getattr(cfg, 'use_cache', False)
+ self.img_shape = (1080, 1920) # (h, w)
+ self.cam_param = {}
+
+ # check image shape
+ img_path = osp.join(self.img_dir,
+ np.load(self.annot_path)['image_path'][0])
+ img_shape = cv2.imread(img_path).shape[:2]
+ assert self.img_shape == img_shape, 'image shape is incorrect: {} vs {}'.format(
+ self.img_shape, img_shape)
+
+ # load data or cache
+ if self.use_cache and osp.isfile(self.annot_path_cache):
+ print(
+ f'[{self.__class__.__name__}] loading cache from {self.annot_path_cache}'
+ )
+ self.datalist = self.load_cache(self.annot_path_cache)
+ else:
+ if self.use_cache:
+ print(
+ f'[{self.__class__.__name__}] Cache not found, generating cache...'
+ )
+ self.datalist = self.load_data(train_sample_interval=getattr(
+ cfg, f'{self.__class__.__name__}_train_sample_interval', self.sample_interval))
+ if self.use_cache:
+ self.save_cache(self.annot_path_cache, self.datalist)
+
+ def evaluate(self, outs, cur_sample_idx):
+ annots = self.datalist
+ sample_num = len(outs)
+ eval_result = {
+ 'pa_mpvpe_all': [],
+ 'pa_mpvpe_l_hand': [],
+ 'pa_mpvpe_r_hand': [],
+ 'pa_mpvpe_hand': [],
+ 'pa_mpvpe_face': [],
+ 'mpvpe_all': [],
+ 'mpvpe_l_hand': [],
+ 'mpvpe_r_hand': [],
+ 'mpvpe_hand': [],
+ 'mpvpe_face': []
+ }
+
+ vis = getattr(cfg, 'vis', False)
+ vis_save_dir = cfg.vis_dir
+ csv_file = f'{cfg.result_dir}/egobody_smplx_error.csv'
+ file = open(csv_file, 'a', newline='')
+ for n in range(sample_num):
+ annot = annots[cur_sample_idx + n]
+ out = outs[n]
+ mesh_gt = out['smplx_mesh_cam_target']
+ mesh_out = out['smplx_mesh_cam']
+ ann_idx = out['gt_ann_idx']
+ img_path = []
+ for ann_id in ann_idx:
+ img_path.append(annots[ann_id]['img_path'])
+ eval_result['img_path'] = img_path
+ eval_result['ann_idx'] = ann_idx
+ # MPVPE from all vertices
+ mesh_out_align = \
+ mesh_out - np.dot(
+ smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['pelvis'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['pelvis'], None, :]
+
+ eval_result['mpvpe_all'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_align - mesh_gt)**2, -1)).mean(-1) * 1000)
+ mesh_out_align = rigid_align_batch(mesh_out, mesh_gt)
+ eval_result['pa_mpvpe_all'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_align - mesh_gt)**2, -1)).mean(-1) * 1000)
+
+ # MPVPE from hand vertices
+ mesh_gt_lhand = mesh_gt[:, smpl_x.hand_vertex_idx['left_hand'], :]
+ mesh_out_lhand = mesh_out[:, smpl_x.hand_vertex_idx['left_hand'], :]
+ mesh_gt_rhand = mesh_gt[:, smpl_x.hand_vertex_idx['right_hand'], :]
+ mesh_out_rhand = mesh_out[:, smpl_x.hand_vertex_idx['right_hand'], :]
+ mesh_out_lhand_align = \
+ mesh_out_lhand - \
+ np.dot(smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['lwrist'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['lwrist'], None, :]
+
+ mesh_out_rhand_align = \
+ mesh_out_rhand - \
+ np.dot(smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['rwrist'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['rwrist'], None, :]
+
+ eval_result['mpvpe_l_hand'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean(-1) *
+ 1000)
+ eval_result['mpvpe_r_hand'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean(-1) *
+ 1000)
+ eval_result['mpvpe_hand'].extend(
+ (np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean(-1) *
+ 1000 +
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean(-1) *
+ 1000) / 2.)
+ mesh_out_lhand_align = rigid_align_batch(mesh_out_lhand, mesh_gt_lhand)
+ mesh_out_rhand_align = rigid_align_batch(mesh_out_rhand, mesh_gt_rhand)
+ eval_result['pa_mpvpe_l_hand'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean(-1) *
+ 1000)
+ eval_result['pa_mpvpe_r_hand'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean(-1) *
+ 1000)
+ eval_result['pa_mpvpe_hand'].extend(
+ (np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean(-1) *
+ 1000 +
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean(-1) *
+ 1000) / 2.)
+ save_error=True
+ if save_error:
+ writer = csv.writer(file)
+ new_line = [ann_idx[n], img_path[n], eval_result['mpvpe_all'][-1], eval_result['pa_mpvpe_all'][-1]]
+ writer.writerow(new_line)
+
+
+ return eval_result
+
+
+ def print_eval_result(self, eval_result):
+
+ print('======Egocentric======')
+ print('PA MPVPE (All): %.2f mm' % np.mean(eval_result['pa_mpvpe_all']))
+ print('PA MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_l_hand']))
+ print('PA MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_r_hand']))
+ print('PA MPVPE (Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_hand']))
+ print('PA MPVPE (Face): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_face']))
+ print()
+
+ print('MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all']))
+ print('MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_l_hand']))
+ print('MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_r_hand']))
+ print('MPVPE (Hands): %.2f mm' % np.mean(eval_result['mpvpe_hand']))
+ print('MPVPE (Face): %.2f mm' % np.mean(eval_result['mpvpe_face']))
+
+ out_file = osp.join(cfg.result_dir,'Egocentric_val.txt')
+ if os.path.exists(out_file):
+ f = open(out_file, 'a+')
+ else:
+ f = open(out_file, 'w', encoding="utf-8")
+
+ f.write('\n')
+ f.write(f'{cfg.exp_name}\n')
+ f.write(f'Egocentric dataset: \n')
+ f.write('PA MPVPE (All): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_all']))
+ f.write('PA MPVPE (L-Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_l_hand']))
+ f.write('PA MPVPE (R-Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_r_hand']))
+ f.write('PA MPVPE (Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_hand']))
+ f.write('PA MPVPE (Face): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_face']))
+ f.write('MPVPE (All): %.2f mm\n' % np.mean(eval_result['mpvpe_all']))
+ f.write('MPVPE (L-Hands): %.2f mm\n' %
+ np.mean(eval_result['mpvpe_l_hand']))
+ f.write('MPVPE (R-Hands): %.2f mm\n' %
+ np.mean(eval_result['mpvpe_r_hand']))
+ f.write('MPVPE (Hands): %.2f mm\n' % np.mean(eval_result['mpvpe_hand']))
+ f.write('MPVPE (Face): %.2f mm\n' % np.mean(eval_result['mpvpe_face']))
diff --git a/datasets/EgoBody_Kinect.py b/datasets/EgoBody_Kinect.py
new file mode 100644
index 0000000000000000000000000000000000000000..999fe8654a9f9f712aaf3c3ea63ec6f20efd1604
--- /dev/null
+++ b/datasets/EgoBody_Kinect.py
@@ -0,0 +1,194 @@
+import os
+import os.path as osp
+import numpy as np
+import torch
+import cv2
+import json
+import copy
+import csv
+from pycocotools.coco import COCO
+from config.config import cfg
+from util.human_models import smpl_x
+
+from util.transforms import world2cam, cam2pixel, rigid_align
+from humandata import HumanDataset
+from util.transforms import rigid_align, rigid_align_batch
+
+
+class EgoBody_Kinect(HumanDataset):
+ def __init__(self, transform, data_split):
+ super(EgoBody_Kinect, self).__init__(transform, data_split)
+
+ if self.data_split == 'train':
+ filename = 'data/preprocessed_npz/multihuman_data/egobody_kinect_train_multi_080824.npz'
+ self.annot_path_cache = 'data/preprocessed_npz/cache/egobody_kinect_train_cache_080824.npz'
+ self.sample_interval = 10
+ else:
+ filename = 'data/preprocessed_npz/egobody_kinect_test_230503_043_fix_betas_multi.npz'
+ self.annot_path_cache = 'data/preprocessed_npz/egobody_kinect_test_230503_043_fix_betas_multi_cache_100.npz'
+ self.sample_interval = 100
+ self.use_betas_neutral = getattr(cfg, 'egobody_fix_betas', False)
+
+ self.img_dir = 'data/osx_data/EgoBody'
+ self.annot_path = filename
+
+ self.use_cache = getattr(cfg, 'use_cache', False)
+ self.img_shape = (1080, 1920) # (h, w)
+ self.cam_param = {}
+
+ # check image shape
+ img_path = osp.join(self.img_dir,
+ np.load(self.annot_path)['image_path'][0])
+ img_shape = cv2.imread(img_path).shape[:2]
+ assert self.img_shape == img_shape, 'image shape is incorrect: {} vs {}'.format(
+ self.img_shape, img_shape)
+
+ # load data or cache
+ if self.use_cache and osp.isfile(self.annot_path_cache):
+ print(
+ f'[{self.__class__.__name__}] loading cache from {self.annot_path_cache}'
+ )
+ self.datalist = self.load_cache(self.annot_path_cache)
+ else:
+ if self.use_cache:
+ print(
+ f'[{self.__class__.__name__}] Cache not found, generating cache...'
+ )
+ self.datalist = self.load_data(train_sample_interval=self.sample_interval)
+ if self.use_cache:
+ self.save_cache(self.annot_path_cache, self.datalist)
+ def evaluate(self, outs, cur_sample_idx):
+ annots = self.datalist
+ sample_num = len(outs)
+ eval_result = {
+ 'pa_mpvpe_all': [],
+ 'pa_mpvpe_l_hand': [],
+ 'pa_mpvpe_r_hand': [],
+ 'pa_mpvpe_hand': [],
+ 'pa_mpvpe_face': [],
+ 'mpvpe_all': [],
+ 'mpvpe_l_hand': [],
+ 'mpvpe_r_hand': [],
+ 'mpvpe_hand': [],
+ 'mpvpe_face': []
+ }
+
+ vis = getattr(cfg, 'vis', False)
+ vis_save_dir = cfg.vis_dir
+
+ for n in range(sample_num):
+ annot = annots[cur_sample_idx + n]
+ out = outs[n]
+ mesh_gt = out['smplx_mesh_cam_target']
+ mesh_out = out['smplx_mesh_cam']
+ ann_idx = out['gt_ann_idx']
+ img_path = []
+ for ann_id in ann_idx:
+ img_path.append(annots[ann_id]['img_path'])
+ eval_result['img_path'] = img_path
+ eval_result['ann_idx'] = ann_idx
+ # MPVPE from all vertices
+ mesh_out_align = \
+ mesh_out - np.dot(
+ smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['pelvis'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['pelvis'], None, :]
+
+ eval_result['mpvpe_all'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_align - mesh_gt)**2, -1)).mean(-1) * 1000)
+ mesh_out_align = rigid_align_batch(mesh_out, mesh_gt)
+ eval_result['pa_mpvpe_all'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_align - mesh_gt)**2, -1)).mean(-1) * 1000)
+
+ # MPVPE from hand vertices
+ mesh_gt_lhand = mesh_gt[:, smpl_x.hand_vertex_idx['left_hand'], :]
+ mesh_out_lhand = mesh_out[:, smpl_x.hand_vertex_idx['left_hand'], :]
+ mesh_gt_rhand = mesh_gt[:, smpl_x.hand_vertex_idx['right_hand'], :]
+ mesh_out_rhand = mesh_out[:, smpl_x.hand_vertex_idx['right_hand'], :]
+ mesh_out_lhand_align = \
+ mesh_out_lhand - \
+ np.dot(smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['lwrist'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['lwrist'], None, :]
+
+ mesh_out_rhand_align = \
+ mesh_out_rhand - \
+ np.dot(smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['rwrist'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['rwrist'], None, :]
+
+ eval_result['mpvpe_l_hand'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean(-1) *
+ 1000)
+ eval_result['mpvpe_r_hand'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean(-1) *
+ 1000)
+ eval_result['mpvpe_hand'].extend(
+ (np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean(-1) *
+ 1000 +
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean(-1) *
+ 1000) / 2.)
+ mesh_out_lhand_align = rigid_align_batch(mesh_out_lhand, mesh_gt_lhand)
+ mesh_out_rhand_align = rigid_align_batch(mesh_out_rhand, mesh_gt_rhand)
+ eval_result['pa_mpvpe_l_hand'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean(-1) *
+ 1000)
+ eval_result['pa_mpvpe_r_hand'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean(-1) *
+ 1000)
+ eval_result['pa_mpvpe_hand'].extend(
+ (np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, -1)).mean(-1) *
+ 1000 +
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).mean(-1) *
+ 1000) / 2.)
+ vis = False
+ if vis:
+ import mmcv
+ img = (out['img']).transpose(0,2,3,1)
+ img = mmcv.imdenormalize(
+ img=img[0],
+ mean=np.array([123.675, 116.28, 103.53]),
+ std=np.array([58.395, 57.12, 57.375]),
+ to_bgr=True).astype(np.uint8)
+ from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d
+ import ipdb;ipdb.set_trace()
+ visualize_kp2d(
+ out['smplx_joint_proj'][0][None],
+ image_array=img[None].copy(),
+ disable_limbs=True,
+ overwrite=True,
+ output_path='./figs/pred2d'
+ )
+ from pytorch3d.io import save_obj
+ save_obj('temp.obj',verts=out['smplx_mesh_cam'][0],faces=torch.tensor([]))
+ # MPVPE from face vertices
+ mesh_gt_face = mesh_gt[:, smpl_x.face_vertex_idx, :]
+ mesh_out_face = mesh_out[:, smpl_x.face_vertex_idx, :]
+ mesh_out_face_align = \
+ mesh_out_face - \
+ np.dot(smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['neck'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['neck'], None, :]
+ eval_result['mpvpe_face'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_face_align - mesh_gt_face)**2, -1)).mean(-1) * 1000)
+ mesh_out_face_align = rigid_align_batch(mesh_out_face, mesh_gt_face)
+ eval_result['pa_mpvpe_face'].extend(
+ np.sqrt(np.sum(
+ (mesh_out_face_align - mesh_gt_face)**2, -1)).mean(-1) * 1000)
+
+ # for k,v in eval_result.items():
+ # if k != 'img_path' and k != 'ann_idx':
+ # # import ipdb;ipdb.set_trace()
+ # if len(v)>1:
+ # eval_result[k] = np.concatenate(v,axis=0)
+ # else:
+ # eval_result[k] = np.array(v)
+
+ return eval_result
\ No newline at end of file
diff --git a/datasets/INFERENCE.py b/datasets/INFERENCE.py
new file mode 100644
index 0000000000000000000000000000000000000000..122c4edb6e7c495c0b5df5274d628dbe79e8b3a3
--- /dev/null
+++ b/datasets/INFERENCE.py
@@ -0,0 +1,289 @@
+import os
+import os.path as osp
+from glob import glob
+import numpy as np
+from config.config import cfg
+import copy
+import json
+import pickle
+import cv2
+import torch
+from pycocotools.coco import COCO
+from util.human_models import smpl_x
+from util.preprocessing import load_img, sanitize_bbox, process_bbox,augmentation_keep_size, load_ply, load_obj
+from util.transforms import rigid_align, rigid_align_batch
+import tqdm
+import random
+from util.formatting import DefaultFormatBundle
+from detrsmpl.data.datasets.pipelines.transforms import Normalize
+from humandata import HumanDataset
+from detrsmpl.utils.demo_utils import xywh2xyxy, xyxy2xywh, box2cs
+from detrsmpl.core.conventions.keypoints_mapping import convert_kps
+import mmcv
+import cv2
+import numpy as np
+from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d
+from detrsmpl.core.visualization.visualize_smpl import visualize_smpl_hmr,render_smpl
+from detrsmpl.models.body_models.builder import build_body_model
+from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d
+from detrsmpl.data.data_structures.multi_human_data import MultiHumanData
+from detrsmpl.utils.ffmpeg_utils import video_to_images
+from mmcv.runner import get_dist_info
+from config.config import cfg
+import torch.distributed as dist
+import shutil
+
+class INFERENCE(torch.utils.data.Dataset):
+ def __init__(self, img_dir=None,out_path=None):
+
+ self.output_path = out_path
+
+ self.img_dir = img_dir
+
+ self.is_vid = False
+
+ # can you change isfile to decide if it is mp4
+ rank, _ = get_dist_info()
+ if self.img_dir.endswith('.mp4'):
+ self.is_vid = True
+ img_name = self.img_dir.split('/')[-1][:-4]
+ # self.img_dir = self.img_dir[:-4]
+ else:
+ img_name = self.img_dir.split('/')[-1]
+ self.img_name = img_name+'_out'
+ self.output_path = os.path.join(self.output_path,self.img_name)
+ os.makedirs(self.output_path, exist_ok=True)
+ self.tmp_dir = os.path.join(self.output_path, 'temp_img')
+ os.makedirs(self.tmp_dir, exist_ok=True)
+ self.result_img_dir = os.path.join(self.output_path, 'res_img')
+
+
+ if not self.is_vid:
+ if rank == 0:
+ image_files = sorted(glob(self.img_dir + '/*.jpg') + glob(self.img_dir + '/*.png'))
+ for i, image_file in enumerate(image_files):
+ new_name = os.path.join(self.tmp_dir, '%06d.png'%i)
+ shutil.copy(image_file, new_name)
+ dist.barrier()
+ else:
+ if rank == 0:
+ video_to_images(self.img_dir, self.tmp_dir)
+ dist.barrier()
+ self.img_paths = sorted(glob(self.tmp_dir+'/*',recursive=True))
+ self.score_threshold = 0.2
+ self.resolution = [720 ,1280] # AGORA test
+ # self.resolution = [1200, 1600] # EHF
+ # self.img_paths = sorted(glob(self.img_dir,recursive=True))
+ self.format = DefaultFormatBundle()
+ self.normalize = Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])
+
+ def __len__(self):
+ return len(self.img_paths)
+
+ def __getitem__(self, idx):
+
+ img = load_img(self.img_paths[idx],'BGR')
+ img_whole_bbox = np.array([0, 0, img.shape[1],img.shape[0]])
+ img, img2bb_trans, bb2img_trans, _, _ = \
+ augmentation_keep_size(img, img_whole_bbox, 'test')
+
+ cropped_img_shape=img.shape[:2]
+ img = (img.astype(np.float32))
+
+ inputs = {'img': img}
+ targets = {
+ 'body_bbox_center': np.array(img_whole_bbox[None]),
+ 'body_bbox_size': np.array(img_whole_bbox[None])}
+ meta_info = {
+ 'ori_shape':np.array(self.resolution),
+ 'img_shape': np.array(img.shape[:2]),
+ 'img2bb_trans': img2bb_trans,
+ 'bb2img_trans': bb2img_trans,
+ 'ann_idx': idx}
+ result = {**inputs, **targets, **meta_info}
+
+ result = self.normalize(result)
+ result = self.format(result)
+
+ return result
+
+ def inference(self, outs):
+ img_paths = self.img_paths
+ sample_num = len(outs)
+ output = {}
+
+ for out in outs:
+ ann_idx = out['image_idx']
+ img_cropped = mmcv.imdenormalize(
+ img=(out['img'].cpu().numpy()).transpose(1, 2, 0),
+ mean=np.array([123.675, 116.28, 103.53]),
+ std=np.array([58.395, 57.12, 57.375]),
+ to_bgr=True).astype(np.uint8)
+ # bb2img_trans = out['bb2img_trans']
+ # img2bb_trans = out['img2bb_trans']
+ scores = out['scores'].clone().cpu().numpy()
+ img_shape = out['img_shape'].cpu().numpy()[::-1] # w, h
+ width,height = img_shape
+ width += width % 2
+ height += height % 2
+ img_shape = np.array([width, height])
+ img = cv2.imread(img_paths[ann_idx]) # h, w
+
+
+ joint_proj = out['smplx_joint_proj'].clone().cpu().numpy()
+ joint_vis = out['smplx_joint_proj'].clone().cpu().numpy()
+ joint_coco = out['keypoints_coco'].clone().cpu().numpy()
+ joint_coco_raw = joint_coco.copy()
+ smpl_kp3d_coco, _ = convert_kps(out['smpl_kp3d'].clone().cpu().numpy(),src='smplx',dst='coco', approximate=True)
+
+
+
+ body_bbox = out['body_bbox'].clone().cpu().numpy()
+ lhand_bbox = out['lhand_bbox'].clone().cpu().numpy()
+ rhand_bbox = out['rhand_bbox'].clone().cpu().numpy()
+ face_bbox = out['face_bbox'].clone().cpu().numpy()
+
+ if self.resolution == [720, 1280]:
+ joint_proj[:, :, 0] = joint_proj[:, :, 0] / img_shape[0] * 3840
+ joint_proj[:, :, 1] = joint_proj[:, :, 1] / img_shape[1] * 2160
+ joint_vis[:, :, 0] = joint_vis[:, :, 0] / img_shape[0] * img.shape[1]
+ joint_vis[:, :, 1] = joint_vis[:, :, 1]/ img_shape[1] * img.shape[0]
+
+ joint_coco[:, :, 0] = joint_coco[:, :, 0] / img_shape[0] * img.shape[1]
+ joint_coco[:, :, 1] = joint_coco[:, :, 1]/ img_shape[1] * img.shape[0]
+ scale = np.array([
+ img.shape[1]/img_shape[0],
+ img.shape[1]/img_shape[0],
+ img.shape[1]/img_shape[0],
+ img.shape[1]/img_shape[0],
+ ])
+ body_bbox_raw = body_bbox.copy()
+ body_bbox = body_bbox * scale
+ lhand_bbox = lhand_bbox * scale
+ rhand_bbox = rhand_bbox * scale
+ face_bbox = face_bbox * scale
+ elif self.resolution == [1200, 1600]:
+
+ joint_proj[:, :, 0] = joint_proj[:, :, 0] * (1200 / 800)
+ joint_proj[:, :, 1] = joint_proj[:, :, 1] * (1600 / 1066)
+
+ joint_vis[:, :, 0] = joint_vis[:, :, 0] * (1200 / 800)
+ joint_vis[:, :, 1] = joint_vis[:, :, 1] * (1600 / 1066)
+
+ scale = np.array([1600/1066, 1200/800, 1600/1066, 1200/800])[None]
+ body_bbox = body_bbox * scale
+ lhand_bbox = lhand_bbox * scale
+ rhand_bbox = rhand_bbox * scale
+ face_bbox = face_bbox * scale
+
+ for i, score in enumerate(scores):
+ if score < self.score_threshold:
+ break
+
+ save_name = img_paths[ann_idx].split('/')[-1][:-4] # if not crop should be -4
+ if self.resolution == (2160, 3840):
+ save_name = save_name.split('_ann_id')[0]
+ else:
+ save_name = save_name.split('_1280x720')[0]
+
+
+
+ save_dict = {
+ 'params': {
+ 'transl': out['cam_trans'][i].reshape(1, -1).cpu().numpy(),
+ 'global_orient': out['smplx_root_pose'][i].reshape(1, -1).cpu().numpy(),
+ 'body_pose': out['smplx_body_pose'][i].reshape(1, -1).cpu().numpy(),
+ 'left_hand_pose': out['smplx_lhand_pose'][i].reshape(1, -1).cpu().numpy(),
+ 'right_hand_pose': out['smplx_rhand_pose'][i].reshape(1, -1).cpu().numpy(),
+ 'reye_pose': np.zeros((1, 3)),
+ 'leye_pose': np.zeros((1, 3)),
+ 'jaw_pose': out['smplx_jaw_pose'][i].reshape(1, -1).cpu().numpy(),
+ 'expression': out['smplx_expr'][i].reshape(1, -1).cpu().numpy(),
+ 'betas': out['smplx_shape'][i].reshape(1, -1).cpu().numpy()},
+
+ 'joints': joint_proj[i].reshape(1, -1, 2)[0,:24]}
+
+ # save
+ exist_result_path = glob(osp.join(self.output_path, 'predictions', save_name + '*'))
+ if len(exist_result_path) == 0:
+ person_idx = 0
+ else:
+ last_person_idx = max([
+ int(name.split('personId_')[1].split('.pkl')[0])
+ for name in exist_result_path
+ ])
+ person_idx = last_person_idx + 1
+
+ save_name += '_personId_' + str(person_idx) + '.pkl'
+ os.makedirs(osp.join(self.output_path, 'predictions'), exist_ok=True)
+ with open(osp.join(self.output_path, 'predictions', save_name),'wb') as f:
+ pickle.dump(save_dict, f)
+ # mesh
+ # bbox
+
+
+ if i == 0:
+ save_name = img_paths[ann_idx].split('/')[-1][:-4]
+ cv2.imwrite(os.path.join(self.result_img_dir,img_paths[ann_idx].split('/')[-1]), img)
+ else:
+ # dump bbox
+ body_xywh = xyxy2xywh(body_bbox[:i])
+ score = scores[:i]
+ out_value = [{'bbox': b, 'score': s} for b, s in zip(body_xywh, score)]
+ out_key = img_paths[ann_idx].split('/')[-1]
+ output.update({out_key: out_value})
+
+ # show bbox
+ img = mmcv.imshow_bboxes(img, body_bbox[:i], show=False, colors='green')
+ img = mmcv.imshow_bboxes(img, lhand_bbox[:i], show=False, colors='blue')
+ img = mmcv.imshow_bboxes(img, rhand_bbox[:i], show=False, colors='yellow')
+ img = mmcv.imshow_bboxes(img, face_bbox[:i], show=False, colors='red')
+
+ verts = out['smpl_verts'][:i] + out['cam_trans'][:i][:, None]
+ body_model_cfg = dict(
+ type='smplx',
+ keypoint_src='smplx',
+ num_expression_coeffs=10,
+ num_betas=10,
+ gender='neutral',
+ keypoint_dst='smplx_137',
+ model_path='data/body_models/smplx',
+ use_pca=False,
+ use_face_contour=True)
+ body_model = build_body_model(body_model_cfg).to('cuda')
+ # for n, v in enumerate(verts):
+ # save_obj(
+ # osp.join(self.out_path, 'vis', img_paths[ann_idx].split('/')[-1].rjust(5+4,'0')).replace('.jpg',f'_{n}_.obj'),
+ # verts = v,
+ # faces=torch.tensor(body_model.faces.astype(np.int32))
+ # )
+ # print(osp.join(self.out_path, 'vis', img_paths[ann_idx].split('/')[-1]))
+
+ render_smpl(
+ verts=verts[None],
+ body_model=body_model,
+ # K= np.array(
+ # [[img_shape[0]/2, 0, img_shape[0]/2],
+ # [0, img_shape[0]/2, img_shape[1]/2],
+ # [0, 0, 1]]),
+ K= np.array(
+ [[5000, 0, img_shape[0]/2],
+ [0, 5000, img_shape[1]/2],
+ [0, 0, 1]]),
+ R=None,
+ T=None,
+ # output_path=osp.join(self.out_path, 'vis', img_paths[ann_idx].split('/')[-1].rjust(5+4,'0')),
+ output_path=os.path.join(self.result_img_dir,img_paths[ann_idx].split('/')[-1]),
+ image_array=cv2.resize(img, (img_shape[0],img_shape[1]), cv2.INTER_CUBIC),
+ in_ndc=False,
+ alpha=0.9,
+ convention='opencv',
+ projection='perspective',
+ overwrite=True,
+ no_grad=True,
+ device='cuda',
+ resolution=[img_shape[0],img_shape[1]],
+ render_choice='hq',
+ )
+ return output
+
diff --git a/datasets/INFERENCE_demo.py b/datasets/INFERENCE_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..39ba4bce29b86d481c49600209b21274393eba17
--- /dev/null
+++ b/datasets/INFERENCE_demo.py
@@ -0,0 +1,169 @@
+import os
+import os.path as osp
+from glob import glob
+import numpy as np
+from config.config import cfg
+import copy
+import json
+import pickle
+import cv2
+import torch
+from pycocotools.coco import COCO
+from util.human_models import smpl_x
+from util.preprocessing import load_img, sanitize_bbox, process_bbox,augmentation_keep_size, load_ply, load_obj
+from util.transforms import rigid_align, rigid_align_batch
+import tqdm
+import random
+from util.formatting import DefaultFormatBundle
+from detrsmpl.data.datasets.pipelines.transforms import Normalize
+from humandata import HumanDataset
+from detrsmpl.utils.demo_utils import xywh2xyxy, xyxy2xywh, box2cs
+from detrsmpl.core.conventions.keypoints_mapping import convert_kps
+import mmcv
+import cv2
+import numpy as np
+from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d
+from detrsmpl.core.visualization.visualize_smpl import visualize_smpl_hmr,render_smpl
+from detrsmpl.models.body_models.builder import build_body_model
+from detrsmpl.core.visualization.visualize_keypoints3d import visualize_kp3d
+from detrsmpl.data.data_structures.multi_human_data import MultiHumanData
+from detrsmpl.utils.ffmpeg_utils import video_to_images
+from mmcv.runner import get_dist_info
+from config.config import cfg
+import torch.distributed as dist
+import shutil
+import re
+from pytorch3d.io import save_obj
+
+class INFERENCE_demo(torch.utils.data.Dataset):
+ def __init__(self, img_dir=None,out_path=None):
+
+ self.output_path = out_path
+ self.mesh_path = os.path.join(self.output_path, 'mesh')
+ self.img_dir = img_dir
+ self.is_vid = True
+ body_model_cfg = dict(
+ type='smplx',
+ keypoint_src='smplx',
+ num_expression_coeffs=10,
+ num_betas=10,
+ gender='neutral',
+ keypoint_dst='smplx_137',
+ model_path='data/body_models/smplx',
+ use_pca=False,
+ use_face_contour=True)
+ self.body_model = build_body_model(body_model_cfg).to('cuda')
+
+ os.makedirs(self.output_path, exist_ok=True)
+ self.tmp_dir = os.path.join(self.output_path, 'temp_img')
+ os.makedirs(self.tmp_dir, exist_ok=True)
+ self.result_img_dir = os.path.join(self.output_path, 'res_img')
+ video_to_images(self.img_dir, self.tmp_dir)
+ self.img_paths = sorted(glob(self.tmp_dir+'/*',recursive=True))
+
+ self.num_person = cfg.num_person if 'num_person' in cfg else 0.1
+ self.score_threshold = cfg.threshold if 'threshold' in cfg else 0.1
+ self.format = DefaultFormatBundle()
+ self.normalize = Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])
+
+ def __len__(self):
+ return len(self.img_paths)
+
+ def __getitem__(self, idx):
+ img = load_img(self.img_paths[idx],'BGR')
+ self.resolution = img.shape[:2]
+ img_whole_bbox = np.array([0, 0, img.shape[1],img.shape[0]])
+ img, img2bb_trans, bb2img_trans, _, _ = \
+ augmentation_keep_size(img, img_whole_bbox, 'test')
+
+ # cropped_img_shape=img.shape[:2]
+ img = (img.astype(np.float32))
+
+ inputs = {'img': img}
+ targets = {
+ 'body_bbox_center': np.array(img_whole_bbox[None]),
+ 'body_bbox_size': np.array(img_whole_bbox[None])}
+ meta_info = {
+ 'ori_shape':np.array(self.resolution),
+ 'img_shape': np.array(img.shape[:2]),
+ 'img2bb_trans': img2bb_trans,
+ 'bb2img_trans': bb2img_trans,
+ 'ann_idx': idx}
+ result = {**inputs, **targets, **meta_info}
+
+ result = self.normalize(result)
+ result = self.format(result)
+
+ return result
+
+ def inference(self, outs):
+ img_paths = self.img_paths
+ for out in outs:
+ ann_idx = out['image_idx']
+ # img_cropped = mmcv.imdenormalize(
+ # img=(out['img'].cpu().numpy()).transpose(1, 2, 0),
+ # mean=np.array([123.675, 116.28, 103.53]),
+ # std=np.array([58.395, 57.12, 57.375]),
+ # to_bgr=True).astype(np.uint8)
+ # bb2img_trans = out['bb2img_trans']
+ # img2bb_trans = out['img2bb_trans']
+ scores = out['scores'].clone().cpu().numpy()
+ img_shape = out['img_shape'].cpu().numpy()[::-1] # w, h
+ img = cv2.imread(img_paths[ann_idx]) # h, w
+ scale = img.shape[1]/img_shape[0]
+ body_bbox = out['body_bbox'].clone().cpu().numpy()
+ body_bbox = body_bbox * scale
+ joint_3d, _ = convert_kps(out['smpl_kp3d'].clone().cpu().numpy(),src='smplx',dst='smplx', approximate=True)
+
+ for i, score in enumerate(scores):
+ if score < self.score_threshold:
+ break
+ if i>self.num_person:
+ break
+ save_name = img_paths[ann_idx].split('/')[-1]
+ save_name = save_name.split('.')[0]
+ vert = out['smpl_verts'][i] + out['cam_trans'][i][None]
+ # save mesh
+ exist_result_path = glob(osp.join(self.mesh_path, save_name + '*'))
+ if len(exist_result_path) == 0:
+ person_idx = 0
+ else:
+ last_person_idx = max([
+ int(name.split('personId_')[1].split('.obj')[0])
+ for name in exist_result_path
+ ])
+ person_idx = last_person_idx + 1
+
+ save_name += '_personId_' + str(person_idx) + '.obj'
+ os.makedirs(self.mesh_path, exist_ok=True)
+ save_obj(osp.join(self.mesh_path, save_name), vert, faces=torch.tensor(self.body_model.faces.astype(np.int32)))
+
+ if i == 0:
+ save_name = img_paths[ann_idx].split('/')[-1][:-4]
+ cv2.imwrite(os.path.join(self.result_img_dir,img_paths[ann_idx].split('/')[-1]), img)
+ else:
+ verts = out['smpl_verts'][:i] + out['cam_trans'][:i][:, None]
+ img = mmcv.imshow_bboxes(img, body_bbox[:i], show=False, colors='green')
+ render_smpl(
+ verts=verts[None],
+ body_model=self.body_model,
+ K= np.array(
+ [[5000, 0, img_shape[0]/2],
+ [0, 5000, img_shape[1]/2],
+ [0, 0, 1]]),
+ R=None,
+ T=None,
+ output_path=os.path.join(self.result_img_dir,img_paths[ann_idx].split('/')[-1]),
+ image_array=cv2.resize(img, (img_shape[0],img_shape[1]), cv2.INTER_CUBIC),
+ in_ndc=False,
+ alpha=0.9,
+ convention='opencv',
+ projection='perspective',
+ overwrite=True,
+ no_grad=True,
+ device='cuda',
+ resolution=[img_shape[1],img_shape[0]],
+ render_choice='hq'
+ )
+ return None
+
diff --git a/datasets/SynBody.py b/datasets/SynBody.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f9428069440cd2723aac31e3e2c9e9f43ec7831
--- /dev/null
+++ b/datasets/SynBody.py
@@ -0,0 +1,53 @@
+import os
+import os.path as osp
+import numpy as np
+import torch
+import cv2
+import json
+import copy
+from pycocotools.coco import COCO
+from config.config import cfg
+from util.human_models import smpl_x
+from util.preprocessing import (
+ load_img, process_bbox, augmentation_instance_sample
+ ,process_human_model_output_batch_simplify,process_db_coord_batch_no_valid)
+from util.transforms import world2cam, cam2pixel, rigid_align
+from humandata import HumanDataset
+
+
+class SynBody(HumanDataset):
+ def __init__(self, transform, data_split):
+ super(SynBody, self).__init__(transform, data_split)
+ self.img_dir = 'data/datasets/synbody'
+ self.annot_path = 'data/preprocessed_npz/multihuman_data/synbody_v1.1_multi_new.npz'
+ self.annot_path_cache = 'data/preprocessed_npz/cache/synbody_v1.1_cache_new_10.npz'
+ self.use_cache = getattr(cfg, 'use_cache', False)
+ self.img_shape = (720, 1280) # (h, w)
+ self.cam_param = {
+ 'focal': (540, 540), # (fx, fy)
+ 'princpt': (640, 360) # (cx, cy)
+ }
+
+ # check image shape
+ img_path = osp.join(self.img_dir,
+ np.load(self.annot_path)['image_path'][0])
+
+ img_shape = cv2.imread(img_path).shape[:2]
+ assert self.img_shape == img_shape, 'image shape is incorrect: {} vs {}'.format(
+ self.img_shape, img_shape)
+
+ # load data or cache
+ if self.use_cache and osp.isfile(self.annot_path_cache):
+ print(
+ f'[{self.__class__.__name__}] loading cache from {self.annot_path_cache}'
+ )
+ self.datalist = self.load_cache(self.annot_path_cache)
+ else:
+ if self.use_cache:
+ print(
+ f'[{self.__class__.__name__}] Cache not found, generating cache...'
+ )
+ self.datalist = self.load_data(train_sample_interval=getattr(
+ cfg, f'{self.__class__.__name__}_train_sample_interval', 15))
+ if self.use_cache:
+ self.save_cache(self.annot_path_cache, self.datalist)
diff --git a/datasets/UBody_MM.py b/datasets/UBody_MM.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1aa88b2542a1e06ebc8d51100940cf7e5aa1668
--- /dev/null
+++ b/datasets/UBody_MM.py
@@ -0,0 +1,1122 @@
+import os
+import os.path as osp
+from glob import glob
+import numpy as np
+from config.config import cfg
+import copy
+
+import cv2
+import torch
+from pycocotools.coco import COCO
+from util.human_models import smpl_x
+from util.preprocessing import load_img, process_bbox
+from util.transforms import rigid_align_batch
+import tqdm
+from detrsmpl.utils.geometry import batch_rodrigues, project_points_new
+import random
+from util.formatting import DefaultFormatBundle
+from detrsmpl.data.datasets.pipelines.transforms import Normalize
+from datasets.humandata import HumanDataset
+import time
+from util.preprocessing import (
+ load_img, process_bbox, augmentation_instance_sample,process_human_model_output_batch_simplify,process_db_coord_batch_no_valid,process_human_model_output_batch_ubody)
+KPS2D_KEYS = [
+ 'keypoints2d_ori', 'keypoints2d_smplx', 'keypoints2d_smpl',
+ 'keypoints2d_original','keypoints2d_gta'
+]
+KPS3D_KEYS = [
+ 'keypoints3d_cam', 'keypoints3d', 'keypoints3d_smplx', 'keypoints3d_smpl',
+ 'keypoints3d_original', 'keypoints3d_gta'
+]
+class UBody_MM(HumanDataset):
+ def __init__(self, transform, data_split):
+ super(UBody_MM, self).__init__(transform, data_split)
+
+ self.img_dir = 'data/osx_data/UBody'
+ self.data_split = data_split
+ self.test_vid_list = np.load('data/osx_data/UBody/splits/intra_scene_test_list.npy')
+ if self.data_split == 'train':
+ # self.annot_path = 'data/preprocessed_npz/multihuman_data/ubody_intra_train_multi_all.npz'
+ # self.annot_path_cache = 'data/preprocessed_npz/cache/ubody_intra_train_cache_fix8.npz'
+ self.annot_path = 'data/preprocessed_npz/multihuman_data/ubody_train_intra_multi.npz'
+ self.annot_path_cache = 'data/preprocessed_npz/cache/ubody_train_intra_cache_080824.npz'
+ self.sample_interval = getattr(
+ cfg, f'{self.__class__.__name__}_train_sample_interval', 5)
+ elif self.data_split == 'test':
+ self.annot_path = 'data/preprocessed_npz/ubody_intra_test_all.npz'
+ self.annot_path_cache = 'data/preprocessed_npz/cache/ubody_intra_test_multi_all_smpler_x.npz'
+ self.sample_interval = getattr(
+ cfg, f'{self.__class__.__name__}_test_sample_interval', 100)
+ # self.test_set = 'val'
+ self.use_cache = getattr(cfg, 'use_cache', False)
+ self.img_shape = None #1024, 1024) # (h, w)
+ self.cam_param = {}
+ self.keypoints2d = 'keypoints2d_ubody'
+ # load data
+ if self.use_cache and osp.isfile(self.annot_path_cache):
+ print(
+ f'[{self.__class__.__name__}] loading cache from {self.annot_path_cache}'
+ )
+ self.datalist = self.load_cache(self.annot_path_cache)
+ else:
+ if self.use_cache:
+ print(
+ f'[{self.__class__.__name__}] Cache not found, generating cache...'
+ )
+ self.datalist = self.load_data(train_sample_interval=self.sample_interval)
+
+ if self.use_cache:
+ self.save_cache(self.annot_path_cache, self.datalist)
+
+
+ def evaluate(self, outs, cur_sample_idx):
+ annots = self.datalist
+ sample_num = len(outs)
+ eval_result = {
+ 'pa_mpvpe_all': [],
+ 'pa_mpvpe_l_hand': [],
+ 'pa_mpvpe_r_hand': [],
+ 'pa_mpvpe_hand': [],
+ 'pa_mpvpe_face': [],
+ 'mpvpe_all': [],
+ 'mpvpe_l_hand': [],
+ 'mpvpe_r_hand': [],
+ 'mpvpe_hand': [],
+ 'mpvpe_face': []
+ }
+
+ vis = getattr(cfg, 'vis', False)
+ vis_save_dir = cfg.vis_dir
+
+ for n in range(sample_num):
+
+ out = outs[n]
+ mesh_gt = out['smplx_mesh_cam_target']
+ mesh_out = out['smplx_mesh_cam']
+ cam_trans = out['cam_trans']
+ joint_proj = out['smplx_joint_proj']
+ img_wh = (out['img_shape'])
+ ann_idx = out['gt_ann_idx']
+ img_path = []
+ for ann_id in ann_idx:
+ img_path.append(annots[ann_id]['img_path'])
+ # print(img_path)
+ eval_result['img_path'] = img_path
+ eval_result['ann_idx'] = ann_idx
+
+ # MPVPE from all vertices
+ joint_gt_body_wo_trans = np.dot(smpl_x.j14_regressor,
+ mesh_gt).transpose(1,0,2)
+ joint_gt_body_proj = project_points_new(
+ points_3d=torch.Tensor(joint_gt_body_wo_trans),
+ pred_cam=torch.Tensor(cam_trans),
+ focal_length=5000,
+ camera_center=torch.Tensor(img_wh/2)
+ ) # origin image space
+
+
+
+ joint_gt_lhand_wo_trans = np.dot(
+ smpl_x.orig_hand_regressor['left'], mesh_gt).transpose(1,0,2)
+ joint_gt_lhand_proj = project_points_new(
+ points_3d=torch.Tensor(joint_gt_lhand_wo_trans),
+ pred_cam=torch.Tensor(cam_trans),
+ focal_length=5000,
+ camera_center=torch.Tensor(img_wh/2)
+ ) # origin image space
+ joint_gt_rhand_wo_trans = np.dot(
+ smpl_x.orig_hand_regressor['left'], mesh_gt).transpose(1,0,2)
+ joint_gt_rhand_proj = project_points_new(
+ points_3d=torch.Tensor(joint_gt_rhand_wo_trans),
+ pred_cam=torch.Tensor(cam_trans),
+ focal_length=5000,
+ camera_center=torch.Tensor(img_wh/2)
+ ) # origin image space
+ mesh_gt_proj = project_points_new(
+ points_3d=torch.Tensor(mesh_gt),
+ pred_cam=torch.Tensor(cam_trans),
+ focal_length=5000,
+ camera_center=torch.Tensor(img_wh/2))
+
+
+ joint_gt_body_valid = self.validate_within_img_batch(
+ img_wh, joint_gt_body_proj)
+ joint_gt_lhand_valid = self.validate_within_img_batch(
+ img_wh, joint_gt_lhand_proj)
+ joint_gt_rhand_valid = self.validate_within_img_batch(
+ img_wh, joint_gt_rhand_proj)
+ mesh_valid = self.validate_within_img_batch(img_wh, mesh_gt_proj)
+ mesh_valid = mesh_valid.cpu().numpy()>0
+ mesh_lhand_valid = mesh_valid[:,smpl_x.hand_vertex_idx['left_hand']]
+ mesh_rhand_valid = mesh_valid[:,smpl_x.hand_vertex_idx['right_hand']]
+ mesh_face_valid = mesh_valid[:,smpl_x.face_vertex_idx]
+
+ # MPVPE from all vertices
+ mesh_out = out['smplx_mesh_cam']
+ mesh_out_align = rigid_align_batch(mesh_out, mesh_gt)
+
+ if mesh_valid.sum()>0:
+ pa_mpvpe_all = np.sqrt(np.sum(
+ (mesh_out_align - mesh_gt)**2, -1))[mesh_valid].mean() * 1000
+ else:
+ pa_mpvpe_all = 0
+
+ eval_result['pa_mpvpe_all'].append(pa_mpvpe_all)
+
+ mesh_out_align = mesh_out - np.dot(smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:,smpl_x.J_regressor_idx['pelvis'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:,smpl_x.J_regressor_idx['pelvis'], None, :]
+ if mesh_valid.sum()>0:
+ mpvpe_all = np.sqrt(np.sum(
+ (mesh_out_align - mesh_gt)**2, -1))[mesh_valid].mean() * 1000
+ else:
+ mpvpe_all = 0
+ eval_result['mpvpe_all'].append(mpvpe_all)
+ vis = False
+
+ if vis:
+ import mmcv
+ img = (out['img']).transpose(0,2,3,1)
+
+ img = mmcv.imdenormalize(
+ img=img[0],
+ mean=np.array([123.675, 116.28, 103.53]),
+ std=np.array([58.395, 57.12, 57.375]),
+ to_bgr=True).astype(np.uint8)
+ cv2.imwrite('temp.png',img)
+ from detrsmpl.core.visualization.visualize_keypoints2d import visualize_kp2d
+
+ # out['smplx_joint_proj']
+ from pytorch3d.io import save_obj
+
+ mesh_pred_proj = project_points_new(
+ points_3d=torch.Tensor(mesh_gt),
+ pred_cam=torch.Tensor(cam_trans),
+ focal_length=5000,
+ camera_center=torch.Tensor(img_wh/2))
+ mesh_pred_proj = (mesh_valid[:,:,None])*mesh_pred_proj.detach().cpu().numpy()
+ visualize_kp2d(
+ mesh_pred_proj[0][None],
+ image_array=img[None].copy(),
+ disable_limbs=True,
+ overwrite=True,
+ output_path='./figs/gt2d/%d'%ann_idx
+ )
+ mesh_pred_proj = project_points_new(
+ points_3d=torch.Tensor(mesh_out),
+ pred_cam=torch.Tensor(cam_trans),
+ focal_length=5000,
+ camera_center=torch.Tensor(img_wh/2))
+ mesh_pred_proj = (mesh_valid[:,:,None])*mesh_pred_proj.detach().cpu().numpy()
+ visualize_kp2d(
+ mesh_pred_proj[0][None],
+ image_array=img[None].copy(),
+ disable_limbs=True,
+ overwrite=True,
+ output_path='./figs/pred2d/%d'%ann_idx
+ )
+ save_obj('./figs/pred_smpl_%d.obj'%mpvpe_all,verts = torch.tensor(mesh_out_align[0]),faces=torch.tensor([]))
+ save_obj('./figs/gt_smpl_%d.obj'%mpvpe_all,verts = torch.tensor(mesh_gt[0]),faces=torch.tensor([]))
+ # MPVPE from hand vertices
+ mesh_gt_lhand = mesh_gt[:, smpl_x.hand_vertex_idx['left_hand'], :]
+ mesh_out_lhand = mesh_out[:, smpl_x.hand_vertex_idx['left_hand'], :]
+ mesh_gt_rhand = mesh_gt[:, smpl_x.hand_vertex_idx['right_hand'], :]
+ mesh_out_rhand = mesh_out[:, smpl_x.hand_vertex_idx['right_hand'], :]
+ mesh_out_lhand_align = \
+ mesh_out_lhand - \
+ np.dot(smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['lwrist'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['lwrist'], None, :]
+
+ mesh_out_rhand_align = \
+ mesh_out_rhand - \
+ np.dot(smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['rwrist'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['rwrist'], None, :]
+ mpvpe_hand = []
+
+ if mesh_lhand_valid.sum() != 0:
+ mpvpe_lhand = np.sqrt(
+ np.sum((mesh_out_lhand_align - mesh_gt_lhand)**2,
+ -1))[mesh_lhand_valid].mean() * 1000
+ mpvpe_hand.append(mpvpe_lhand)
+ eval_result['mpvpe_l_hand'].append(mpvpe_lhand)
+ else:
+ eval_result['mpvpe_l_hand'].append(np.zeros_like(mpvpe_all))
+ if mesh_rhand_valid.sum() != 0:
+ mpvpe_rhand = np.sqrt(
+ np.sum((mesh_out_rhand_align - mesh_gt_rhand)**2,
+ -1))[mesh_rhand_valid].mean() * 1000
+ mpvpe_hand.append(mpvpe_rhand)
+ eval_result['mpvpe_r_hand'].append(mpvpe_rhand)
+ else:
+ eval_result['mpvpe_r_hand'].append(np.zeros_like(mpvpe_all))
+ if len(mpvpe_hand) > 0:
+ mpvpe_hand = np.stack(mpvpe_hand,axis=-1)
+ eval_result['mpvpe_hand'].append(np.mean(mpvpe_hand,axis=-1))
+ else:
+ eval_result['mpvpe_hand'].append(np.zeros_like(mpvpe_all))
+ mesh_out_lhand_align = rigid_align_batch(mesh_out_lhand, mesh_gt_lhand)
+ mesh_out_rhand_align = rigid_align_batch(mesh_out_rhand, mesh_gt_rhand)
+ pa_mpvpe_hand = []
+ if mesh_lhand_valid.sum() != 0:
+ pa_mpvpe_lhand = np.sqrt(
+ np.sum((mesh_out_lhand_align - mesh_gt_lhand)**2,
+ -1))[mesh_lhand_valid].mean() * 1000
+ pa_mpvpe_hand.append(pa_mpvpe_lhand)
+ eval_result['pa_mpvpe_l_hand'].append(pa_mpvpe_lhand)
+ else:
+ eval_result['pa_mpvpe_l_hand'].append(np.zeros_like(mpvpe_all))
+ if mesh_rhand_valid.sum() != 0:
+ # pa_mpvpe_rhand = np.sqrt(np.sum((mesh_out_rhand_align - mesh_gt_rhand)**2, -1)).sum(-1) * 1000 / (mesh_rhand_valid.sum(-1)+1e-6)
+ pa_mpvpe_rhand = np.sqrt(
+ np.sum((mesh_out_rhand_align - mesh_gt_rhand)**2,
+ -1))[mesh_rhand_valid].mean() * 1000
+ pa_mpvpe_hand.append(pa_mpvpe_rhand)
+ eval_result['pa_mpvpe_r_hand'].append(pa_mpvpe_rhand)
+ else:
+ eval_result['pa_mpvpe_r_hand'].append(np.zeros_like(mpvpe_all))
+ if len(pa_mpvpe_hand) > 0:
+ pa_mpvpe_hand = np.stack(pa_mpvpe_hand,axis=-1)
+ eval_result['pa_mpvpe_hand'].append(np.mean(pa_mpvpe_hand,axis=-1))
+ else:
+ eval_result['pa_mpvpe_hand'].append(np.zeros_like(np.mean(np.zeros_like(mpvpe_all))))
+
+ # MPVPE from face vertices
+ mesh_gt_face = mesh_gt[:, smpl_x.face_vertex_idx, :]
+ mesh_out_face = mesh_out[:, smpl_x.face_vertex_idx, :]
+ mesh_out_face_align = \
+ mesh_out_face - \
+ np.dot(smpl_x.J_regressor, mesh_out).transpose(1,0,2)[:, smpl_x.J_regressor_idx['neck'], None, :] + \
+ np.dot(smpl_x.J_regressor, mesh_gt).transpose(1,0,2)[:, smpl_x.J_regressor_idx['neck'], None, :]
+ if mesh_face_valid.sum() != 0:
+ eval_result['mpvpe_face'].append(
+ np.sqrt(np.sum((mesh_out_face_align - mesh_gt_face)**2,
+ -1))[mesh_face_valid].mean() * 1000)
+ else:
+ eval_result['mpvpe_face'].append(np.zeros_like(np.mean(np.zeros_like(mpvpe_all))))
+ mesh_out_face_align = rigid_align_batch(mesh_out_face, mesh_gt_face)
+
+ if mesh_face_valid.sum() != 0:
+ eval_result['pa_mpvpe_face'].append(
+ np.sqrt(np.sum((mesh_out_face_align - mesh_gt_face)**2,
+ -1))[mesh_face_valid].mean() * 1000)
+ else:
+ eval_result['pa_mpvpe_face'].append(np.zeros_like(np.mean(np.zeros_like(mpvpe_all))))
+ for k,v in eval_result.items():
+ if k != 'img_path' and k != 'ann_idx':
+
+ if len(v)>1:
+ eval_result[k] = np.concatenate(v,axis=0)
+ else:
+ eval_result[k] = np.array(v)
+ return eval_result
+
+ def load_data(self, train_sample_interval=1):
+
+ content = np.load(self.annot_path, allow_pickle=True)
+ try:
+ frame_range = content['frame_range']
+ except KeyError:
+ self.num_data = len(content['image_path'])
+ frame_range = \
+ np.array([[i, i + 1] for i in range(self.num_data)])
+
+ num_examples = len(frame_range)
+
+ if 'meta' in content:
+ meta = content['meta'].item()
+ print('meta keys:', meta.keys())
+ else:
+ meta = None
+ print(
+ 'No meta info provided! Please give height and width manually')
+
+ print(
+ f'Start loading humandata {self.annot_path} into memory...\nDataset includes: {content.files}'
+ )
+ tic = time.time()
+ image_path = content['image_path']
+
+ if meta is not None and 'height' in meta:
+ height = np.array(meta['height'])
+ width = np.array(meta['width'])
+ image_shape = np.stack([height, width], axis=-1)
+ else:
+ image_shape = None
+
+ if meta is not None and 'gender' in meta and len(meta['gender']) != 0:
+ gender = meta['gender']
+ else:
+ gender = None
+
+ face_valid = meta['face_valid']
+ lhand_valid = meta['lefthand_valid']
+ rhand_valid = meta['righthand_valid']
+ valid_label = meta['valid_label']
+ is_crowd = meta['iscrowd']
+ keypoints_valid = content['keypoints2d_ubody'][:,:,2].sum(-1)!=0
+ bbox_xywh = content['bbox_xywh']
+ if 'smplx' in content:
+ smplx = content['smplx'].item()
+ as_smplx = 'smplx'
+ elif 'smpl' in content:
+ smplx = content['smpl'].item()
+ as_smplx = 'smpl'
+ elif 'smplh' in content:
+ smplx = content['smplh'].item()
+ as_smplx = 'smplh'
+ # TODO: temp solution, should be more general. But SHAPY is very special
+ elif self.__class__.__name__ == 'SHAPY':
+ smplx = {}
+ else:
+ raise KeyError('No SMPL for SMPLX available, please check keys:\n'
+ f'{content.files}')
+
+ print('Smplx param', smplx.keys())
+
+ if 'lhand_bbox_xywh' in content and 'rhand_bbox_xywh' in content:
+ lhand_bbox_xywh = content['lhand_bbox_xywh']
+ rhand_bbox_xywh = content['rhand_bbox_xywh']
+ else:
+ lhand_bbox_xywh = np.zeros_like(bbox_xywh)
+ rhand_bbox_xywh = np.zeros_like(bbox_xywh)
+
+ if 'face_bbox_xywh' in content:
+ face_bbox_xywh = content['face_bbox_xywh']
+ else:
+ face_bbox_xywh = np.zeros_like(bbox_xywh)
+
+ decompressed = False
+ if content['__keypoints_compressed__']:
+ decompressed_kps = self.decompress_keypoints(content)
+ decompressed = True
+
+ keypoints3d = None
+ valid_kps3d = False
+ keypoints3d_mask = None
+ valid_kps3d_mask = False
+ for kps3d_key in KPS3D_KEYS:
+ if kps3d_key in content:
+ keypoints3d = decompressed_kps[kps3d_key][:, self.SMPLX_137_MAPPING, :] if decompressed \
+ else content[kps3d_key][:, self.SMPLX_137_MAPPING, :]
+ valid_kps3d = True
+ if keypoints3d.shape[-1] == 4:
+ valid_kps3d_mask = True
+ break
+
+ if self.keypoints2d is not None:
+ keypoints2d = decompressed_kps[self.keypoints2d][:, self.SMPLX_137_MAPPING, :] if decompressed \
+ else content[self.keypoints2d][:, self.SMPLX_137_MAPPING, :]
+ keypoints2d = keypoints2d[:,:,:3]
+ if keypoints2d.shape[-1] == 3:
+ valid_kps3d_mask = True
+
+
+ print('Done. Time: {:.2f}s'.format(time.time() - tic))
+
+ datalist = []
+ num_examples
+
+ # processing each image, filter according to bbox valid
+ for i in tqdm.tqdm(range(int(num_examples))):
+ if self.data_split == 'train' and i % self.sample_interval != 0:
+ continue
+
+ frame_start, frame_end = frame_range[i]
+ img_path = osp.join(self.img_dir, image_path[frame_start])
+ vid_name = img_path.split('/')[-2]
+ if 'Trim' in vid_name:
+ vid_name = vid_name.split('_Trim')[0]
+ if str(vid_name) in self.test_vid_list:
+ continue
+ # im_shape = cv2.imread(img_path).shape[:2]
+ img_shape = image_shape[
+ frame_start] if image_shape is not None else self.img_shape
+
+ bbox_list = bbox_xywh[frame_start:frame_end, :4]
+
+ unique_bbox_idx = np.unique(bbox_list,axis=0,return_index=True)[1]
+ unique_bbox_idx.sort()
+ unique_bbox_list = bbox_list[unique_bbox_idx]
+
+ valid_idx = []
+ body_bbox_list = []
+
+ if hasattr(cfg, 'bbox_ratio'):
+ bbox_ratio = cfg.bbox_ratio * 0.833 # preprocess body bbox is giving 1.2 box padding
+ else:
+ bbox_ratio = 1.25
+
+ for bbox_i, bbox in zip(unique_bbox_idx,unique_bbox_list):
+
+ bbox = process_bbox(bbox,
+ img_width=img_shape[1],
+ img_height=img_shape[0],
+ ratio=bbox_ratio)
+ if bbox is None:
+ continue
+
+ if is_crowd[frame_start + bbox_i] == 0 and valid_label[frame_start + bbox_i] != 0 and keypoints_valid[frame_start + bbox_i] == True:
+
+ valid_idx.append(frame_start + bbox_i)
+ bbox[2:] += bbox[:2]
+ body_bbox_list.append(bbox)
+ if len(valid_idx) == 0:
+ continue
+ valid_num = len(valid_idx)
+ # hand/face bbox
+ lhand_bbox_list = []
+ rhand_bbox_list = []
+ face_bbox_list = []
+
+ for bbox_i in valid_idx:
+ lhand_bbox = lhand_bbox_xywh[bbox_i]
+ rhand_bbox = rhand_bbox_xywh[bbox_i]
+ face_bbox = face_bbox_xywh[bbox_i]
+ if lhand_valid[bbox_i] > 0: # conf > 0
+ lhand_bbox = lhand_bbox[:4]
+ if hasattr(cfg, 'bbox_ratio'):
+ lhand_bbox = process_bbox(lhand_bbox,
+ img_width=img_shape[1],
+ img_height=img_shape[0],
+ ratio=cfg.bbox_ratio)
+ if lhand_bbox is not None:
+ lhand_bbox[2:] += lhand_bbox[:2] # xywh -> xyxy
+ else:
+ lhand_bbox = None
+ if rhand_valid[bbox_i] > 0:
+ rhand_bbox = rhand_bbox[:4]
+ if hasattr(cfg, 'bbox_ratio'):
+ rhand_bbox = process_bbox(rhand_bbox,
+ img_width=img_shape[1],
+ img_height=img_shape[0],
+ ratio=cfg.bbox_ratio)
+ if rhand_bbox is not None:
+ rhand_bbox[2:] += rhand_bbox[:2] # xywh -> xyxy
+ else:
+ rhand_bbox = None
+ if face_valid[bbox_i] > 0:
+ face_bbox = face_bbox[:4]
+ if hasattr(cfg, 'bbox_ratio'):
+ face_bbox = process_bbox(face_bbox,
+ img_width=img_shape[1],
+ img_height=img_shape[0],
+ ratio=cfg.bbox_ratio)
+ if face_bbox is not None:
+ face_bbox[2:] += face_bbox[:2] # xywh -> xyxy
+ else:
+ face_bbox = None
+ lhand_bbox_list.append(lhand_bbox)
+ rhand_bbox_list.append(rhand_bbox)
+ face_bbox_list.append(face_bbox)
+
+ # lhand_bbox = np.stack(lhand_bbox_list,axis=0)
+ # rhand_bbox = np.stack(rhand_bbox_list,axis=0)
+ # face_bbox = np.stack(face_bbox_list,axis=0)
+ joint_img = keypoints2d[valid_idx]
+
+ # num_joints = joint_cam.shape[0]
+ # joint_valid = np.ones((num_joints, 1))
+ if valid_kps3d:
+ joint_cam = keypoints3d[valid_idx]
+ else:
+ joint_cam = None
+
+ if 'leye_pose_0' in smplx.keys():
+ smplx.pop('leye_pose_0')
+ if 'leye_pose_1' in smplx.keys():
+ smplx.pop('leye_pose_1')
+ if 'leye_pose' in smplx.keys():
+ smplx.pop('leye_pose')
+ if 'reye_pose_0' in smplx.keys():
+ smplx.pop('reye_pose_0')
+ if 'reye_pose_1' in smplx.keys():
+ smplx.pop('reye_pose_1')
+ if 'reye_pose' in smplx.keys():
+ smplx.pop('reye_pose')
+
+
+ smplx_param = {k: v[valid_idx] for k, v in smplx.items()}
+ gender_ = gender[valid_idx] \
+ if gender is not None else np.array(['neutral']*(valid_num))
+
+ # TODO: set invalid if None?
+ smplx_param['root_pose'] = smplx_param.pop('global_orient', None)
+ smplx_param['shape'] = smplx_param.pop('betas', None)
+ smplx_param['trans'] = smplx_param.pop('transl', np.zeros(3))
+ smplx_param['lhand_pose'] = smplx_param.pop('left_hand_pose', None)
+ smplx_param['rhand_pose'] = smplx_param.pop(
+ 'right_hand_pose', None)
+ smplx_param['expr'] = smplx_param.pop('expression', None)
+
+ # TODO do not fix betas, give up shape supervision
+ if 'betas_neutral' in smplx_param and self.data_split == 'train':
+ smplx_param['shape'] = smplx_param.pop('betas_neutral')
+ # smplx_param['shape'] = np.zeros(10, dtype=np.float32)
+
+ # # TODO fix shape of poses
+ if self.__class__.__name__ == 'Talkshow':
+ smplx_param['body_pose'] = smplx_param['body_pose'].reshape(
+ -1, 21, 3)
+ smplx_param['lhand_pose'] = smplx_param['lhand_pose'].reshape(
+ -1, 15, 3)
+ smplx_param['rhand_pose'] = smplx_param['lhand_pose'].reshape(
+ -1, 15, 3)
+ smplx_param['expr'] = smplx_param['expr'][:, :10]
+
+ if self.__class__.__name__ == 'BEDLAM':
+ smplx_param['shape'] = smplx_param['shape'][:, :10]
+
+ if as_smplx == 'smpl':
+ smplx_param['shape'] = np.zeros(
+ [valid_num, 10],
+ dtype=np.float32) # drop smpl betas for smplx
+ smplx_param['body_pose'] = smplx_param[
+ 'body_pose'][:, :21, :] # use smpl body_pose on smplx
+ if as_smplx == 'smplh':
+ smplx_param['shape'] = np.zeros(
+ [valid_num, 10],
+ dtype=np.float32) # drop smpl betas for smplx
+
+ if smplx_param['lhand_pose'] is None or self.body_only == True:
+ smplx_param['lhand_valid'] = np.zeros(valid_num, dtype=np.bool8)
+ else:
+ smplx_param['lhand_valid'] = lhand_valid[valid_idx]
+
+ if smplx_param['rhand_pose'] is None or self.body_only == True:
+ smplx_param['rhand_valid'] = np.zeros(valid_num, dtype=np.bool8)
+ else:
+ smplx_param['rhand_valid'] = rhand_valid[valid_idx]
+
+ if smplx_param['expr'] is None or self.body_only == True:
+ smplx_param['face_valid'] = np.zeros(valid_num, dtype=np.bool8)
+ else:
+ smplx_param['face_valid'] = face_valid[valid_idx]
+
+ if joint_cam is not None and np.any(np.isnan(joint_cam)):
+ continue
+
+
+
+ datalist.append({
+ 'img_path': img_path,
+ 'img_shape': img_shape,
+ 'bbox': body_bbox_list,
+ 'lhand_bbox': lhand_bbox_list,
+ 'rhand_bbox': rhand_bbox_list,
+ 'face_bbox': face_bbox_list,
+ 'joint_img': joint_img,
+ 'joint_cam': joint_cam,
+ 'smplx_param': smplx_param,
+ 'as_smplx': as_smplx,
+ 'gender': gender_
+ })
+
+ # save memory
+ del content, image_path, bbox_xywh, lhand_bbox_xywh, rhand_bbox_xywh, face_bbox_xywh, keypoints3d, keypoints2d
+
+ if self.data_split == 'train':
+ print(f'[{self.__class__.__name__} train] original size:',
+ int(num_examples), '. Sample interval:',
+ train_sample_interval, '. Sampled size:', len(datalist))
+
+ if getattr(cfg, 'data_strategy',
+ None) == 'balance' and self.data_split == 'train':
+ print(
+ f'[{self.__class__.__name__}] Using [balance] strategy with datalist shuffled...'
+ )
+ random.shuffle(datalist)
+
+ return datalist
+ def __getitem__(self, idx):
+ try:
+ data = copy.deepcopy(self.datalist[idx])
+ except Exception as e:
+ print(f'[{self.__class__.__name__}] Error loading data {idx}')
+ print(e)
+ exit(0)
+
+ img_path, img_shape, bbox = data['img_path'], data['img_shape'], data[
+ 'bbox']
+ as_smplx = data['as_smplx']
+ if 'gender' in data:
+ gender = data['gender'].copy()
+ for gender_str, gender_num in {
+ 'neutral': -1, 'male': 0, 'female': 1}.items():
+ gender[gender==gender_str]=gender_num
+ gender = gender.astype(int)
+ else:
+ gender = np.array([-1]*len(bbox))
+ img_whole_bbox = np.array([0, 0, img_shape[1], img_shape[0]])
+ img = load_img(img_path, order='BGR')
+ num_person = len(data['bbox'])
+ data_name = self.__class__.__name__
+ img, img2bb_trans, bb2img_trans, rot, do_flip = \
+ augmentation_instance_sample(img, img_whole_bbox, self.data_split,data,data_name)
+ cropped_img_shape=img.shape[:2]
+ num_person = len(data['bbox'])
+
+ if self.data_split == 'train':
+ # h36m gt
+ if 'joint_cam' in data:
+ joint_cam = data['joint_cam']
+ else:
+ joint_cam = None
+
+ if joint_cam is not None:
+ dummy_cord = False
+ joint_cam[:,:,:3] = joint_cam[:,:,:3] - joint_cam[:, self.
+ joint_set['root_joint_idx'],
+ None, :3] # root-relative
+ else:
+ # dummy cord as joint_cam
+ dummy_cord = True
+ joint_cam = np.zeros(
+ (num_person, self.joint_set['joint_num'], 4),
+ dtype=np.float32)
+
+ joint_img = data['joint_img']
+
+ # do rotation on keypoints
+ joint_img_aug, joint_cam_wo_ra, joint_cam_ra, joint_trunc = \
+ process_db_coord_batch_no_valid(
+ joint_img, joint_cam, do_flip, img_shape,
+ self.joint_set['flip_pairs'], img2bb_trans, rot,
+ self.joint_set['joints_name'], smpl_x.joints_name,
+ cropped_img_shape)
+ joint_img_aug[:,:,2:] = joint_img_aug[:,:,2:] * joint_trunc
+
+ # smplx coordinates and parameters
+ smplx_param = data['smplx_param']
+ if self.__class__.__name__ in ['CHI3D', 'SynBody']:
+ smplx_param['lhand_pose']-=self.lhand_mean[None]
+ smplx_param['rhand_pose']-=self.rhand_mean[None]
+ part_valid = {
+ 'lhand': smplx_param['lhand_valid'],
+ 'rhand': smplx_param['rhand_valid'],
+ 'face': smplx_param['face_valid']
+ }
+ smplx_pose, smplx_shape, smplx_expr, smplx_pose_valid, \
+ smplx_joint_valid, smplx_expr_valid, smplx_shape_valid = \
+ process_human_model_output_batch_ubody(
+ smplx_param, do_flip, rot, as_smplx, part_valid)
+
+ # if cam not provided, we take joint_img as smplx joint 2d,
+ # which is commonly the case for our processed humandata
+ # TODO temp fix keypoints3d for renbody
+
+
+ # change smplx_shape if use_betas_neutral
+ # processing follows that in process_human_model_output
+ if self.use_betas_neutral:
+ smplx_shape = smplx_param['betas_neutral'].reshape(
+ num_person, -1)
+ smplx_shape[(np.abs(smplx_shape) > 3).any(axis=1)] = 0.
+ smplx_shape = smplx_shape.reshape(num_person, -1)
+
+ # smplx_pose_valid = np.tile(smplx_pose_valid[:,:, None], (1, 3)).reshape(num_person,-1)
+
+ # smplx_pose = smplx_pose * smplx_pose_valid
+ # smplx_expr = smplx_expr * smplx_expr_valid[:, None]
+ smplx_joint_valid = smplx_joint_valid[:, :, None]
+
+ lhand_bbox_center_list = []
+ lhand_bbox_valid_list = []
+ lhand_bbox_size_list = []
+ lhand_bbox_list = []
+ face_bbox_center_list = []
+ face_bbox_size_list = []
+ face_bbox_valid_list = []
+ face_bbox_list = []
+ rhand_bbox_center_list = []
+ rhand_bbox_valid_list = []
+ rhand_bbox_size_list = []
+ rhand_bbox_list = []
+ body_bbox_center_list = []
+ body_bbox_size_list = []
+ body_bbox_valid_list = []
+ body_bbox_list = []
+ # hand and face bbox transform
+
+ for i in range(num_person):
+ # TODO: check if body bbox is invalid, it will assert error?
+ body_bbox, body_bbox_valid = self.process_hand_face_bbox(
+ data['bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+
+ lhand_bbox, lhand_bbox_valid = self.process_hand_face_bbox(
+ data['lhand_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ lhand_bbox_valid *= smplx_param['lhand_valid'][i]
+
+ rhand_bbox, rhand_bbox_valid = self.process_hand_face_bbox(
+ data['rhand_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ rhand_bbox_valid *= smplx_param['rhand_valid'][i]
+
+ face_bbox, face_bbox_valid = self.process_hand_face_bbox(
+ data['face_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ face_bbox_valid *= smplx_param['face_valid'][i]
+
+ if do_flip:
+ lhand_bbox, rhand_bbox = rhand_bbox, lhand_bbox
+ lhand_bbox_valid, rhand_bbox_valid = rhand_bbox_valid, lhand_bbox_valid
+
+ body_bbox_list.append(body_bbox)
+ lhand_bbox_list.append(lhand_bbox)
+ rhand_bbox_list.append(rhand_bbox)
+ face_bbox_list.append(face_bbox)
+
+ lhand_bbox_center = (lhand_bbox[0] + lhand_bbox[1]) / 2.
+ rhand_bbox_center = (rhand_bbox[0] + rhand_bbox[1]) / 2.
+ face_bbox_center = (face_bbox[0] + face_bbox[1]) / 2.
+ body_bbox_center = (body_bbox[0] + body_bbox[1]) / 2.
+ lhand_bbox_size = lhand_bbox[1] - lhand_bbox[0]
+ rhand_bbox_size = rhand_bbox[1] - rhand_bbox[0]
+
+ face_bbox_size = face_bbox[1] - face_bbox[0]
+ body_bbox_size = body_bbox[1] - body_bbox[0]
+ lhand_bbox_center_list.append(lhand_bbox_center)
+ lhand_bbox_valid_list.append(lhand_bbox_valid)
+ lhand_bbox_size_list.append(lhand_bbox_size)
+ face_bbox_center_list.append(face_bbox_center)
+ face_bbox_size_list.append(face_bbox_size)
+ face_bbox_valid_list.append(face_bbox_valid)
+ rhand_bbox_center_list.append(rhand_bbox_center)
+ rhand_bbox_valid_list.append(rhand_bbox_valid)
+ rhand_bbox_size_list.append(rhand_bbox_size)
+ body_bbox_center_list.append(body_bbox_center)
+ body_bbox_size_list.append(body_bbox_size)
+ body_bbox_valid_list.append(body_bbox_valid)
+
+
+ body_bbox = np.stack(body_bbox_list, axis=0)
+ lhand_bbox = np.stack(lhand_bbox_list, axis=0)
+ rhand_bbox = np.stack(rhand_bbox_list, axis=0)
+ face_bbox = np.stack(face_bbox_list, axis=0)
+ lhand_bbox_center = np.stack(lhand_bbox_center_list, axis=0)
+ lhand_bbox_valid = np.stack(lhand_bbox_valid_list, axis=0)
+ lhand_bbox_size = np.stack(lhand_bbox_size_list, axis=0)
+ face_bbox_center = np.stack(face_bbox_center_list, axis=0)
+ face_bbox_size = np.stack(face_bbox_size_list, axis=0)
+ face_bbox_valid = np.stack(face_bbox_valid_list, axis=0)
+ body_bbox_center = np.stack(body_bbox_center_list, axis=0)
+ body_bbox_size = np.stack(body_bbox_size_list, axis=0)
+ body_bbox_valid = np.stack(body_bbox_valid_list, axis=0)
+ rhand_bbox_center = np.stack(rhand_bbox_center_list, axis=0)
+ rhand_bbox_valid = np.stack(rhand_bbox_valid_list, axis=0)
+ rhand_bbox_size = np.stack(rhand_bbox_size_list, axis=0)
+
+ inputs = {'img': img}
+
+ is_3D = True
+ # joint_img_aug[:,:,2] = joint_img_aug[:,:,2] * body_bbox_valid[:,None]
+
+ # assign 2d kps valid to 3d kps
+ joint_cam_wo_ra[..., -1] = joint_img_aug[..., -1] * smplx_joint_valid[..., 0]
+ joint_cam_ra[..., -1] = joint_img_aug[..., -1] * smplx_joint_valid[..., 0]
+ joint_img_aug[...,-1] = joint_img_aug[...,-1] * smplx_joint_valid[...,0]
+ targets = {
+ # keypoints2d, [0,img_w],[0,img_h] -> [0,1] -> [0,output_hm_shape]
+ 'joint_img': joint_img_aug[body_bbox_valid>0],
+ # joint_cam, kp3d wo ra # raw kps3d probably without ra
+ 'joint_cam': joint_cam_wo_ra[body_bbox_valid>0],
+ # kps3d with body, face, hand ra
+ 'smplx_joint_cam': joint_cam_ra[body_bbox_valid>0],
+ 'smplx_pose': smplx_pose[body_bbox_valid>0],
+ 'smplx_shape': smplx_shape[body_bbox_valid>0],
+ 'smplx_expr': smplx_expr[body_bbox_valid>0],
+ 'lhand_bbox_center': lhand_bbox_center[body_bbox_valid>0],
+ 'lhand_bbox_size': lhand_bbox_size[body_bbox_valid>0],
+ 'rhand_bbox_center': rhand_bbox_center[body_bbox_valid>0],
+ 'rhand_bbox_size': rhand_bbox_size[body_bbox_valid>0],
+ 'face_bbox_center': face_bbox_center[body_bbox_valid>0],
+ 'face_bbox_size': face_bbox_size[body_bbox_valid>0],
+ 'body_bbox_center': body_bbox_center[body_bbox_valid>0],
+ 'body_bbox_size': body_bbox_size[body_bbox_valid>0],
+ 'body_bbox': body_bbox.reshape(-1,4)[body_bbox_valid>0],
+ 'lhand_bbox': lhand_bbox.reshape(-1,4)[body_bbox_valid>0],
+ 'rhand_bbox': rhand_bbox.reshape(-1,4)[body_bbox_valid>0],
+ 'face_bbox': face_bbox.reshape(-1,4)[body_bbox_valid>0],
+ 'gender': gender[body_bbox_valid>0]}
+
+ meta_info = {
+ 'joint_trunc': joint_trunc[body_bbox_valid>0],
+ 'smplx_pose_valid': smplx_pose_valid[body_bbox_valid>0],
+ 'smplx_shape_valid': smplx_shape_valid[body_bbox_valid>0],
+ 'smplx_expr_valid': smplx_expr_valid[body_bbox_valid>0],
+ 'is_3D': is_3D,
+ 'lhand_bbox_valid': lhand_bbox_valid[body_bbox_valid>0],
+ 'rhand_bbox_valid': rhand_bbox_valid[body_bbox_valid>0],
+ 'face_bbox_valid': face_bbox_valid[body_bbox_valid>0],
+ 'body_bbox_valid': body_bbox_valid[body_bbox_valid>0],
+ 'img_shape': np.array(img.shape[:2]),
+ 'ori_shape':data['img_shape'],
+ 'idx': idx,
+ }
+
+
+ result = {**inputs, **targets, **meta_info}
+
+ result = self.normalize(result)
+ result = self.format(result)
+ return result
+
+
+
+ if self.data_split == 'test':
+ self.cam_param = {}
+ if 'joint_cam' not in data:
+ joint_cam = None
+ else:
+ joint_cam = data['joint_cam']
+
+ if joint_cam is not None:
+ dummy_cord = False
+ joint_cam[:,:,:3] = joint_cam[:,:,:3] - joint_cam[
+ :, self.joint_set['root_joint_idx'], None, :3] # root-relative
+ else:
+ # dummy cord as joint_cam
+ dummy_cord = True
+ joint_cam = np.zeros((num_person, 137, 4), dtype=np.float32)
+
+ joint_img = data['joint_img']
+
+ joint_img_aug, joint_cam_wo_ra, joint_cam_ra, joint_trunc = \
+ process_db_coord_batch_no_valid(
+ joint_img, joint_cam, do_flip, img_shape,
+ self.joint_set['flip_pairs'], img2bb_trans, rot,
+ self.joint_set['joints_name'], smpl_x.joints_name,
+ cropped_img_shape)
+ joint_img_aug[:,:,2:] = joint_img_aug[:,:,2:] * joint_trunc
+
+
+ # smplx coordinates and parameters
+ smplx_param = data['smplx_param']
+ # smplx_cam_trans = np.array(
+ # smplx_param['trans']) if 'trans' in smplx_param else None
+ # TODO: remove this, seperate smpl and smplx
+ part_valid = {
+ 'lhand': smplx_param['lhand_valid'],
+ 'rhand': smplx_param['rhand_valid'],
+ 'face': smplx_param['face_valid']
+ }
+ smplx_pose, smplx_shape, smplx_expr, smplx_pose_valid, \
+ smplx_joint_valid, smplx_expr_valid, smplx_shape_valid = \
+ process_human_model_output_batch_ubody(
+ smplx_param, do_flip, rot, as_smplx, part_valid)
+ # if cam not provided, we take joint_img as smplx joint 2d,
+ # which is commonly the case for our processed humandata
+ if self.use_betas_neutral:
+ smplx_shape = smplx_param['betas_neutral'].reshape(
+ num_person, -1)
+ smplx_shape[(np.abs(smplx_shape) > 3).any(axis=1)] = 0.
+ smplx_shape = smplx_shape.reshape(num_person, -1)
+
+ # smplx_pose_valid = np.tile(smplx_pose_valid[:,:, None], (1, 3)).reshape(num_person,-1)
+ smplx_joint_valid = smplx_joint_valid[:, :, None]
+ # smplx_pose = smplx_pose*smplx_pose_valid
+ # smplx_expr = smplx_expr*smplx_expr_valid
+
+ # if not (smplx_shape == 0).all():
+ # smplx_shape_valid = True
+ # else:
+ # smplx_shape_valid = False
+ lhand_bbox_center_list = []
+ lhand_bbox_valid_list = []
+ lhand_bbox_size_list = []
+ lhand_bbox_list = []
+ face_bbox_center_list = []
+ face_bbox_size_list = []
+ face_bbox_valid_list = []
+ face_bbox_list = []
+ rhand_bbox_center_list = []
+ rhand_bbox_valid_list = []
+ rhand_bbox_size_list = []
+ rhand_bbox_list = []
+ body_bbox_center_list = []
+ body_bbox_size_list = []
+ body_bbox_valid_list = []
+ body_bbox_list = []
+
+ for i in range(num_person):
+ lhand_bbox, lhand_bbox_valid = self.process_hand_face_bbox(
+ data['lhand_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+
+ rhand_bbox, rhand_bbox_valid = self.process_hand_face_bbox(
+ data['rhand_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ lhand_bbox_valid *= smplx_param['lhand_valid'][i]
+
+ face_bbox, face_bbox_valid = self.process_hand_face_bbox(
+ data['face_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ rhand_bbox_valid *= smplx_param['rhand_valid'][i]
+
+ body_bbox, body_bbox_valid = self.process_hand_face_bbox(
+ data['bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ face_bbox_valid *= smplx_param['face_valid'][i]
+
+ if do_flip:
+ lhand_bbox, rhand_bbox = rhand_bbox, lhand_bbox
+ lhand_bbox_valid, rhand_bbox_valid = rhand_bbox_valid, lhand_bbox_valid
+
+ body_bbox_list.append(body_bbox)
+ lhand_bbox_list.append(lhand_bbox)
+ rhand_bbox_list.append(rhand_bbox)
+ face_bbox_list.append(face_bbox)
+
+ lhand_bbox_center = (lhand_bbox[0] + lhand_bbox[1]) / 2.
+ rhand_bbox_center = (rhand_bbox[0] + rhand_bbox[1]) / 2.
+ face_bbox_center = (face_bbox[0] + face_bbox[1]) / 2.
+ body_bbox_center = (body_bbox[0] + body_bbox[1]) / 2.
+ lhand_bbox_size = lhand_bbox[1] - lhand_bbox[0]
+ rhand_bbox_size = rhand_bbox[1] - rhand_bbox[0]
+
+ face_bbox_size = face_bbox[1] - face_bbox[0]
+ body_bbox_size = body_bbox[1] - body_bbox[0]
+ lhand_bbox_center_list.append(lhand_bbox_center)
+ lhand_bbox_valid_list.append(lhand_bbox_valid)
+ lhand_bbox_size_list.append(lhand_bbox_size)
+ face_bbox_center_list.append(face_bbox_center)
+ face_bbox_size_list.append(face_bbox_size)
+ face_bbox_valid_list.append(face_bbox_valid)
+ rhand_bbox_center_list.append(rhand_bbox_center)
+ rhand_bbox_valid_list.append(rhand_bbox_valid)
+ rhand_bbox_size_list.append(rhand_bbox_size)
+ body_bbox_center_list.append(body_bbox_center)
+ body_bbox_size_list.append(body_bbox_size)
+ body_bbox_valid_list.append(body_bbox_valid)
+
+ body_bbox = np.stack(body_bbox_list, axis=0)
+ lhand_bbox = np.stack(lhand_bbox_list, axis=0)
+ rhand_bbox = np.stack(rhand_bbox_list, axis=0)
+ face_bbox = np.stack(face_bbox_list, axis=0)
+ lhand_bbox_center = np.stack(lhand_bbox_center_list, axis=0)
+ lhand_bbox_valid = np.stack(lhand_bbox_valid_list, axis=0)
+ lhand_bbox_size = np.stack(lhand_bbox_size_list, axis=0)
+ face_bbox_center = np.stack(face_bbox_center_list, axis=0)
+ face_bbox_size = np.stack(face_bbox_size_list, axis=0)
+ face_bbox_valid = np.stack(face_bbox_valid_list, axis=0)
+ body_bbox_center = np.stack(body_bbox_center_list, axis=0)
+ body_bbox_size = np.stack(body_bbox_size_list, axis=0)
+ body_bbox_valid = np.stack(body_bbox_valid_list, axis=0)
+ rhand_bbox_center = np.stack(rhand_bbox_center_list, axis=0)
+ rhand_bbox_valid = np.stack(rhand_bbox_valid_list, axis=0)
+ rhand_bbox_size = np.stack(rhand_bbox_size_list, axis=0)
+
+ inputs = {'img': img}
+ joint_img_aug[:,:,2] = joint_img_aug[:,:,2] * body_bbox_valid[:,None]
+
+ # assign 2d kps valid to 3d kps
+ joint_cam_wo_ra[..., -1] = joint_img_aug[..., -1] * smplx_joint_valid[..., 0]
+ joint_cam_ra[..., -1] = joint_img_aug[..., -1] * smplx_joint_valid[..., 0]
+ joint_img_aug[...,-1] = joint_img_aug[...,-1] * smplx_joint_valid[...,0]
+ targets = {
+ # keypoints2d, [0,img_w],[0,img_h] -> [0,1] -> [0,output_hm_shape]
+ 'joint_img': joint_img_aug,
+ # projected smplx if valid cam_param, else same as keypoints2d
+ # joint_cam, kp3d wo ra # raw kps3d probably without ra
+ 'joint_cam': joint_cam_wo_ra,
+ 'ann_idx': idx,
+ # kps3d with body, face, hand ra
+ 'smplx_joint_cam': joint_cam_ra,
+ 'smplx_pose': smplx_pose,
+ 'smplx_shape': smplx_shape,
+ 'smplx_expr': smplx_expr,
+ 'lhand_bbox_center': lhand_bbox_center,
+ 'lhand_bbox_size': lhand_bbox_size,
+ 'rhand_bbox_center': rhand_bbox_center,
+ 'rhand_bbox_size': rhand_bbox_size,
+ 'face_bbox_center': face_bbox_center,
+ 'face_bbox_size': face_bbox_size,
+ 'body_bbox_center': body_bbox_center,
+ 'body_bbox_size': body_bbox_size,
+ 'body_bbox': body_bbox.reshape(-1,4),
+ 'lhand_bbox': lhand_bbox.reshape(-1,4),
+ 'rhand_bbox': rhand_bbox.reshape(-1,4),
+ 'face_bbox': face_bbox.reshape(-1,4),
+ 'gender': gender,
+ 'bb2img_trans': bb2img_trans,
+ }
+
+ if self.body_only:
+ meta_info = {
+ 'joint_trunc': joint_trunc,
+ 'smplx_pose_valid': smplx_pose_valid,
+ 'smplx_shape_valid': float(smplx_shape_valid),
+ 'smplx_expr_valid': smplx_expr_valid,
+ 'is_3D': float(False) if dummy_cord else float(True),
+ 'lhand_bbox_valid': lhand_bbox_valid,
+ 'rhand_bbox_valid': rhand_bbox_valid,
+ 'face_bbox_valid': face_bbox_valid,
+ 'body_bbox_valid': body_bbox_valid,
+ 'img_shape': np.array(img.shape[:2]),
+ 'ori_shape':data['img_shape']
+ }
+ else:
+ meta_info = {
+ 'joint_trunc': joint_trunc,
+ 'smplx_pose_valid': smplx_pose_valid,
+ 'smplx_shape_valid': smplx_shape_valid,
+ 'smplx_expr_valid': smplx_expr_valid,
+ 'is_3D': float(False) if dummy_cord else float(True),
+ 'lhand_bbox_valid': lhand_bbox_valid,
+ 'rhand_bbox_valid': rhand_bbox_valid,
+ 'face_bbox_valid': face_bbox_valid,
+ 'body_bbox_valid': body_bbox_valid,
+ 'img_shape': np.array(img.shape[:2]),
+ 'ori_shape':data['img_shape']}
+
+ result = {**inputs, **targets, **meta_info}
+ result = self.normalize(result)
+ result = self.format(result)
+ return result
+ def print_eval_result(self, eval_result):
+
+ print('UBody test results are dumped at: ' +
+ osp.join(cfg.result_dir, 'predictions'))
+
+ if self.data_split == 'test' and self.test_set == 'test': # do not print. just submit the results to the official evaluation server
+ return
+
+ print('======UBody-val======')
+ print('PA MPVPE (All): %.2f mm' % np.mean(eval_result['pa_mpvpe_all']))
+ print('PA MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_l_hand']))
+ print('PA MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_r_hand']))
+ print('PA MPVPE (Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_hand']))
+ print('PA MPVPE (Face): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_face']))
+ print()
+
+ print('MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all']))
+ print('MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_l_hand']))
+ print('MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_r_hand']))
+ print('MPVPE (Hands): %.2f mm' % np.mean(eval_result['mpvpe_hand']))
+ print('MPVPE (Face): %.2f mm' % np.mean(eval_result['mpvpe_face']))
+
+ f = open(os.path.join(cfg.result_dir, 'result.txt'), 'w')
+ f.write(f'UBody-val dataset: \n')
+ f.write('PA MPVPE (All): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_all']))
+ f.write('PA MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_l_hand']))
+ f.write('PA MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_r_hand']))
+ f.write('PA MPVPE (Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_hand']))
+ f.write('PA MPVPE (Face): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_face']))
+ f.write('MPVPE (All): %.2f mm\n' % np.mean(eval_result['mpvpe_all']))
+ f.write('MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_l_hand']))
+ f.write('MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_r_hand']))
+ f.write('MPVPE (Hands): %.2f mm' % np.mean(eval_result['mpvpe_hand']))
+ f.write('MPVPE (Face): %.2f mm\n' % np.mean(eval_result['mpvpe_face']))
diff --git a/datasets/__init__.py b/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/datasets/dataset.py b/datasets/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..01c9a3a3fdebcfeb047c6b67483d6d47d4e52b94
--- /dev/null
+++ b/datasets/dataset.py
@@ -0,0 +1,118 @@
+import random
+import numpy as np
+from torch.utils.data.dataset import Dataset
+from config.config import cfg
+
+class MultipleDatasets(Dataset):
+ def __init__(self,
+ dbs,
+ partition,
+ make_same_len=True,
+ total_len=None,
+ verbose=False):
+ self.dbs = dbs
+ self.db_num = len(self.dbs)
+ self.max_db_data_num = max([len(db) for db in dbs])
+ self.db_len_cumsum = np.cumsum([len(db) for db in dbs])
+ self.make_same_len = make_same_len
+ # self.partition = partition
+ self.partition = {k: v for k, v in sorted(partition.items(), key=lambda item: item[1])}
+ self.dataset = {}
+ for db in dbs:
+ self.dataset.update({db.__class__.__name__: db})
+
+ if verbose:
+ print('datasets:', [len(self.dbs[i]) for i in range(self.db_num)])
+ print(
+ f'Sample Ratio: {self.partition}')
+
+ def __len__(self):
+ return self.max_db_data_num
+
+ def __getitem__(self, index):
+ p = np.random.rand()
+ v = list(self.partition.values())
+ k = list(self.partition.keys())
+ for i,v_i in enumerate(v):
+ if p<=v_i:
+ return self.dataset[k[i]][index % len(self.dataset[k[i]])]
+
+
+import random
+import numpy as np
+from torch.utils.data.dataset import Dataset
+
+
+class MultipleDatasets_debug(Dataset):
+ def __init__(self, dbs, make_same_len=True, total_len=None, verbose=False):
+ self.dbs = dbs
+ self.db_num = len(self.dbs)
+ self.max_db_data_num = max([len(db) for db in dbs])
+ self.db_len_cumsum = np.cumsum([len(db) for db in dbs])
+ self.make_same_len = make_same_len
+
+ if total_len == 'auto':
+ self.total_len = self.db_len_cumsum[-1]
+ self.auto_total_len = True
+ else:
+ self.total_len = total_len
+ self.auto_total_len = False
+
+ if total_len is not None:
+ self.per_db_len = self.total_len // self.db_num
+ if verbose:
+ print('datasets:', [len(self.dbs[i]) for i in range(self.db_num)])
+ print(
+ f'Auto total length: {self.auto_total_len}, {self.total_len}')
+
+ def __len__(self):
+ # all dbs have the same length
+ if self.make_same_len:
+ if self.total_len is None:
+ # match the longest length
+ return self.max_db_data_num * self.db_num
+ else:
+ # each dataset has the same length and total len is fixed
+ return self.total_len
+ else:
+ # each db has different length, simply concat
+ return sum([len(db) for db in self.dbs])
+
+ def __getitem__(self, index):
+ if self.make_same_len:
+ if self.total_len is None:
+ # match the longest length
+ db_idx = index // self.max_db_data_num
+ data_idx = index % self.max_db_data_num
+ if data_idx >= len(self.dbs[db_idx]) * (
+ self.max_db_data_num //
+ len(self.dbs[db_idx])): # last batch: random sampling
+ data_idx = random.randint(0, len(self.dbs[db_idx]) - 1)
+ else: # before last batch: use modular
+ data_idx = data_idx % len(self.dbs[db_idx])
+ else:
+ db_idx = index // self.per_db_len
+ data_idx = index % self.per_db_len
+ if db_idx > (self.db_num - 1):
+ # last batch: randomly choose one dataset
+ db_idx = random.randint(0, self.db_num - 1)
+
+ if len(self.dbs[db_idx]) < self.per_db_len and \
+ data_idx >= len(self.dbs[db_idx]) * (self.per_db_len // len(self.dbs[db_idx])):
+ # last batch: random sampling in this dataset
+ data_idx = random.randint(0, len(self.dbs[db_idx]) - 1)
+ else:
+ # before last batch: use modular
+ data_idx = data_idx % len(self.dbs[db_idx])
+
+ else:
+ for i in range(self.db_num):
+ if index < self.db_len_cumsum[i]:
+ db_idx = i
+ break
+ if db_idx == 0:
+ data_idx = index
+ else:
+ data_idx = index - self.db_len_cumsum[db_idx - 1]
+
+ return self.dbs[db_idx][data_idx]
diff --git a/datasets/humandata.py b/datasets/humandata.py
new file mode 100644
index 0000000000000000000000000000000000000000..82115c4518736a273d05d8b644708154b4f71c9f
--- /dev/null
+++ b/datasets/humandata.py
@@ -0,0 +1,1301 @@
+import os
+import os.path as osp
+import numpy as np
+import torch
+import cv2
+import json
+import copy
+from pycocotools.coco import COCO
+from config.config import cfg
+from util.human_models import smpl_x
+from util.preprocessing import (
+ load_img, process_bbox, augmentation_instance_sample, process_human_model_output_batch_simplify,process_db_coord_batch_no_valid)
+from util.transforms import world2cam, cam2pixel, rigid_align
+from detrsmpl.utils.geometry import batch_rodrigues, project_points_new, weak_perspective_projection, perspective_projection
+import tqdm
+import time
+import random
+from detrsmpl.utils.demo_utils import box2cs, xywh2xyxy, xyxy2xywh
+import torch.distributed as dist
+
+KPS2D_KEYS = [
+ 'keypoints2d_ori', 'keypoints2d_smplx', 'keypoints2d_smpl',
+ 'keypoints2d_original','keypoints2d_gta','keypoints2d'
+]
+KPS3D_KEYS = [
+ 'keypoints3d_cam', 'keypoints3d', 'keypoints3d_smplx', 'keypoints3d_smpl',
+ 'keypoints3d_original', 'keypoints3d_gta','keypoints3d'
+]
+# keypoints3d_cam with root-align has higher priority, followed by old version key keypoints3d
+# when there is keypoints3d_smplx, use this rather than keypoints3d_original
+
+from util.formatting import DefaultFormatBundle
+from detrsmpl.data.datasets.pipelines.transforms import Normalize
+
+class Cache():
+ """A custom implementation for OSX pipeline."""
+ def __init__(self, load_path=None):
+ if load_path is not None:
+ self.load(load_path)
+
+ def load(self, load_path):
+ self.load_path = load_path
+ self.cache = np.load(load_path, allow_pickle=True)
+ self.data_len = self.cache['data_len']
+ self.data_strategy = self.cache['data_strategy']
+ assert self.data_len == len(self.cache) - 2 # data_len, data_strategy
+ self.cache = None
+
+ @classmethod
+ def save(cls, save_path, data_list, data_strategy):
+ assert save_path is not None, 'save_path is None'
+ data_len = len(data_list)
+ cache = {}
+ for i, data in enumerate(data_list):
+ cache[str(i)] = data
+ assert len(cache) == data_len
+ # update meta
+ cache.update({'data_len': data_len, 'data_strategy': data_strategy})
+ # import pdb; pdb.set_trace()
+ np.savez_compressed(save_path, **cache)
+ print(f'Cache saved to {save_path}.')
+
+ # def shuffle(self):
+ # random.shuffle(self.mapping)
+
+ def __len__(self):
+ return self.data_len
+
+ def __getitem__(self, idx):
+ if self.cache is None:
+ self.cache = np.load(self.load_path, allow_pickle=True)
+ # mapped_idx = self.mapping[idx]
+ # cache_data = self.cache[str(mapped_idx)]
+ # print(self.cache.files)
+ cache_data = self.cache[str(idx)]
+ data = cache_data.item()
+ return data
+
+
+class HumanDataset(torch.utils.data.Dataset):
+
+ # same mapping for 144->137 and 190->137
+ SMPLX_137_MAPPING = [
+ 0, 1, 2, 4, 5, 7, 8, 12, 16, 17, 18, 19, 20, 21, 60, 61, 62, 63, 64,
+ 65, 59, 58, 57, 56, 55, 37, 38, 39, 66, 25, 26, 27, 67, 28, 29, 30, 68,
+ 34, 35, 36, 69, 31, 32, 33, 70, 52, 53, 54, 71, 40, 41, 42, 72, 43, 44,
+ 45, 73, 49, 50, 51, 74, 46, 47, 48, 75, 22, 15, 56, 57, 76, 77, 78, 79,
+ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
+ 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
+ 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
+ 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
+ 140, 141, 142, 143
+ ]
+
+ def __init__(self, transform, data_split):
+ self.transform = transform
+ self.data_split = data_split
+
+ # dataset information, to be filled by child class
+ self.img_dir = None
+ self.annot_path = None
+ self.annot_path_cache = None
+ self.use_cache = False
+ self.img_shape = None # (h, w)
+ self.cam_param = None # {'focal_length': (fx, fy), 'princpt': (cx, cy)}
+ self.use_betas_neutral = False
+ self.body_only = False
+ self.joint_set = {
+ 'joint_num': smpl_x.joint_num,
+ 'joints_name': smpl_x.joints_name,
+ 'flip_pairs': smpl_x.flip_pairs
+ }
+ self.joint_set['root_joint_idx'] = self.joint_set['joints_name'].index(
+ 'Pelvis')
+ self.format = DefaultFormatBundle()
+ self.normalize = Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])
+ self.keypoints2d = None
+ # self.rank = dist.get_rank()
+ self.lhand_mean = smpl_x.layer['neutral'].left_hand_mean.reshape(15, 3).cpu().numpy()
+ self.rhand_mean = smpl_x.layer['neutral'].right_hand_mean.reshape(15, 3).cpu().numpy()
+ # self.log_file_path = f'indices_node{rank}.txt'
+ def load_cache(self, annot_path_cache):
+ datalist = Cache(annot_path_cache)
+ # assert datalist.data_strategy == getattr(cfg, 'data_strategy', None), \
+ # f'Cache data strategy {datalist.data_strategy} does not match current data strategy ' \
+ # f'{getattr(cfg, "data_strategy", None)}'
+ return datalist
+
+ def save_cache(self, annot_path_cache, datalist):
+ print(
+ f'[{self.__class__.__name__}] Caching datalist to {self.annot_path_cache}...'
+ )
+ Cache.save(annot_path_cache,
+ datalist,
+ data_strategy=getattr(cfg, 'data_strategy', None))
+
+ def load_data(self, train_sample_interval=1,
+ hand_bbox_ratio=1, body_bbox_ratio=1):
+
+ content = np.load(self.annot_path, allow_pickle=True)
+ try:
+ frame_range = content['frame_range']
+ except KeyError:
+ self.num_data = len(content['image_path'])
+ frame_range = \
+ np.array([[i, i + 1] for i in range(self.num_data)])
+
+ num_examples = len(frame_range)
+ if 'meta' in content:
+ meta = content['meta'].item()
+ print('meta keys:', meta.keys())
+ else:
+ meta = None
+ print(
+ 'No meta info provided! Please give height and width manually')
+
+ print(
+ f'Start loading humandata {self.annot_path} into memory...\nDataset includes: {content.files}'
+ )
+ tic = time.time()
+ image_path = content['image_path']
+ if meta is not None and 'height' in meta and len(meta['height'])>0:
+ height = np.array(meta['height'])
+ width = np.array(meta['width'])
+ image_shape = np.stack([height, width], axis=-1)
+ else:
+ image_shape = None
+
+ if meta is not None and 'gender' in meta and len(meta['gender']) != 0:
+ gender = np.array(meta['gender'])
+ else:
+ gender = None
+ bbox_xywh = content['bbox_xywh']
+
+ if 'smplx' in content:
+ smplx = content['smplx'].item()
+ as_smplx = 'smplx'
+ elif 'smpl' in content:
+ smplx = content['smpl'].item()
+ as_smplx = 'smpl'
+ elif 'smplh' in content:
+ smplx = content['smplh'].item()
+ as_smplx = 'smplh'
+ # TODO: temp solution, should be more general. But SHAPY is very special
+ elif self.__class__.__name__ == 'SHAPY':
+ smplx = {}
+ else:
+ raise KeyError('No SMPL for SMPLX available, please check keys:\n'
+ f'{content.files}')
+
+ print('Smplx param', smplx.keys())
+
+ if 'lhand_bbox_xywh' in content and 'rhand_bbox_xywh' in content:
+ lhand_bbox_xywh = content['lhand_bbox_xywh']
+ rhand_bbox_xywh = content['rhand_bbox_xywh']
+ else:
+ lhand_bbox_xywh = np.zeros_like(bbox_xywh)
+ rhand_bbox_xywh = np.zeros_like(bbox_xywh)
+
+ if 'face_bbox_xywh' in content:
+ face_bbox_xywh = content['face_bbox_xywh']
+ else:
+ face_bbox_xywh = np.zeros_like(bbox_xywh)
+
+ if meta is not None and 'smplx_valid' in meta:
+ smplx_valid = meta['smplx_valid']
+ else:
+ smplx_valid = np.ones(len(bbox_xywh))
+
+ decompressed = False
+ if content['__keypoints_compressed__']:
+ decompressed_kps = self.decompress_keypoints(content)
+ decompressed = True
+
+ keypoints3d = None
+ valid_kps3d = False
+ keypoints3d_mask = None
+ valid_kps3d_mask = False
+
+ # processing keypoints
+ for kps3d_key in KPS3D_KEYS:
+ if kps3d_key in content:
+ keypoints3d = decompressed_kps[kps3d_key][:, self.SMPLX_137_MAPPING, :] if decompressed \
+ else content[kps3d_key][:, self.SMPLX_137_MAPPING, :]
+ valid_kps3d = True
+ if keypoints3d.shape[-1] == 4:
+ valid_kps3d_mask = True
+ break
+ if self.keypoints2d is not None:
+ keypoints2d = decompressed_kps[self.keypoints2d][:, self.SMPLX_137_MAPPING, :] if decompressed \
+ else content[self.keypoints2d][:, self.SMPLX_137_MAPPING, :]
+
+
+ else:
+ for kps2d_key in KPS2D_KEYS:
+ if kps2d_key in content:
+ keypoints2d = decompressed_kps[kps2d_key][:, self.SMPLX_137_MAPPING, :] if decompressed \
+ else content[kps2d_key][:, self.SMPLX_137_MAPPING, :]
+ break
+ if keypoints2d.shape[-1] == 3:
+ valid_kps3d_mask = True
+
+ print('Done. Time: {:.2f}s'.format(time.time() - tic))
+
+ datalist = []
+ # num_examples
+
+ # processing each image, filter according to bbox valid
+ for i in tqdm.tqdm(range(int(num_examples))):
+
+ if self.data_split == 'train' and i % train_sample_interval != 0:
+ continue
+ frame_start, frame_end = frame_range[i]
+ img_path = osp.join(self.img_dir, image_path[frame_start])
+ # im_shape = cv2.imread(img_path).shape[:2]
+ img_shape = image_shape[
+ frame_start] if image_shape is not None else self.img_shape
+
+
+ bbox_list = bbox_xywh[frame_start:frame_end, :4]
+
+ valid_idx = []
+ body_bbox_list = []
+
+ # if hasattr(cfg, 'bbox_ratio'):
+ # bbox_ratio = cfg.bbox_ratio * 0.833 # preprocess body bbox is giving 1.2 box padding
+ # else:
+ # bbox_ratio = 1.25
+ # if self.__class__.__name__ == 'SPEC':
+ # bbox_ratio = 1.25
+
+ for bbox_i, bbox in enumerate(bbox_list):
+
+ bbox = process_bbox(bbox,
+ img_width=img_shape[1],
+ img_height=img_shape[0],
+ ratio=body_bbox_ratio)
+ if bbox is None:
+ continue
+ else:
+ valid_idx.append(frame_start + bbox_i)
+ bbox[2:] += bbox[:2]
+ body_bbox_list.append(bbox)
+
+ if len(valid_idx) == 0:
+ continue
+ valid_num = len(valid_idx)
+ # hand/face bbox
+ lhand_bbox_list = []
+ rhand_bbox_list = []
+ face_bbox_list = []
+ smplx_valid_list = []
+ for bbox_i in valid_idx:
+ smplx_valid_list.append(smplx_valid[bbox_i])
+ lhand_bbox = lhand_bbox_xywh[bbox_i]
+ rhand_bbox = rhand_bbox_xywh[bbox_i]
+ face_bbox = face_bbox_xywh[bbox_i]
+ if lhand_bbox[-1] > 0: # conf > 0
+ lhand_bbox = lhand_bbox[:4]
+ # if hasattr(cfg, 'bbox_ratio'):
+ lhand_bbox = process_bbox(lhand_bbox,
+ img_width=img_shape[1],
+ img_height=img_shape[0],
+ ratio=hand_bbox_ratio)
+ if lhand_bbox is not None:
+ lhand_bbox[2:] += lhand_bbox[:2] # xywh -> xyxy
+ else:
+ lhand_bbox = None
+ if rhand_bbox[-1] > 0:
+ rhand_bbox = rhand_bbox[:4]
+ # if hasattr(cfg, 'bbox_ratio'):
+ rhand_bbox = process_bbox(rhand_bbox,
+ img_width=img_shape[1],
+ img_height=img_shape[0],
+ ratio=hand_bbox_ratio)
+ if rhand_bbox is not None:
+ rhand_bbox[2:] += rhand_bbox[:2] # xywh -> xyxy
+ else:
+ rhand_bbox = None
+ if face_bbox[-1] > 0:
+ face_bbox = face_bbox[:4]
+ # if hasattr(cfg, 'bbox_ratio'):
+ face_bbox = process_bbox(face_bbox,
+ img_width=img_shape[1],
+ img_height=img_shape[0],
+ ratio=hand_bbox_ratio)
+ if face_bbox is not None:
+ face_bbox[2:] += face_bbox[:2] # xywh -> xyxy
+ else:
+ face_bbox = None
+ lhand_bbox_list.append(lhand_bbox)
+ rhand_bbox_list.append(rhand_bbox)
+ face_bbox_list.append(face_bbox)
+
+ joint_img = keypoints2d[valid_idx]
+
+ if valid_kps3d:
+ joint_cam = keypoints3d[valid_idx]
+ else:
+ joint_cam = None
+ if 'leye_pose_0' in smplx.keys():
+ smplx.pop('leye_pose_0')
+ if 'leye_pose_1' in smplx.keys():
+ smplx.pop('leye_pose_1')
+ if 'leye_pose' in smplx.keys():
+ smplx.pop('leye_pose')
+ if 'reye_pose_0' in smplx.keys():
+ smplx.pop('reye_pose_0')
+ if 'reye_pose_1' in smplx.keys():
+ smplx.pop('reye_pose_1')
+ if 'reye_pose' in smplx.keys():
+ smplx.pop('reye_pose')
+
+
+ smplx_param = {k: v[valid_idx] for k, v in smplx.items()}
+ gender_ = gender[valid_idx] \
+ if gender is not None else np.array(['neutral']*(valid_num))
+ lhand_bbox_valid = lhand_bbox_xywh[valid_idx,4]
+ rhand_bbox_valid = rhand_bbox_xywh[valid_idx,4]
+ face_bbox_valid = face_bbox_xywh[valid_idx,4]
+
+ # TODO: set invalid if None?
+ smplx_param['root_pose'] = smplx_param.pop('global_orient', None)
+ smplx_param['shape'] = smplx_param.pop('betas', None)
+ smplx_param['trans'] = smplx_param.pop('transl', np.zeros([len(valid_idx),3]))
+ smplx_param['lhand_pose'] = smplx_param.pop('left_hand_pose', None)
+ smplx_param['rhand_pose'] = smplx_param.pop(
+ 'right_hand_pose', None)
+ smplx_param['expr'] = smplx_param.pop('expression', None)
+
+ # TODO do not fix betas, give up shape supervision
+ if 'betas_neutral' in smplx_param and self.data_split == 'train':
+ smplx_param['shape'] = smplx_param.pop('betas_neutral')
+ # smplx_param['shape'] = np.zeros(10, dtype=np.float32)
+
+ # # TODO fix shape of poses
+ if self.__class__.__name__ == 'Talkshow':
+ smplx_param['body_pose'] = smplx_param['body_pose'].reshape(
+ -1, 21, 3)
+ smplx_param['lhand_pose'] = smplx_param['lhand_pose'].reshape(
+ -1, 15, 3)
+ smplx_param['rhand_pose'] = smplx_param['lhand_pose'].reshape(
+ -1, 15, 3)
+ smplx_param['expr'] = smplx_param['expr'][:, :10]
+
+ if self.__class__.__name__ == 'BEDLAM':
+ smplx_param['shape'] = smplx_param['shape'][:, :10]
+ # smplx_param['expr'] = None
+ if self.__class__.__name__ == 'GTA':
+ smplx_param['shape'] = np.zeros(
+ [valid_num, 10],
+ dtype=np.float32)
+ if self.__class__.__name__ == 'COCO_NA':
+ # smplx_param['expr'] = None
+ smplx_param['body_pose'] = smplx_param['body_pose'].reshape(
+ -1, 21, 3)
+ smplx_param['lhand_pose'] = smplx_param['lhand_pose'].reshape(
+ -1, 15, 3)
+ smplx_param['rhand_pose'] = smplx_param['rhand_pose'].reshape(
+ -1, 15, 3)
+ if as_smplx == 'smpl':
+ smplx_param['shape'] = np.zeros(
+ [valid_num, 10],
+ dtype=np.float32) # drop smpl betas for smplx
+ smplx_param['body_pose'] = smplx_param[
+ 'body_pose'].reshape(-1,23,3)[:, :21, :] # use smpl body_pose on smplx
+ if as_smplx == 'smplh':
+ smplx_param['shape'] = np.zeros(
+ [valid_num, 10],
+ dtype=np.float32) # drop smpl betas for smplx
+
+ if smplx_param['lhand_pose'] is None or self.body_only == True:
+ smplx_param['lhand_valid'] = np.zeros(valid_num, dtype=np.bool8)
+ else:
+ smplx_param['lhand_valid'] = lhand_bbox_valid.astype(np.bool8)
+
+ if smplx_param['rhand_pose'] is None or self.body_only == True:
+ smplx_param['rhand_valid'] = np.zeros(valid_num, dtype=np.bool8)
+ else:
+ smplx_param['rhand_valid'] = rhand_bbox_valid.astype(np.bool8)
+
+ if smplx_param['expr'] is None or self.body_only == True:
+ smplx_param['face_valid'] = np.zeros(valid_num, dtype=np.bool8)
+ else:
+ smplx_param['face_valid'] = face_bbox_valid.astype(np.bool8)
+
+ smplx_param['smplx_valid'] = np.array(smplx_valid_list).astype(np.bool8)
+ if joint_cam is not None and np.any(np.isnan(joint_cam)):
+ continue
+
+
+ if self.__class__.__name__ == 'SPEC':
+ joint_img[:,:,2] = joint_img[:,:,2]>0
+ joint_cam[:,:,3] = joint_cam[:,:,0]!=0
+ datalist.append({
+ 'img_path': img_path,
+ 'img_shape': img_shape,
+ 'bbox': body_bbox_list,
+ 'lhand_bbox': lhand_bbox_list,
+ 'rhand_bbox': rhand_bbox_list,
+ 'face_bbox': face_bbox_list,
+ 'joint_img': joint_img,
+ 'joint_cam': joint_cam,
+ 'smplx_param': smplx_param,
+ 'as_smplx': as_smplx,
+ 'gender': gender_
+ })
+
+ # save memory
+ del content, image_path, bbox_xywh, lhand_bbox_xywh, rhand_bbox_xywh, face_bbox_xywh, keypoints3d, keypoints2d
+
+ if self.data_split == 'train':
+ print(f'[{self.__class__.__name__} train] original size:',
+ int(num_examples), '. Sample interval:',
+ train_sample_interval, '. Sampled size:', len(datalist))
+
+ if getattr(cfg, 'data_strategy',
+ None) == 'balance' and self.data_split == 'train':
+ print(
+ f'[{self.__class__.__name__}] Using [balance] strategy with datalist shuffled...'
+ )
+ random.shuffle(datalist)
+
+ return datalist
+
+ def __len__(self):
+ return len(self.datalist)
+ # 19493
+ def __getitem__(self, idx):
+ # rank = self.rank
+ # local_rank = rank % torch.cuda.device_count()
+ # with open(f'index_log_{rank}.txt', 'a') as f:
+ # f.write(f'{rank}-{local_rank}-{idx}\n')
+ try:
+ data = copy.deepcopy(self.datalist[idx])
+ except Exception as e:
+ print(f'[{self.__class__.__name__}] Error loading data {idx}')
+ print(e)
+ exit(0)
+ # data/datasets/coco_2017/train2017/000000029582.jpg' 45680
+ img_path, img_shape, bbox = \
+ data['img_path'], data['img_shape'], data['bbox']
+ as_smplx = data['as_smplx']
+ gender = data['gender'].copy()
+ for gender_str, gender_num in {
+ 'neutral': -1, 'male': 0, 'female': 1}.items():
+ gender[gender==gender_str]=gender_num
+ gender = gender.astype(int)
+
+ img_whole_bbox = np.array([0, 0, img_shape[1], img_shape[0]])
+ img = load_img(img_path, order='BGR')
+
+ num_person = len(data['bbox'])
+ data_name = self.__class__.__name__
+ try:
+ # dist.barrier()
+ img, img2bb_trans, bb2img_trans, rot, do_flip = \
+ augmentation_instance_sample(img, img_whole_bbox, self.data_split, data, data_name)
+ except Exception as e:
+ rank = self.rank
+ local_rank = rank % torch.cuda.device_count()
+ with open(f'index_log_{rank}.txt', 'a') as f:
+ f.write(f'{rank}-{local_rank}-{idx}\n')
+ f.write(f'[{self.__class__.__name__}] Error loading data {idx}\n')
+ f.write(f'Error in augmentation_instance_sample for {img_path}\n')
+ # print(f'[{self.__class__.__name__}] Error loading data {idx}')
+ # print(f'Error in augmentation_instance_sample for {img_path}')
+ raise e
+ cropped_img_shape = img.shape[:2]
+
+ if self.data_split == 'train':
+ joint_cam = data['joint_cam'] # num, 137,4
+ if joint_cam is not None:
+ dummy_cord = False
+ joint_cam[:,:,:3] = \
+ joint_cam[:,:,:3] - joint_cam[:, self.joint_set['root_joint_idx'], None, :3] # root-relative
+ else:
+ # dummy cord as joint_cam
+ dummy_cord = True
+ joint_cam = np.zeros(
+ (num_person, self.joint_set['joint_num'], 4),
+ dtype=np.float32)
+
+ joint_img = data['joint_img']
+ # do rotation on keypoints
+ joint_img_aug, joint_cam_wo_ra, joint_cam_ra, joint_trunc = \
+ process_db_coord_batch_no_valid(
+ joint_img, joint_cam, do_flip, img_shape,
+ self.joint_set['flip_pairs'], img2bb_trans, rot,
+ self.joint_set['joints_name'], smpl_x.joints_name,
+ cropped_img_shape)
+ joint_img_aug[:,:,2:] = joint_img_aug[:,:,2:] * joint_trunc
+
+ # smplx coordinates and parameters
+ smplx_param = data['smplx_param']
+
+
+ if self.__class__.__name__ in [ 'CHI3D', 'SynBody', 'UBody_MM']:
+ smplx_param['lhand_pose']-=self.lhand_mean[None]
+ smplx_param['rhand_pose']-=self.rhand_mean[None]
+ # smplx_param
+ smplx_pose, smplx_shape, smplx_expr, smplx_pose_valid, \
+ smplx_joint_valid, smplx_expr_valid, smplx_shape_valid = \
+ process_human_model_output_batch_simplify(
+ smplx_param, do_flip, rot, as_smplx, data_name)
+ smplx_joint_valid = smplx_joint_valid[:, :, None]
+ # if cam not provided, we take joint_img as smplx joint 2d,
+ # which is commonly the case for our processed humandata
+ # change smplx_shape if use_betas_neutral
+ # processing follows that in process_human_model_output
+ if self.use_betas_neutral:
+ smplx_shape = smplx_param['betas_neutral'].reshape(
+ num_person, -1)
+ smplx_shape[(np.abs(smplx_shape) > 3).any(axis=1)] = 0.
+ smplx_shape = smplx_shape.reshape(num_person, -1)
+
+ if self.__class__.__name__ == 'MPII_MM' :
+ for name in ('L_Ankle', 'R_Ankle', 'L_Wrist', 'R_Wrist'):
+ smplx_pose_valid[:, smpl_x.orig_joints_name.index(name)] = 0
+ for name in ('L_Big_toe', 'L_Small_toe', 'L_Heel', 'R_Big_toe', 'R_Small_toe', 'R_Heel'):
+ smplx_joint_valid[:,smpl_x.joints_name.index(name)] = 0
+
+
+ lhand_bbox_center_list = []
+ lhand_bbox_valid_list = []
+ lhand_bbox_size_list = []
+ lhand_bbox_list = []
+ face_bbox_center_list = []
+ face_bbox_size_list = []
+ face_bbox_valid_list = []
+ face_bbox_list = []
+ rhand_bbox_center_list = []
+ rhand_bbox_valid_list = []
+ rhand_bbox_size_list = []
+ rhand_bbox_list = []
+ body_bbox_center_list = []
+ body_bbox_size_list = []
+ body_bbox_valid_list = []
+ body_bbox_list = []
+ # hand and face bbox transform
+
+
+ for i in range(num_person):
+ body_bbox, body_bbox_valid = self.process_hand_face_bbox(
+ data['bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+
+ lhand_bbox, lhand_bbox_valid = self.process_hand_face_bbox(
+ data['lhand_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ lhand_bbox_valid *= smplx_param['lhand_valid'][i]
+
+ rhand_bbox, rhand_bbox_valid = self.process_hand_face_bbox(
+ data['rhand_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ rhand_bbox_valid *= smplx_param['rhand_valid'][i]
+
+ face_bbox, face_bbox_valid = self.process_hand_face_bbox(
+ data['face_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ face_bbox_valid *= smplx_param['face_valid'][i]
+
+ # BEDLAM and COCO_NA do not have face expression
+ # if self.__class__.__name__ != 'BEDLAM':
+ # face_bbox_valid *= smplx_param['face_valid'][i]
+
+ if do_flip:
+ lhand_bbox, rhand_bbox = rhand_bbox, lhand_bbox
+ lhand_bbox_valid, rhand_bbox_valid = rhand_bbox_valid, lhand_bbox_valid
+
+ body_bbox_list.append(body_bbox)
+ lhand_bbox_list.append(lhand_bbox)
+ rhand_bbox_list.append(rhand_bbox)
+ face_bbox_list.append(face_bbox)
+
+ lhand_bbox_center = (lhand_bbox[0] + lhand_bbox[1]) / 2.
+ rhand_bbox_center = (rhand_bbox[0] + rhand_bbox[1]) / 2.
+ face_bbox_center = (face_bbox[0] + face_bbox[1]) / 2.
+ body_bbox_center = (body_bbox[0] + body_bbox[1]) / 2.
+ lhand_bbox_size = lhand_bbox[1] - lhand_bbox[0]
+ rhand_bbox_size = rhand_bbox[1] - rhand_bbox[0]
+
+ face_bbox_size = face_bbox[1] - face_bbox[0]
+ body_bbox_size = body_bbox[1] - body_bbox[0]
+ lhand_bbox_center_list.append(lhand_bbox_center)
+ lhand_bbox_valid_list.append(lhand_bbox_valid)
+ lhand_bbox_size_list.append(lhand_bbox_size)
+ face_bbox_center_list.append(face_bbox_center)
+ face_bbox_size_list.append(face_bbox_size)
+ face_bbox_valid_list.append(face_bbox_valid)
+ rhand_bbox_center_list.append(rhand_bbox_center)
+ rhand_bbox_valid_list.append(rhand_bbox_valid)
+ rhand_bbox_size_list.append(rhand_bbox_size)
+ body_bbox_center_list.append(body_bbox_center)
+ body_bbox_size_list.append(body_bbox_size)
+ body_bbox_valid_list.append(body_bbox_valid)
+
+
+ body_bbox = np.stack(body_bbox_list, axis=0)
+ lhand_bbox = np.stack(lhand_bbox_list, axis=0)
+ rhand_bbox = np.stack(rhand_bbox_list, axis=0)
+ face_bbox = np.stack(face_bbox_list, axis=0)
+ lhand_bbox_center = np.stack(lhand_bbox_center_list, axis=0)
+ lhand_bbox_valid = np.stack(lhand_bbox_valid_list, axis=0)
+ lhand_bbox_size = np.stack(lhand_bbox_size_list, axis=0)
+ face_bbox_center = np.stack(face_bbox_center_list, axis=0)
+ face_bbox_size = np.stack(face_bbox_size_list, axis=0)
+ face_bbox_valid = np.stack(face_bbox_valid_list, axis=0)
+ body_bbox_center = np.stack(body_bbox_center_list, axis=0)
+ body_bbox_size = np.stack(body_bbox_size_list, axis=0)
+ body_bbox_valid = np.stack(body_bbox_valid_list, axis=0)
+ rhand_bbox_center = np.stack(rhand_bbox_center_list, axis=0)
+ rhand_bbox_valid = np.stack(rhand_bbox_valid_list, axis=0)
+ rhand_bbox_size = np.stack(rhand_bbox_size_list, axis=0)
+
+ inputs = {'img': img}
+
+ # joint_img_aug[:,:,2] = joint_img_aug[:,:,2] * body_bbox_valid[:,None]
+
+ is_3D = float(False) if dummy_cord else float(True)
+ if self.__class__.__name__ == 'COCO_NA':
+ is_3D = False
+ if self.__class__.__name__ == 'GTA_Human2':
+ smplx_shape_valid = smplx_shape_valid * 0
+ if self.__class__.__name__ == 'PoseTrack' or self.__class__.__name__ == 'MPII_MM' \
+ or self.__class__.__name__ == 'CrowdPose' or self.__class__.__name__ == 'UBody_MM' \
+ or self.__class__.__name__ == 'COCO_NA':
+ joint_cam_ra[...,-1] = joint_cam_ra[...,-1] * smplx_joint_valid[...,0]
+ joint_cam_wo_ra[...,-1] = joint_cam_wo_ra[...,-1] * smplx_joint_valid[...,0]
+ joint_img_aug[...,-1] = joint_img_aug[...,-1] * smplx_joint_valid[...,0]
+ # if body_bbox_valid.sum() > 0:
+
+
+ targets = {
+ # keypoints2d, [0,img_w],[0,img_h] -> [0,1] -> [0,output_hm_shape]
+ 'joint_img': joint_img_aug[body_bbox_valid>0],
+ # joint_cam, kp3d wo ra # raw kps3d probably without ra
+ 'joint_cam': joint_cam_wo_ra[body_bbox_valid>0],
+ # kps3d with body, face, hand ra
+ 'smplx_joint_cam': joint_cam_ra[body_bbox_valid>0],
+ 'smplx_pose': smplx_pose[body_bbox_valid>0],
+ 'smplx_shape': smplx_shape[body_bbox_valid>0],
+ 'smplx_expr': smplx_expr[body_bbox_valid>0],
+ 'lhand_bbox_center': lhand_bbox_center[body_bbox_valid>0],
+ 'lhand_bbox_size': lhand_bbox_size[body_bbox_valid>0],
+ 'rhand_bbox_center': rhand_bbox_center[body_bbox_valid>0],
+ 'rhand_bbox_size': rhand_bbox_size[body_bbox_valid>0],
+ 'face_bbox_center': face_bbox_center[body_bbox_valid>0],
+ 'face_bbox_size': face_bbox_size[body_bbox_valid>0],
+ 'body_bbox_center': body_bbox_center[body_bbox_valid>0],
+ 'body_bbox_size': body_bbox_size[body_bbox_valid>0],
+ 'body_bbox': body_bbox.reshape(-1,4)[body_bbox_valid>0],
+ 'lhand_bbox': lhand_bbox.reshape(-1,4)[body_bbox_valid>0],
+ 'rhand_bbox': rhand_bbox.reshape(-1,4)[body_bbox_valid>0],
+ 'face_bbox': face_bbox.reshape(-1,4)[body_bbox_valid>0],
+ 'gender': gender[body_bbox_valid>0]}
+
+ meta_info = {
+ 'joint_trunc': joint_trunc[body_bbox_valid>0],
+ 'smplx_pose_valid': smplx_pose_valid[body_bbox_valid>0],
+ 'smplx_shape_valid': smplx_shape_valid[body_bbox_valid>0],
+ 'smplx_expr_valid': smplx_expr_valid[body_bbox_valid>0],
+ 'is_3D': is_3D,
+ 'lhand_bbox_valid': lhand_bbox_valid[body_bbox_valid>0],
+ 'rhand_bbox_valid': rhand_bbox_valid[body_bbox_valid>0],
+ 'face_bbox_valid': face_bbox_valid[body_bbox_valid>0],
+ 'body_bbox_valid': body_bbox_valid[body_bbox_valid>0],
+ 'img_shape': np.array(img.shape[:2]),
+ 'ori_shape':data['img_shape'],
+ 'idx': idx
+
+ }
+
+ result = {**inputs, **targets, **meta_info}
+
+ result = self.normalize(result)
+ result = self.format(result)
+ return result
+
+
+
+ if self.data_split == 'test':
+ self.cam_param = {}
+ joint_cam = data['joint_cam']
+
+ if joint_cam is not None:
+ dummy_cord = False
+ joint_cam[:,:,:3] = joint_cam[:,:,:3] - joint_cam[
+ :, self.joint_set['root_joint_idx'], None, :3] # root-relative
+ else:
+ # dummy cord as joint_cam
+ dummy_cord = True
+ joint_cam = np.zeros(
+ (num_person, self.joint_set['joint_num'], 3),
+ dtype=np.float32)
+
+ joint_img = data['joint_img']
+
+
+ joint_img_aug, joint_cam_wo_ra, joint_cam_ra, joint_trunc = \
+ process_db_coord_batch_no_valid(
+ joint_img, joint_cam, do_flip, img_shape,
+ self.joint_set['flip_pairs'], img2bb_trans, rot,
+ self.joint_set['joints_name'], smpl_x.joints_name,
+ cropped_img_shape)
+
+
+
+ # smplx coordinates and parameters
+ smplx_param = data['smplx_param']
+ # smplx_cam_trans = np.array(
+ # smplx_param['trans']) if 'trans' in smplx_param else None
+ # TODO: remove this, seperate smpl and smplx
+ smplx_pose, smplx_shape, smplx_expr, smplx_pose_valid, \
+ smplx_joint_valid, smplx_expr_valid, smplx_shape_valid = \
+ process_human_model_output_batch_simplify(
+ smplx_param, do_flip, rot, as_smplx)
+ # if cam not provided, we take joint_img as smplx joint 2d,
+ # which is commonly the case for our processed humandata
+ if self.use_betas_neutral:
+ smplx_shape = smplx_param['betas_neutral'].reshape(
+ num_person, -1)
+ smplx_shape[(np.abs(smplx_shape) > 3).any(axis=1)] = 0.
+ smplx_shape = smplx_shape.reshape(num_person, -1)
+ # smplx_pose_valid = np.tile(smplx_pose_valid[:,:, None], (1, 3)).reshape(num_person,-1)
+ smplx_joint_valid = smplx_joint_valid[:, :, None]
+
+ # if not (smplx_shape == 0).all():
+ # smplx_shape_valid = True
+ # else:
+ # smplx_shape_valid = False
+ lhand_bbox_center_list = []
+ lhand_bbox_valid_list = []
+ lhand_bbox_size_list = []
+ lhand_bbox_list = []
+ face_bbox_center_list = []
+ face_bbox_size_list = []
+ face_bbox_valid_list = []
+ face_bbox_list = []
+ rhand_bbox_center_list = []
+ rhand_bbox_valid_list = []
+ rhand_bbox_size_list = []
+ rhand_bbox_list = []
+ body_bbox_center_list = []
+ body_bbox_size_list = []
+ body_bbox_valid_list = []
+ body_bbox_list = []
+
+ for i in range(num_person):
+ lhand_bbox, lhand_bbox_valid = self.process_hand_face_bbox(
+ data['lhand_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ rhand_bbox, rhand_bbox_valid = self.process_hand_face_bbox(
+ data['rhand_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+ face_bbox, face_bbox_valid = self.process_hand_face_bbox(
+ data['face_bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+
+ body_bbox, body_bbox_valid = self.process_hand_face_bbox(
+ data['bbox'][i], do_flip, img_shape, img2bb_trans,
+ cropped_img_shape)
+
+ if do_flip:
+ lhand_bbox, rhand_bbox = rhand_bbox, lhand_bbox
+ lhand_bbox_valid, rhand_bbox_valid = rhand_bbox_valid, lhand_bbox_valid
+
+ body_bbox_list.append(body_bbox)
+ lhand_bbox_list.append(lhand_bbox)
+ rhand_bbox_list.append(rhand_bbox)
+ face_bbox_list.append(face_bbox)
+
+ lhand_bbox_center = (lhand_bbox[0] + lhand_bbox[1]) / 2.
+ rhand_bbox_center = (rhand_bbox[0] + rhand_bbox[1]) / 2.
+ face_bbox_center = (face_bbox[0] + face_bbox[1]) / 2.
+ body_bbox_center = (body_bbox[0] + body_bbox[1]) / 2.
+ lhand_bbox_size = lhand_bbox[1] - lhand_bbox[0]
+ rhand_bbox_size = rhand_bbox[1] - rhand_bbox[0]
+
+ face_bbox_size = face_bbox[1] - face_bbox[0]
+ body_bbox_size = body_bbox[1] - body_bbox[0]
+ lhand_bbox_center_list.append(lhand_bbox_center)
+ lhand_bbox_valid_list.append(lhand_bbox_valid)
+ lhand_bbox_size_list.append(lhand_bbox_size)
+ face_bbox_center_list.append(face_bbox_center)
+ face_bbox_size_list.append(face_bbox_size)
+ face_bbox_valid_list.append(face_bbox_valid)
+ rhand_bbox_center_list.append(rhand_bbox_center)
+ rhand_bbox_valid_list.append(rhand_bbox_valid)
+ rhand_bbox_size_list.append(rhand_bbox_size)
+ body_bbox_center_list.append(body_bbox_center)
+ body_bbox_size_list.append(body_bbox_size)
+ body_bbox_valid_list.append(body_bbox_valid)
+
+ body_bbox = np.stack(body_bbox_list, axis=0)
+ lhand_bbox = np.stack(lhand_bbox_list, axis=0)
+ rhand_bbox = np.stack(rhand_bbox_list, axis=0)
+ face_bbox = np.stack(face_bbox_list, axis=0)
+ lhand_bbox_center = np.stack(lhand_bbox_center_list, axis=0)
+ lhand_bbox_valid = np.stack(lhand_bbox_valid_list, axis=0)
+ lhand_bbox_size = np.stack(lhand_bbox_size_list, axis=0)
+ face_bbox_center = np.stack(face_bbox_center_list, axis=0)
+ face_bbox_size = np.stack(face_bbox_size_list, axis=0)
+ face_bbox_valid = np.stack(face_bbox_valid_list, axis=0)
+ body_bbox_center = np.stack(body_bbox_center_list, axis=0)
+ body_bbox_size = np.stack(body_bbox_size_list, axis=0)
+ body_bbox_valid = np.stack(body_bbox_valid_list, axis=0)
+ rhand_bbox_center = np.stack(rhand_bbox_center_list, axis=0)
+ rhand_bbox_valid = np.stack(rhand_bbox_valid_list, axis=0)
+ rhand_bbox_size = np.stack(rhand_bbox_size_list, axis=0)
+
+
+ inputs = {'img': img}
+
+ targets = {
+ # keypoints2d, [0,img_w],[0,img_h] -> [0,1] -> [0,output_hm_shape]
+ 'joint_img': joint_img_aug,
+ # projected smplx if valid cam_param, else same as keypoints2d
+ # joint_cam, kp3d wo ra # raw kps3d probably without ra
+ 'joint_cam': joint_cam_wo_ra,
+ 'ann_idx': idx,
+ # kps3d with body, face, hand ra
+ 'smplx_joint_cam': joint_cam_ra,
+ 'smplx_pose': smplx_pose,
+ 'smplx_shape': smplx_shape,
+ 'smplx_expr': smplx_expr,
+ 'lhand_bbox_center': lhand_bbox_center,
+ 'lhand_bbox_size': lhand_bbox_size,
+ 'rhand_bbox_center': rhand_bbox_center,
+ 'rhand_bbox_size': rhand_bbox_size,
+ 'face_bbox_center': face_bbox_center,
+ 'face_bbox_size': face_bbox_size,
+ 'body_bbox_center': body_bbox_center,
+ 'body_bbox_size': body_bbox_size,
+ 'body_bbox': body_bbox.reshape(-1,4),
+ 'lhand_bbox': lhand_bbox.reshape(-1,4),
+ 'rhand_bbox': rhand_bbox.reshape(-1,4),
+ 'face_bbox': face_bbox.reshape(-1,4),
+ 'gender': gender,
+ 'bb2img_trans': bb2img_trans,
+ }
+
+ if self.body_only:
+ meta_info = {
+ 'joint_trunc': joint_trunc,
+ 'smplx_pose_valid': smplx_pose_valid,
+ 'smplx_shape_valid': float(smplx_shape_valid),
+ 'smplx_expr_valid': smplx_expr_valid,
+ 'is_3D': float(False) if dummy_cord else float(True),
+ 'lhand_bbox_valid': lhand_bbox_valid,
+ 'rhand_bbox_valid': rhand_bbox_valid,
+ 'face_bbox_valid': face_bbox_valid,
+ 'body_bbox_valid': body_bbox_valid,
+ 'img_shape': np.array(img.shape[:2]),
+ 'ori_shape':data['img_shape'],
+ 'idx': idx
+
+ }
+ else:
+ meta_info = {
+ 'joint_trunc': joint_trunc,
+ 'smplx_pose_valid': smplx_pose_valid,
+ 'smplx_shape_valid': smplx_shape_valid,
+ 'smplx_expr_valid': smplx_expr_valid,
+ 'is_3D': float(False) if dummy_cord else float(True),
+ 'lhand_bbox_valid': lhand_bbox_valid,
+ 'rhand_bbox_valid': rhand_bbox_valid,
+ 'face_bbox_valid': face_bbox_valid,
+ 'body_bbox_valid': body_bbox_valid,
+ 'img_shape': np.array(img.shape[:2]),
+ 'ori_shape':data['img_shape'],
+ 'idx': idx
+ }
+
+ result = {**inputs, **targets, **meta_info}
+ result = self.normalize(result)
+ result = self.format(result)
+ return result
+
+ def process_hand_face_bbox(self, bbox, do_flip, img_shape, img2bb_trans,
+ input_img_shape):
+ if bbox is None:
+ bbox = np.array([0, 0, 1, 1],
+ dtype=np.float32).reshape(2, 2) # dummy value
+ bbox_valid = float(False) # dummy value
+ else:
+ # reshape to top-left (x,y) and bottom-right (x,y)
+ bbox = bbox.reshape(2, 2)
+
+ # flip augmentation
+ if do_flip:
+ bbox[:, 0] = img_shape[1] - bbox[:, 0] - 1
+ bbox[0, 0], bbox[1, 0] = bbox[1, 0].copy(), bbox[
+ 0, 0].copy() # xmin <-> xmax swap
+
+ # make four points of the bbox
+ bbox = bbox.reshape(4).tolist()
+ xmin, ymin, xmax, ymax = bbox
+ bbox = np.array(
+ [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]],
+ dtype=np.float32).reshape(4, 2)
+
+ # affine transformation (crop, rotation, scale)
+ bbox_xy1 = np.concatenate((bbox, np.ones_like(bbox[:, :1])), 1)
+ bbox = np.dot(img2bb_trans,
+ bbox_xy1.transpose(1, 0)).transpose(1, 0)[:, :2]
+
+ # print(bbox)
+ # bbox[:, 0] = bbox[:, 0] / input_img_shape[1] * cfg.output_hm_shape[2]
+ # bbox[:, 1] = bbox[:, 1] / input_img_shape[0] * cfg.output_hm_shape[1]
+
+ bbox[:, 0] /= input_img_shape[1]
+ bbox[:, 1] /= input_img_shape[0]
+
+ # make box a rectangle without rotation
+ if np.max(bbox[:,0])<=0 or np.min(bbox[:,0])>=1 or np.max(bbox[:,1])<=0 or np.min(bbox[:,1])>=1:
+ bbox_valid = float(False)
+ bbox = np.array([0, 0, 1, 1], dtype=np.float32)
+ else:
+ xmin = np.max([np.min(bbox[:, 0]), 0])
+ xmax = np.min([np.max(bbox[:, 0]), 1])
+ ymin = np.max([np.min(bbox[:, 1]), 0])
+ ymax = np.min([np.max(bbox[:, 1]), 1])
+ bbox = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
+
+ bbox = np.clip(bbox,0,1)
+ bbox_valid = float(True)
+ bbox = bbox.reshape(2, 2)
+
+ return bbox, bbox_valid
+
+ def evaluate(self, outs, cur_sample_idx=None):
+ annots = self.datalist
+ sample_num = len(outs)
+ eval_result = {
+ 'pa_mpvpe_all': [],
+ 'pa_mpvpe_l_hand': [],
+ 'pa_mpvpe_r_hand': [],
+ 'pa_mpvpe_hand': [],
+ 'pa_mpvpe_face': [],
+ 'mpvpe_all': [],
+ 'mpvpe_l_hand': [],
+ 'mpvpe_r_hand': [],
+ 'mpvpe_hand': [],
+ 'mpvpe_face': [],
+ 'pa_mpjpe_body': [],
+ 'pa_mpjpe_l_hand': [],
+ 'pa_mpjpe_r_hand': [],
+ 'pa_mpjpe_hand': []
+ }
+
+ for n in range(sample_num):
+ out = outs[n]
+ ann_idx = out['gt_ann_idx']
+ mesh_gt = out['smplx_mesh_cam_pseudo_gt']
+ mesh_out = out['smplx_mesh_cam']
+ cam_trans = out['cam_trans']
+ ann_idx = out['gt_ann_idx']
+ img_path = []
+ for ann_id in ann_idx:
+ img_path.append(annots[ann_id]['img_path'])
+ eval_result['img_path'] = img_path
+ eval_result['ann_idx'] = ann_idx
+
+ img = out['img']
+ # MPVPE from all vertices
+ mesh_out_align = mesh_out - np.dot(
+ smpl_x.J_regressor,
+ mesh_out)[smpl_x.J_regressor_idx['pelvis'], None, :] + np.dot(
+ smpl_x.J_regressor,
+ mesh_gt)[smpl_x.J_regressor_idx['pelvis'], None, :]
+ eval_result['mpvpe_all'].append(
+ np.sqrt(np.sum(
+ (mesh_out_align - mesh_gt)**2, 1)).mean() * 1000)
+ mesh_out_align = rigid_align(mesh_out, mesh_gt)
+ eval_result['pa_mpvpe_all'].append(
+ np.sqrt(np.sum(
+ (mesh_out_align - mesh_gt)**2, 1)).mean() * 1000)
+ # MPVPE from hand vertices
+ mesh_gt_lhand = mesh_gt[smpl_x.hand_vertex_idx['left_hand'], :]
+ mesh_out_lhand = mesh_out[smpl_x.hand_vertex_idx['left_hand'], :]
+ mesh_gt_rhand = mesh_gt[smpl_x.hand_vertex_idx['right_hand'], :]
+ mesh_out_rhand = mesh_out[smpl_x.hand_vertex_idx['right_hand'], :]
+ mesh_out_lhand_align = mesh_out_lhand - np.dot(
+ smpl_x.J_regressor,
+ mesh_out)[smpl_x.J_regressor_idx['lwrist'], None, :] + np.dot(
+ smpl_x.J_regressor,
+ mesh_gt)[smpl_x.J_regressor_idx['lwrist'], None, :]
+ mesh_out_rhand_align = mesh_out_rhand - np.dot(
+ smpl_x.J_regressor,
+ mesh_out)[smpl_x.J_regressor_idx['rwrist'], None, :] + np.dot(
+ smpl_x.J_regressor,
+ mesh_gt)[smpl_x.J_regressor_idx['rwrist'], None, :]
+ eval_result['mpvpe_l_hand'].append(
+ np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, 1)).mean() *
+ 1000)
+ eval_result['mpvpe_r_hand'].append(
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, 1)).mean() *
+ 1000)
+ eval_result['mpvpe_hand'].append(
+ (np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, 1)).mean() *
+ 1000 +
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, 1)).mean() *
+ 1000) / 2.)
+ mesh_out_lhand_align = rigid_align(mesh_out_lhand, mesh_gt_lhand)
+ mesh_out_rhand_align = rigid_align(mesh_out_rhand, mesh_gt_rhand)
+ eval_result['pa_mpvpe_l_hand'].append(
+ np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, 1)).mean() *
+ 1000)
+ eval_result['pa_mpvpe_r_hand'].append(
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, 1)).mean() *
+ 1000)
+ eval_result['pa_mpvpe_hand'].append(
+ (np.sqrt(np.sum(
+ (mesh_out_lhand_align - mesh_gt_lhand)**2, 1)).mean() *
+ 1000 +
+ np.sqrt(np.sum(
+ (mesh_out_rhand_align - mesh_gt_rhand)**2, 1)).mean() *
+ 1000) / 2.)
+
+ if self.__class__.__name__ == 'UBody':
+ joint_gt_body_wo_trans = np.dot(smpl_x.j14_regressor,
+ mesh_gt)
+ import ipdb;ipdb.set_trace()
+ img_wh = out['gt_img_shape'].flip(-1)
+ joint_gt_body_proj = project_points_new(
+ points_3d=joint_gt_body_wo_trans,
+ pred_cam=cam_trans,
+ focal_length=5000,
+ camera_center=img_wh/2
+ ) # origin image space
+ joint_gt_lhand_wo_trans = np.dot(
+ smpl_x.orig_hand_regressor['left'], mesh_gt)
+ joint_gt_lhand_proj = project_points_new(
+ points_3d=joint_gt_lhand_wo_trans,
+ pred_cam=cam_trans,
+ focal_length=5000,
+ camera_center=img_wh/2
+ ) # origin image space
+ joint_gt_rhand_wo_trans = np.dot(
+ smpl_x.orig_hand_regressor['left'], mesh_gt)
+ joint_gt_rhand_proj = project_points_new(
+ points_3d=joint_gt_rhand_wo_trans,
+ pred_cam=cam_trans,
+ focal_length=5000,
+ camera_center=img_wh/2
+ ) # origin image space
+ mesh_gt_proj = project_points_new(
+ points_3d=mesh_gt,
+ pred_cam=cam_trans,
+ focal_length=5000,
+ camera_center=img_wh/2)
+ joint_gt_body_valid = self.validate_within_img(
+ img, joint_gt_body_proj)
+ joint_gt_lhand_valid = self.validate_within_img(
+ img, joint_gt_lhand_proj)
+ joint_gt_rhand_valid = self.validate_within_img(
+ img, joint_gt_rhand_proj)
+ mesh_valid = self.validate_within_img(img, mesh_gt_proj)
+ mesh_lhand_valid = mesh_valid[smpl_x.hand_vertex_idx['left_hand']]
+ mesh_rhand_valid = mesh_valid[smpl_x.hand_vertex_idx['right_hand']]
+ mesh_face_valid = mesh_valid[smpl_x.face_vertex_idx]
+
+ # MPVPE from face vertices
+ mesh_gt_face = mesh_gt[smpl_x.face_vertex_idx, :]
+ mesh_out_face = mesh_out[smpl_x.face_vertex_idx, :]
+ mesh_out_face_align = mesh_out_face - np.dot(
+ smpl_x.J_regressor,
+ mesh_out)[smpl_x.J_regressor_idx['neck'], None, :] + np.dot(
+ smpl_x.J_regressor,
+ mesh_gt)[smpl_x.J_regressor_idx['neck'], None, :]
+ eval_result['mpvpe_face'].append(
+ np.sqrt(np.sum(
+ (mesh_out_face_align - mesh_gt_face)**2, 1)).mean() * 1000)
+ mesh_out_face_align = rigid_align(mesh_out_face, mesh_gt_face)
+ eval_result['pa_mpvpe_face'].append(
+ np.sqrt(np.sum(
+ (mesh_out_face_align - mesh_gt_face)**2, 1)).mean() * 1000)
+
+ # MPJPE from body joints
+ joint_gt_body = np.dot(smpl_x.j14_regressor, mesh_gt)
+ joint_out_body = np.dot(smpl_x.j14_regressor, mesh_out)
+ joint_out_body_align = rigid_align(joint_out_body, joint_gt_body)
+ eval_result['pa_mpjpe_body'].append(
+ np.sqrt(np.sum((joint_out_body_align - joint_gt_body)**2,
+ 1))[joint_gt_body_valid].mean() * 1000)
+
+ # eval_result['pa_mpjpe_body'].append(
+ # np.sqrt(np.sum(
+ # (joint_out_body_align - joint_gt_body)**2, 1)).mean() *
+ # 1000)
+
+ # MPJPE from hand joints
+ joint_gt_lhand = np.dot(smpl_x.orig_hand_regressor['left'],
+ mesh_gt)
+ joint_out_lhand = np.dot(smpl_x.orig_hand_regressor['left'],
+ mesh_out)
+ joint_out_lhand_align = rigid_align(joint_out_lhand,
+ joint_gt_lhand)
+ joint_gt_rhand = np.dot(smpl_x.orig_hand_regressor['right'],
+ mesh_gt)
+ joint_out_rhand = np.dot(smpl_x.orig_hand_regressor['right'],
+ mesh_out)
+ joint_out_rhand_align = rigid_align(joint_out_rhand,
+ joint_gt_rhand)
+ # if self.__class__.__name__ == 'UBody':
+ if sum(joint_gt_lhand_valid) != 0:
+ pa_mpjpe_lhand = np.sqrt(
+ np.sum((joint_out_lhand_align - joint_gt_lhand)**2,
+ 1))[joint_gt_lhand_valid].mean() * 1000
+ pa_mpjpe_hand.append(pa_mpjpe_lhand)
+ eval_result['pa_mpjpe_l_hand'].append(pa_mpjpe_lhand)
+ if sum(joint_gt_rhand_valid) != 0:
+ pa_mpjpe_rhand = np.sqrt(
+ np.sum((joint_out_rhand_align - joint_gt_rhand)**2,
+ 1))[joint_gt_rhand_valid].mean() * 1000
+ pa_mpjpe_hand.append(pa_mpjpe_rhand)
+ eval_result['pa_mpjpe_r_hand'].append(pa_mpjpe_rhand)
+ if len(pa_mpjpe_hand) > 0:
+ eval_result['pa_mpjpe_hand'].append(np.mean(pa_mpjpe_hand))
+
+ eval_result['pa_mpjpe_l_hand'].append(
+ np.sqrt(np.sum(
+ (joint_out_lhand_align - joint_gt_lhand)**2, 1)).mean() *
+ 1000)
+ eval_result['pa_mpjpe_r_hand'].append(
+ np.sqrt(np.sum(
+ (joint_out_rhand_align - joint_gt_rhand)**2, 1)).mean() *
+ 1000)
+ eval_result['pa_mpjpe_hand'].append(
+ (np.sqrt(np.sum(
+ (joint_out_lhand_align - joint_gt_lhand)**2, 1)).mean() *
+ 1000 +
+ np.sqrt(np.sum(
+ (joint_out_rhand_align - joint_gt_rhand)**2, 1)).mean() *
+ 1000) / 2.)
+ return eval_result
+
+ def print_eval_result(self, eval_result):
+ print(f'======{cfg.testset}======')
+ print('PA MPVPE (All): %.2f mm' % np.mean(eval_result['pa_mpvpe_all']))
+ print('PA MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_l_hand']))
+ print('PA MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_r_hand']))
+ print('PA MPVPE (Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_hand']))
+ print('PA MPVPE (Face): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_face']))
+ print()
+
+ print('MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all']))
+ print('MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_l_hand']))
+ print('MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_r_hand']))
+ print('MPVPE (Hands): %.2f mm' % np.mean(eval_result['mpvpe_hand']))
+ print('MPVPE (Face): %.2f mm' % np.mean(eval_result['mpvpe_face']))
+ print()
+
+ print('PA MPJPE (Body): %.2f mm' %
+ np.mean(eval_result['pa_mpjpe_body']))
+ print('PA MPJPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpjpe_l_hand']))
+ print('PA MPJPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpjpe_r_hand']))
+ print('PA MPJPE (Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpjpe_hand']))
+
+ f = open(os.path.join(cfg.result_dir, 'result.txt'), 'w')
+ f.write(f'{cfg.testset} dataset \n')
+ f.write('PA MPVPE (All): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_all']))
+ f.write('PA MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_l_hand']))
+ f.write('PA MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpvpe_r_hand']))
+ f.write('PA MPVPE (Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_hand']))
+ f.write('PA MPVPE (Face): %.2f mm\n' %
+ np.mean(eval_result['pa_mpvpe_face']))
+ f.write('MPVPE (All): %.2f mm\n' % np.mean(eval_result['mpvpe_all']))
+ f.write('MPVPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_l_hand']))
+ f.write('MPVPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['mpvpe_r_hand']))
+ f.write('MPVPE (Hands): %.2f mm' % np.mean(eval_result['mpvpe_hand']))
+ f.write('MPVPE (Face): %.2f mm\n' % np.mean(eval_result['mpvpe_face']))
+ f.write('PA MPJPE (Body): %.2f mm\n' %
+ np.mean(eval_result['pa_mpjpe_body']))
+ f.write('PA MPJPE (L-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpjpe_l_hand']))
+ f.write('PA MPJPE (R-Hands): %.2f mm' %
+ np.mean(eval_result['pa_mpjpe_r_hand']))
+ f.write('PA MPJPE (Hands): %.2f mm\n' %
+ np.mean(eval_result['pa_mpjpe_hand']))
+ def validate_within_img_batch(
+ self, img_wh, points): # check whether the points is within the image
+ # img: (h, w, c), points: (num_points, 2)
+
+ valid_mask = np.logical_and((points-img_wh[:,None])<0,points>0)
+ valid_mask = np.logical_and(valid_mask[:,:,0],valid_mask[:,:,1])
+
+ return valid_mask
+ def decompress_keypoints(self, humandata) -> None:
+ """If a key contains 'keypoints', and f'{key}_mask' is in self.keys(),
+ invalid zeros will be inserted to the right places and f'{key}_mask'
+ will be unlocked.
+
+ Raises:
+ KeyError:
+ A key contains 'keypoints' has been found
+ but its corresponding mask is missing.
+ """
+ assert bool(humandata['__keypoints_compressed__']) is True
+ key_pairs = []
+ for key in humandata.files:
+ if key not in KPS2D_KEYS + KPS3D_KEYS:
+ continue
+ mask_key = f'{key}_mask'
+ if mask_key in humandata.files:
+ print(f'Decompress {key}...')
+ key_pairs.append([key, mask_key])
+ decompressed_dict = {}
+ for kpt_key, mask_key in key_pairs:
+ mask_array = np.asarray(humandata[mask_key])
+ compressed_kpt = humandata[kpt_key]
+ kpt_array = \
+ self.add_zero_pad(compressed_kpt, mask_array)
+ decompressed_dict[kpt_key] = kpt_array
+ del humandata
+ return decompressed_dict
+
+ def add_zero_pad(self, compressed_array: np.ndarray,
+ mask_array: np.ndarray) -> np.ndarray:
+ """Pad zeros to a compressed keypoints array.
+
+ Args:
+ compressed_array (np.ndarray):
+ A compressed keypoints array.
+ mask_array (np.ndarray):
+ The mask records compression relationship.
+
+ Returns:
+ np.ndarray:
+ A keypoints array in full-size.
+ """
+ assert mask_array.sum() == compressed_array.shape[1]
+ data_len, _, dim = compressed_array.shape
+ mask_len = mask_array.shape[0]
+ ret_value = np.zeros(shape=[data_len, mask_len, dim],
+ dtype=compressed_array.dtype)
+ valid_mask_index = np.where(mask_array == 1)[0]
+ ret_value[:, valid_mask_index, :] = compressed_array
+ return ret_value
diff --git a/detrsmpl/__init__.py b/detrsmpl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac52099185f80381d9e8402ac74fc292f332ad91
--- /dev/null
+++ b/detrsmpl/__init__.py
@@ -0,0 +1,28 @@
+import mmcv
+
+from .version import __version__
+
+
+def digit_version(version_str):
+ digit_version = []
+ for x in version_str.split('.'):
+ if x.isdigit():
+ digit_version.append(int(x))
+ elif x.find('rc') != -1:
+ patch_version = x.split('rc')
+ digit_version.append(int(patch_version[0]) - 1)
+ digit_version.append(int(patch_version[1]))
+ return digit_version
+
+
+mmcv_minimum_version = '1.3.17'
+mmcv_maximum_version = '1.7.1'
+mmcv_version = digit_version(mmcv.__version__)
+
+
+assert (mmcv_version >= digit_version(mmcv_minimum_version)
+ and mmcv_version <= digit_version(mmcv_maximum_version)), \
+ f'MMCV=={mmcv.__version__} is used but incompatible. ' \
+ f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.'
+
+__all__ = ['__version__']
diff --git a/detrsmpl/apis/__init__.py b/detrsmpl/apis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1a2da2c226d4ec3b30cd2a32ac90e60ba803408
--- /dev/null
+++ b/detrsmpl/apis/__init__.py
@@ -0,0 +1,12 @@
+
+from detrsmpl.apis.test import (
+ collect_results_cpu,
+ collect_results_gpu,
+ multi_gpu_test,
+ single_gpu_test,
+)
+from detrsmpl.apis.train import set_random_seed, train_model
+
+__all__ = [
+ 'collect_results_cpu', 'collect_results_gpu','multi_gpu_test','single_gpu_test'
+]
diff --git a/detrsmpl/apis/inference.py b/detrsmpl/apis/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb18c9c2498ee02548b089092a73487998671cce
--- /dev/null
+++ b/detrsmpl/apis/inference.py
@@ -0,0 +1,518 @@
+import cv2
+import mmcv
+import numpy as np
+import torch
+from mmcv.parallel import collate
+from mmcv.runner import load_checkpoint
+
+from detrsmpl.data.datasets.pipelines import Compose
+from detrsmpl.models.architectures.builder import build_architecture
+from detrsmpl.models.backbones.builder import build_backbone
+from detrsmpl.utils.demo_utils import box2cs, xywh2xyxy, xyxy2xywh
+
+
+def init_model(config, checkpoint=None, device='cuda:0'):
+ """Initialize a model from config file.
+
+ Args:
+ config (str or :obj:`mmcv.Config`): Config file path or the config
+ object.
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
+ will not load any weights.
+
+ Returns:
+ nn.Module: The constructed model.
+ (nn.Module, None): The constructed extractor model
+ """
+ if isinstance(config, str):
+ config = mmcv.Config.fromfile(config)
+ elif not isinstance(config, mmcv.Config):
+ raise TypeError('config must be a filename or Config object, '
+ f'but got {type(config)}')
+ config.data.test.test_mode = True
+
+ model = build_architecture(config.model)
+ if checkpoint is not None:
+ # load model checkpoint
+ load_checkpoint(model, checkpoint, map_location=device)
+ # save the config in the model for convenience
+ model.cfg = config
+ model.to(device)
+ model.eval()
+
+ extractor = None
+ if config.model.type == 'VideoBodyModelEstimator':
+ extractor = build_backbone(config.extractor.backbone)
+ if config.extractor.checkpoint is not None:
+ # load model checkpoint
+ load_checkpoint(extractor, config.extractor.checkpoint)
+ extractor.cfg = config
+ extractor.to(device)
+ extractor.eval()
+ return model, extractor
+
+
+class LoadImage:
+ """A simple pipeline to load image."""
+ def __init__(self, color_type='color', channel_order='bgr'):
+ self.color_type = color_type
+ self.channel_order = channel_order
+
+ def __call__(self, results):
+ """Call function to load images into results.
+
+ Args:
+ results (dict): A result dict contains the image_path.
+
+ Returns:
+ dict: ``results`` will be returned containing loaded image.
+ """
+ if isinstance(results['image_path'], str):
+ results['image_file'] = results['image_path']
+ img = mmcv.imread(results['image_path'], self.color_type,
+ self.channel_order)
+ elif isinstance(results['image_path'], np.ndarray):
+ results['image_file'] = ''
+ if self.color_type == 'color' and self.channel_order == 'rgb':
+ img = cv2.cvtColor(results['image_path'], cv2.COLOR_BGR2RGB)
+ else:
+ img = results['image_path']
+ else:
+ raise TypeError('"image_path" must be a numpy array or a str or '
+ 'a pathlib.Path object')
+
+ results['img'] = img
+ return results
+
+
+def inference_image_based_model(
+ model,
+ img_or_path,
+ det_results,
+ bbox_thr=None,
+ format='xywh',
+):
+ """Inference a single image with a list of person bounding boxes.
+
+ Args:
+ model (nn.Module): The loaded pose model.
+ img_or_path (Union[str, np.ndarray]): Image filename or loaded image.
+ det_results (List(dict)): the item in the dict may contain
+ 'bbox' and/or 'track_id'.
+ 'bbox' (4, ) or (5, ): The person bounding box, which contains
+ 4 box coordinates (and score).
+ 'track_id' (int): The unique id for each human instance.
+ bbox_thr (float, optional): Threshold for bounding boxes.
+ Only bboxes with higher scores will be fed into the pose detector.
+ If bbox_thr is None, ignore it. Defaults to None.
+ format (str, optional): bbox format ('xyxy' | 'xywh'). Default: 'xywh'.
+ 'xyxy' means (left, top, right, bottom),
+ 'xywh' means (left, top, width, height).
+
+ Returns:
+ list[dict]: Each item in the list is a dictionary,
+ containing the bbox: (left, top, right, bottom, [score]),
+ SMPL parameters, vertices, kp3d, and camera.
+ """
+ # only two kinds of bbox format is supported.
+ assert format in ['xyxy', 'xywh']
+ mesh_results = []
+ if len(det_results) == 0:
+ return []
+
+ # Change for-loop preprocess each bbox to preprocess all bboxes at once.
+ bboxes = np.array([box['bbox'] for box in det_results])
+
+ # Select bboxes by score threshold
+ if bbox_thr is not None:
+ assert bboxes.shape[1] == 5
+ valid_idx = np.where(bboxes[:, 4] > bbox_thr)[0]
+ bboxes = bboxes[valid_idx]
+ det_results = [det_results[i] for i in valid_idx]
+
+ if format == 'xyxy':
+ bboxes_xyxy = bboxes
+ bboxes_xywh = xyxy2xywh(bboxes)
+ else:
+ # format is already 'xywh'
+ bboxes_xywh = bboxes
+ bboxes_xyxy = xywh2xyxy(bboxes)
+
+ # if bbox_thr remove all bounding box
+ if len(bboxes_xywh) == 0:
+ return []
+
+ cfg = model.cfg
+ device = next(model.parameters()).device
+
+ # build the data pipeline
+ inference_pipeline = [LoadImage()] + cfg.inference_pipeline
+ inference_pipeline = Compose(inference_pipeline)
+
+ assert len(bboxes[0]) in [4, 5]
+
+ batch_data = []
+ input_size = cfg['img_res']
+ aspect_ratio = 1 if isinstance(input_size,
+ int) else input_size[0] / input_size[1]
+
+ for i, bbox in enumerate(bboxes_xywh):
+ center, scale = box2cs(bbox, aspect_ratio, bbox_scale_factor=1.25)
+ # prepare data
+ data = {
+ 'image_path': img_or_path,
+ 'center': center,
+ 'scale': scale,
+ 'rotation': 0,
+ 'bbox_score': bbox[4] if len(bbox) == 5 else 1,
+ 'sample_idx': i,
+ }
+ data = inference_pipeline(data)
+ batch_data.append(data)
+
+ batch_data = collate(batch_data, samples_per_gpu=1)
+
+ if next(model.parameters()).is_cuda:
+ # scatter not work so just move image to cuda device
+ batch_data['img'] = batch_data['img'].to(device)
+
+ # get all img_metas of each bounding box
+ batch_data['img_metas'] = [
+ img_metas[0] for img_metas in batch_data['img_metas'].data
+ ]
+
+ # forward the model
+ with torch.no_grad():
+ results = model(
+ img=batch_data['img'],
+ img_metas=batch_data['img_metas'],
+ sample_idx=batch_data['sample_idx'],
+ )
+
+ for idx in range(len(det_results)):
+ mesh_result = det_results[idx].copy()
+ mesh_result['bbox'] = bboxes_xyxy[idx]
+ mesh_result['camera'] = results['camera'][idx]
+ mesh_result['smpl_pose'] = results['smpl_pose'][idx]
+ mesh_result['smpl_beta'] = results['smpl_beta'][idx]
+ mesh_result['vertices'] = results['vertices'][idx]
+ mesh_result['keypoints_3d'] = results['keypoints_3d'][idx]
+ mesh_results.append(mesh_result)
+ return mesh_results
+
+
+def inference_video_based_model(model,
+ extracted_results,
+ with_track_id=True,
+ causal=True):
+ """Inference SMPL parameters from extracted featutres using a video-based
+ model.
+
+ Args:
+ model (nn.Module): The loaded mesh estimation model.
+ extracted_results (List[List[Dict]]): Multi-frame feature extraction
+ results stored in a nested list. Each element of the outer list
+ is the feature extraction results of a single frame, and each
+ element of the inner list is the feature information of one person,
+ which contains:
+ features (ndarray): extracted features
+ track_id (int): unique id of each person, required when
+ ``with_track_id==True```
+ bbox ((4, ) or (5, )): left, right, top, bottom, [score]
+ with_track_id: If True, the element in extracted_results is expected to
+ contain "track_id", which will be used to gather the feature
+ sequence of a person from multiple frames. Otherwise, the extracted
+ results in each frame are expected to have a consistent number and
+ order of identities. Default is True.
+ causal (bool): If True, the target frame is the first frame in
+ a sequence. Otherwise, the target frame is in the middle of a
+ sequence.
+
+ Returns:
+ list[dict]: Each item in the list is a dictionary, which contains:
+ SMPL parameters, vertices, kp3d, and camera.
+ """
+ cfg = model.cfg
+ device = next(model.parameters()).device
+ seq_len = cfg.data.test.seq_len
+ mesh_results = []
+ # build the data pipeline
+ inference_pipeline = Compose(cfg.inference_pipeline)
+ target_idx = 0 if causal else len(extracted_results) // 2
+
+ input_features = _gather_input_features(extracted_results)
+ feature_sequences = _collate_feature_sequence(input_features,
+ with_track_id, target_idx)
+ if not feature_sequences:
+ return mesh_results
+
+ batch_data = []
+
+ for i, seq in enumerate(feature_sequences):
+
+ data = {
+ 'features': seq['features'],
+ 'sample_idx': i,
+ }
+
+ data = inference_pipeline(data)
+ batch_data.append(data)
+
+ batch_data = collate(batch_data, samples_per_gpu=len(batch_data))
+
+ if next(model.parameters()).is_cuda:
+ # scatter not work so just move image to cuda device
+ batch_data['features'] = batch_data['features'].to(device)
+
+ with torch.no_grad():
+ results = model(features=batch_data['features'],
+ img_metas=batch_data['img_metas'],
+ sample_idx=batch_data['sample_idx'])
+
+ results['camera'] = results['camera'].reshape(-1, seq_len, 3)
+ results['smpl_pose'] = results['smpl_pose'].reshape(-1, seq_len, 24, 3, 3)
+ results['smpl_beta'] = results['smpl_beta'].reshape(-1, seq_len, 10)
+ results['vertices'] = results['vertices'].reshape(-1, seq_len, 6890, 3)
+ results['keypoints_3d'] = results['keypoints_3d'].reshape(
+ -1, seq_len, 17, 3)
+
+ for idx in range(len(feature_sequences)):
+ mesh_result = dict()
+ mesh_result['camera'] = results['camera'][idx, target_idx]
+ mesh_result['smpl_pose'] = results['smpl_pose'][idx, target_idx]
+ mesh_result['smpl_beta'] = results['smpl_beta'][idx, target_idx]
+ mesh_result['vertices'] = results['vertices'][idx, target_idx]
+ mesh_result['keypoints_3d'] = results['keypoints_3d'][idx, target_idx]
+ mesh_result['bbox'] = extracted_results[target_idx][idx]['bbox']
+ # 'track_id' is not included in results generated by mmdet
+ if 'track_id' in extracted_results[target_idx][idx].keys():
+ mesh_result['track_id'] = extracted_results[target_idx][idx][
+ 'track_id']
+ mesh_results.append(mesh_result)
+ return mesh_results
+
+
+def feature_extract(
+ model,
+ img_or_path,
+ det_results,
+ bbox_thr=None,
+ format='xywh',
+):
+ """Extract image features with a list of person bounding boxes.
+
+ Args:
+ model (nn.Module): The loaded feature extraction model.
+ img_or_path (Union[str, np.ndarray]): Image filename or loaded image.
+ det_results (List(dict)): the item in the dict may contain
+ 'bbox' and/or 'track_id'.
+ 'bbox' (4, ) or (5, ): The person bounding box, which contains
+ 4 box coordinates (and score).
+ 'track_id' (int): The unique id for each human instance.
+ bbox_thr (float, optional): Threshold for bounding boxes.
+ If bbox_thr is None, ignore it. Defaults to None.
+ format (str, optional): bbox format. Default: 'xywh'.
+ 'xyxy' means (left, top, right, bottom),
+ 'xywh' means (left, top, width, height).
+
+ Returns:
+ list[dict]: The bbox & pose info,
+ containing the bbox: (left, top, right, bottom, [score])
+ and the features.
+ """
+ # only two kinds of bbox format is supported.
+ assert format in ['xyxy', 'xywh']
+
+ cfg = model.cfg
+ device = next(model.parameters()).device
+
+ feature_results = []
+ if len(det_results) == 0:
+ return feature_results
+
+ # Change for-loop preprocess each bbox to preprocess all bboxes at once.
+ bboxes = np.array([box['bbox'] for box in det_results])
+ assert len(bboxes[0]) in [4, 5]
+
+ # Select bboxes by score threshold
+ if bbox_thr is not None:
+ assert bboxes.shape[1] == 5
+ valid_idx = np.where(bboxes[:, 4] > bbox_thr)[0]
+ bboxes = bboxes[valid_idx]
+ det_results = [det_results[i] for i in valid_idx]
+
+ # if bbox_thr remove all bounding box
+ if len(bboxes) == 0:
+ return feature_results
+
+ if format == 'xyxy':
+ bboxes_xyxy = bboxes
+ bboxes_xywh = xyxy2xywh(bboxes)
+ else:
+ # format is already 'xywh'
+ bboxes_xywh = bboxes
+ bboxes_xyxy = xywh2xyxy(bboxes)
+
+ # build the data pipeline
+ extractor_pipeline = [LoadImage()] + cfg.extractor_pipeline
+ extractor_pipeline = Compose(extractor_pipeline)
+ batch_data = []
+ input_size = cfg['img_res']
+ aspect_ratio = 1 if isinstance(input_size,
+ int) else input_size[0] / input_size[1]
+
+ for i, bbox in enumerate(bboxes_xywh):
+ center, scale = box2cs(bbox, aspect_ratio, bbox_scale_factor=1.25)
+ # prepare data
+ data = {
+ 'image_path': img_or_path,
+ 'center': center,
+ 'scale': scale,
+ 'rotation': 0,
+ 'bbox_score': bbox[4] if len(bbox) == 5 else 1,
+ 'sample_idx': i,
+ }
+ data = extractor_pipeline(data)
+ batch_data.append(data)
+
+ batch_data = collate(batch_data, samples_per_gpu=1)
+
+ if next(model.parameters()).is_cuda:
+ # scatter not work so just move image to cuda device
+ batch_data['img'] = batch_data['img'].to(device)
+
+ # get all img_metas of each bounding box
+ batch_data['img_metas'] = [
+ img_metas[0] for img_metas in batch_data['img_metas'].data
+ ]
+
+ # forward the model
+ with torch.no_grad():
+ results = model(batch_data['img'])
+
+ if isinstance(results, list) or isinstance(results, tuple):
+ results = results[-1].mean(dim=-1).mean(dim=-1)
+
+ for idx in range(len(det_results)):
+ feature_result = det_results[idx].copy()
+ feature_result['bbox'] = bboxes_xyxy[idx]
+ feature_result['features'] = results[idx].cpu().numpy()
+ feature_results.append(feature_result)
+
+ return feature_results
+
+
+def _gather_input_features(extracted_results):
+ """Gather input features.
+
+ Args:
+ extracted_results (List[List[Dict]]):
+ Multi-frame feature extraction results
+
+ Returns:
+ List[List[dict]]: Multi-frame feature extraction results
+ stored in a nested list. Each element of the outer list is the
+ feature extraction results of a single frame, and each element of
+ the inner list is the extracted results of one person,
+ which contains:
+ features (ndarray): extracted features
+ track_id (int): unique id of each person, required when
+ ``with_track_id==True```
+ """
+ sequence_inputs = []
+ for frame in extracted_results:
+ frame_inputs = []
+ for res in frame:
+ inputs = dict()
+ if 'features' in res:
+ inputs['features'] = res['features']
+ if 'track_id' in res:
+ inputs['track_id'] = res['track_id']
+ frame_inputs.append(inputs)
+ sequence_inputs.append(frame_inputs)
+ return sequence_inputs
+
+
+def _collate_feature_sequence(extracted_features,
+ with_track_id=True,
+ target_frame=0):
+ """Reorganize multi-frame feature extraction results into individual
+ feature sequences.
+
+ Args:
+ extracted_features (List[List[Dict]]): Multi-frame feature extraction
+ results stored in a nested list. Each element of the outer list
+ is the feature extraction results of a single frame, and each
+ element of the inner list is the extracted results of one person,
+ which contains:
+ features (ndarray): extracted features
+ track_id (int): unique id of each person, required when
+ ``with_track_id==True```
+ with_track_id (bool): If True, the element in pose_results is expected
+ to contain "track_id", which will be used to gather the pose
+ sequence of a person from multiple frames. Otherwise, the pose
+ results in each frame are expected to have a consistent number and
+ order of identities. Default is True.
+ target_frame (int): The index of the target frame. Default: 0.
+ """
+ T = len(extracted_features)
+ assert T > 0
+
+ target_frame = (T + target_frame) % T # convert negative index to positive
+
+ N = len(
+ extracted_features[target_frame]) # use identities in the target frame
+ if N == 0:
+ return []
+
+ C = extracted_features[target_frame][0]['features'].shape[0]
+
+ track_ids = None
+ if with_track_id:
+ track_ids = [
+ res['track_id'] for res in extracted_features[target_frame]
+ ]
+
+ feature_sequences = []
+ for idx in range(N):
+ feature_seq = dict()
+ # gather static information
+ for k, v in extracted_features[target_frame][idx].items():
+ if k != 'features':
+ feature_seq[k] = v
+ # gather keypoints
+ if not with_track_id:
+ feature_seq['features'] = np.stack(
+ [frame[idx]['features'] for frame in extracted_features])
+ else:
+ features = np.zeros((T, C), dtype=np.float32)
+ features[target_frame] = extracted_features[target_frame][idx][
+ 'features']
+ # find the left most frame containing track_ids[idx]
+ for frame_idx in range(target_frame - 1, -1, -1):
+ contains_idx = False
+ for res in extracted_features[frame_idx]:
+ if res['track_id'] == track_ids[idx]:
+ features[frame_idx] = res['features']
+ contains_idx = True
+ break
+ if not contains_idx:
+ # replicate the left most frame
+ features[frame_idx] = features[frame_idx + 1]
+
+ # find the right most frame containing track_idx[idx]
+ for frame_idx in range(target_frame + 1, T):
+ contains_idx = False
+ for res in extracted_features[frame_idx]:
+ if res['track_id'] == track_ids[idx]:
+ features[frame_idx] = res['features']
+ contains_idx = True
+ break
+ if not contains_idx:
+ # replicate the right most frame
+ features[frame_idx] = features[frame_idx - 1]
+ # break
+ feature_seq['features'] = features
+ feature_sequences.append(feature_seq)
+
+ return feature_sequences
diff --git a/detrsmpl/apis/test.py b/detrsmpl/apis/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..98ed3dc5355967b9c3cff61aeabc2a0fab730e9e
--- /dev/null
+++ b/detrsmpl/apis/test.py
@@ -0,0 +1,172 @@
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+import time
+
+import mmcv
+import torch
+import torch.distributed as dist
+from mmcv.runner import get_dist_info
+
+
+def single_gpu_test(model, data_loader):
+ """Test with single gpu."""
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+
+ batch_size = len(result)
+ if isinstance(result, list):
+ results.extend(result)
+ else:
+ results.append(result)
+
+ if 'img' in data.keys():
+ batch_size = data['img'].size(0)
+ else:
+ batch_size = data['features'].size(0)
+ for _ in range(batch_size):
+ prog_bar.update()
+ return results
+
+
+def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
+ """Test model with multiple gpus.
+
+ This method tests model with multiple gpus and collects the results
+ under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
+ it encodes results to gpu tensors and use gpu communication for results
+ collection. On cpu mode it saves the results on different gpus to 'tmpdir'
+ and collects them by the rank 0 worker.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+ tmpdir (str): Path of directory to save the temporary results from
+ different gpus under cpu mode.
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
+
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+ if rank == 0:
+ # Check if tmpdir is valid for cpu_collect
+ if (not gpu_collect) and (tmpdir is not None and osp.exists(tmpdir)):
+ raise OSError((f'The tmpdir {tmpdir} already exists.',
+ ' Since tmpdir will be deleted after testing,',
+ ' please make sure you specify an empty one.'))
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ time.sleep(2) # This line can prevent deadlock problem in some cases.
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+ if isinstance(result, list):
+ results.extend(result)
+ else:
+ results.append(result)
+
+ if rank == 0:
+ if 'img' in data.keys():
+ batch_size = data['img'].size(0)
+ else:
+ batch_size = data['features'].size(0)
+ for _ in range(batch_size * world_size):
+ prog_bar.update()
+
+ # collect results from all ranks
+ if gpu_collect:
+ results = collect_results_gpu(results, len(dataset))
+ else:
+ results = collect_results_cpu(results, len(dataset), tmpdir)
+ return results
+
+
+def collect_results_cpu(result_part, size, tmpdir=None):
+ """Collect results in cpu."""
+ rank, world_size = get_dist_info()
+ # create a tmp dir if it is not specified
+ if tmpdir is None:
+ MAX_LEN = 512
+ # 32 is whitespace
+ dir_tensor = torch.full((MAX_LEN, ),
+ 32,
+ dtype=torch.uint8,
+ device='cuda')
+ if rank == 0:
+ mmcv.mkdir_or_exist('.dist_test')
+ tmpdir = tempfile.mkdtemp(dir='.dist_test')
+ tmpdir = torch.tensor(bytearray(tmpdir.encode()),
+ dtype=torch.uint8,
+ device='cuda')
+ dir_tensor[:len(tmpdir)] = tmpdir
+ dist.broadcast(dir_tensor, 0)
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+ else:
+ mmcv.mkdir_or_exist(tmpdir)
+ # dump the part result to the dir
+ mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
+ dist.barrier()
+ # collect all parts
+ if rank != 0:
+ return None
+ else:
+ # load results of all parts from tmp dir
+ part_list = []
+ for i in range(world_size):
+ part_file = osp.join(tmpdir, f'part_{i}.pkl')
+ part_result = mmcv.load(part_file)
+ part_list.append(part_result)
+ # import ipdb;ipdb.set_trace()
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ # remove tmp dir
+ shutil.rmtree(tmpdir)
+ return ordered_results
+
+
+def collect_results_gpu(result_part, size):
+ """Collect results in gpu."""
+ rank, world_size = get_dist_info()
+ # dump result part to tensor with pickle
+ part_tensor = torch.tensor(bytearray(pickle.dumps(result_part)),
+ dtype=torch.uint8,
+ device='cuda')
+ # gather all result part tensor shape
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
+ dist.all_gather(shape_list, shape_tensor)
+ # padding result part tensor to max length
+ shape_max = torch.tensor(shape_list).max()
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+ part_send[:shape_tensor[0]] = part_tensor
+ part_recv_list = [
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
+ ]
+ # gather all result part
+ dist.all_gather(part_recv_list, part_send)
+
+ if rank == 0:
+ part_list = []
+ for recv, shape in zip(part_recv_list, shape_list):
+ part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
+ part_list.append(part_result)
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ return ordered_results
diff --git a/detrsmpl/apis/train.py b/detrsmpl/apis/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dbcf0a580fd5e3377b44b9ff695dce6943d1521
--- /dev/null
+++ b/detrsmpl/apis/train.py
@@ -0,0 +1,163 @@
+import random
+import warnings
+
+import numpy as np
+import torch
+from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.runner import (
+ DistSamplerSeedHook,
+ Fp16OptimizerHook,
+ OptimizerHook,
+ build_runner,
+)
+
+from detrsmpl.core.distributed_wrapper import DistributedDataParallelWrapper
+from detrsmpl.core.evaluation import DistEvalHook, EvalHook
+from detrsmpl.core.optimizer import build_optimizers
+from detrsmpl.data.datasets import build_dataloader, build_dataset
+from detrsmpl.utils.logger import get_root_logger
+
+
+def set_random_seed(seed, deterministic=False):
+ """Set random seed.
+
+ Args:
+ seed (int): Seed to be used.
+ deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ if deterministic:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def train_model(model,
+ dataset,
+ cfg,
+ distributed=False,
+ validate=False,
+ timestamp=None,
+ device='cuda',
+ meta=None):
+ """Main api for training model."""
+ logger = get_root_logger(cfg.log_level)
+
+ # prepare data loaders
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
+ data_loaders = [
+ build_dataloader(
+ ds,
+ cfg.data.samples_per_gpu,
+ cfg.data.workers_per_gpu,
+ # cfg.gpus will be ignored if distributed
+ num_gpus=len(cfg.gpu_ids),
+ dist=distributed,
+ round_up=True,
+ seed=cfg.seed) for ds in dataset
+ ]
+
+ # determine whether use adversarial training precess or not
+ use_adverserial_train = cfg.get('use_adversarial_train', False)
+
+ # put model on gpus
+ if distributed:
+ find_unused_parameters = cfg.get('find_unused_parameters', False)
+ # Sets the `find_unused_parameters` parameter in
+ # torch.nn.parallel.DistributedDataParallel
+ if use_adverserial_train:
+ # Use DistributedDataParallelWrapper for adversarial training
+ model = DistributedDataParallelWrapper(
+ model,
+ device_ids=[torch.cuda.current_device()],
+ broadcast_buffers=False,
+ find_unused_parameters=find_unused_parameters)
+ else:
+ model = MMDistributedDataParallel(
+ model.cuda(),
+ device_ids=[torch.cuda.current_device()],
+ broadcast_buffers=False,
+ find_unused_parameters=find_unused_parameters)
+ else:
+ if device == 'cuda':
+ model = MMDataParallel(model.cuda(cfg.gpu_ids[0]),
+ device_ids=cfg.gpu_ids)
+ elif device == 'cpu':
+ model = model.cpu()
+ else:
+ raise ValueError(F'unsupported device name {device}.')
+
+ # build runner
+ optimizer = build_optimizers(model, cfg.optimizer)
+ if cfg.get('runner') is None:
+ cfg.runner = {
+ 'type': 'EpochBasedRunner',
+ 'max_epochs': cfg.total_epochs
+ }
+ warnings.warn(
+ 'config is now expected to have a `runner` section, '
+ 'please set `runner` in your config.', UserWarning)
+
+ runner = build_runner(cfg.runner,
+ default_args=dict(model=model,
+ batch_processor=None,
+ optimizer=optimizer,
+ work_dir=cfg.work_dir,
+ logger=logger,
+ meta=meta))
+
+ # an ugly walkaround to make the .log and .log.json filenames the same
+ runner.timestamp = timestamp
+
+ if use_adverserial_train:
+ # The optimizer step process is included in the train_step function
+ # of the model, so the runner should NOT include optimizer hook.
+ optimizer_config = None
+ else:
+ # fp16 setting
+ fp16_cfg = cfg.get('fp16', None)
+ if fp16_cfg is not None:
+ optimizer_config = Fp16OptimizerHook(**cfg.optimizer_config,
+ **fp16_cfg,
+ distributed=distributed)
+ elif distributed and 'type' not in cfg.optimizer_config:
+ optimizer_config = OptimizerHook(**cfg.optimizer_config)
+ else:
+ optimizer_config = cfg.optimizer_config
+
+ # register hooks
+ runner.register_training_hooks(cfg.lr_config,
+ optimizer_config,
+ cfg.checkpoint_config,
+ cfg.log_config,
+ cfg.get('momentum_config', None),
+ custom_hooks_config=cfg.get(
+ 'custom_hooks', None))
+ if distributed:
+ runner.register_hook(DistSamplerSeedHook())
+
+ # register eval hooks
+ if validate:
+ val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
+ val_dataloader = build_dataloader(
+ val_dataset,
+ samples_per_gpu=cfg.data.samples_per_gpu,
+ workers_per_gpu=cfg.data.workers_per_gpu,
+ dist=distributed,
+ shuffle=False,
+ round_up=True)
+ eval_cfg = cfg.get('evaluation', {})
+ eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
+ eval_hook = DistEvalHook if distributed else EvalHook
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
+
+ if cfg.resume_from:
+ runner.resume(cfg.resume_from)
+ elif cfg.load_from:
+ runner.load_checkpoint(cfg.load_from)
+ runner.run(data_loaders, cfg.workflow)
diff --git a/detrsmpl/core/__init__.py b/detrsmpl/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/core/cameras/__init__.py b/detrsmpl/core/cameras/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..248d02bb8989969384e9685bd7e518cca8142e6d
--- /dev/null
+++ b/detrsmpl/core/cameras/__init__.py
@@ -0,0 +1,19 @@
+from detrsmpl.core.cameras import builder, camera_parameters, cameras
+from detrsmpl.core.cameras.builder import CAMERAS, build_cameras
+from detrsmpl.core.cameras.cameras import (
+ FoVOrthographicCameras,
+ FoVPerspectiveCameras,
+ MMCamerasBase,
+ OrthographicCameras,
+ PerspectiveCameras,
+ WeakPerspectiveCameras,
+ compute_direction_cameras,
+ compute_orbit_cameras,
+)
+
+__all__ = [
+ 'CAMERAS', 'FoVOrthographicCameras', 'FoVPerspectiveCameras',
+ 'MMCamerasBase', 'OrthographicCameras', 'PerspectiveCameras',
+ 'WeakPerspectiveCameras', 'build_cameras', 'builder', 'camera_parameters',
+ 'cameras', 'compute_orbit_cameras', 'compute_direction_cameras'
+]
diff --git a/detrsmpl/core/cameras/builder.py b/detrsmpl/core/cameras/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..711ecadf36ae56c0f1b16554c443be7a4e41b415
--- /dev/null
+++ b/detrsmpl/core/cameras/builder.py
@@ -0,0 +1,8 @@
+from mmcv.utils import Registry
+
+CAMERAS = Registry('cameras')
+
+
+def build_cameras(cfg):
+ """Build cameras."""
+ return CAMERAS.build(cfg)
diff --git a/detrsmpl/core/cameras/camera_parameters.py b/detrsmpl/core/cameras/camera_parameters.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c8534f70c2ea352feb3d8e54b5a69aedf715a2f
--- /dev/null
+++ b/detrsmpl/core/cameras/camera_parameters.py
@@ -0,0 +1,678 @@
+import json
+import warnings
+from enum import Enum
+from typing import Any, List, Tuple, Union
+
+import numpy as np
+import torch
+
+from detrsmpl.core.cameras.cameras import PerspectiveCameras
+from detrsmpl.core.conventions.cameras.convert_convention import (
+ convert_camera_matrix,
+ convert_K_3x3_to_4x4,
+ convert_K_4x4_to_3x3,
+)
+from .builder import build_cameras
+
+_CAMERA_PARAMETER_SUPPORTED_KEYS_ = {
+ 'H': {
+ 'type': int,
+ },
+ 'W': {
+ 'type': int,
+ },
+ 'in_mat': {
+ 'type': list,
+ 'len': 3,
+ },
+ 'rotation_mat': {
+ 'type': list,
+ 'len': 3,
+ },
+ 'translation': {
+ 'type': list,
+ 'len': 3,
+ },
+ 'k1': {
+ 'type': float,
+ },
+ 'k2': {
+ 'type': float,
+ },
+ 'k3': {
+ 'type': float,
+ },
+ 'k4': {
+ 'type': float,
+ },
+ 'k5': {
+ 'type': float,
+ },
+ 'k6': {
+ 'type': float,
+ },
+ 'p1': {
+ 'type': float,
+ },
+ 'p2': {
+ 'type': float,
+ },
+}
+
+
+class _TypeValidation(Enum):
+ MATCH = 0
+ ARRAY = 1
+ FAIL = 2
+
+
+class CameraParameter:
+ logger = None
+ SUPPORTED_KEYS = _CAMERA_PARAMETER_SUPPORTED_KEYS_
+
+ def __init__(self,
+ name: str = 'default',
+ H: int = 1080,
+ W: int = 1920) -> None:
+ """
+ Args:
+ name (str, optional):
+ Name of this camera. Defaults to "default".
+ H (int, optional):
+ Height of a frame, in pixel. Defaults to 1080.
+ W (int, optional):
+ Width of a frame, in pixel. Defaults to 1920.
+ """
+ self.name = name
+ self.parameters_dict = {}
+ in_mat = __zero_mat_list__(3)
+ self.parameters_dict['in_mat'] = in_mat
+ for distort_name in __distort_coefficient_names__:
+ self.parameters_dict[distort_name] = 0.0
+ _, H = self.validate_item('H', H)
+ self.parameters_dict['H'] = H
+ _, W = self.validate_item('W', W)
+ self.parameters_dict['W'] = W
+ r_mat = __zero_mat_list__(3)
+ self.parameters_dict['rotation_mat'] = r_mat
+ t_list = [0.0, 0.0, 0.0]
+ self.parameters_dict['translation'] = t_list
+
+ def reset_distort(self) -> None:
+ """Reset all distort coefficients to zero."""
+ for distort_name in __distort_coefficient_names__:
+ self.parameters_dict[distort_name] = 0.0
+
+ def get_opencv_distort_mat(self) -> np.ndarray:
+ """Get a numpy array of 8 distort coefficients, which is the distCoeffs
+ arg of cv2.undistort.
+
+ Returns:
+ ndarray:
+ (k_1, k_2, p_1, p_2, k_3, k_4, k_5, k_6) of 8 elements.
+ """
+ dist_coeffs = [
+ self.get_value('k1'),
+ self.get_value('k2'),
+ self.get_value('p1'),
+ self.get_value('p2'),
+ self.get_value('k3'),
+ self.get_value('k4'),
+ self.get_value('k5'),
+ self.get_value('k6'),
+ ]
+ dist_coeffs = np.array(dist_coeffs)
+ return dist_coeffs
+
+ def set_KRT(self,
+ K_mat: np.ndarray,
+ R_mat: np.ndarray,
+ T_vec: np.ndarray,
+ inverse_extrinsic: bool = False) -> None:
+ """Set intrinsic and extrinsic of a camera.
+
+ Args:
+ K_mat (np.ndarray):
+ In shape [3, 3].
+ R_mat (np.ndarray):
+ Rotation from world to view in default.
+ In shape [3, 3].
+ T_vec (np.ndarray):
+ Translation from world to view in default.
+ In shape [3,].
+ inverse_extrinsic (bool, optional):
+ If true, R_mat and T_vec transform a point
+ from view to world. Defaults to False.
+ """
+ k_shape = K_mat.shape
+ assert k_shape[0] == k_shape[1] == 3
+ r_shape = R_mat.shape
+ assert r_shape[0] == r_shape[1] == 3
+ assert T_vec.ndim == 1 and T_vec.shape[0] == 3
+ self.set_mat_np('in_mat', K_mat)
+ if inverse_extrinsic:
+ R_mat = np.linalg.inv(R_mat)
+ T_vec = -np.dot(R_mat, T_vec).reshape((3))
+ self.set_mat_np('rotation_mat', R_mat)
+ self.set_value('translation', T_vec.tolist())
+
+ def get_KRT(self, k_dim=3) -> List[np.ndarray]:
+ """Get intrinsic and extrinsic of a camera.
+
+ Args:
+ k_dim (int, optional):
+ Dimension of the returned mat K.
+ Defaults to 3.
+
+ Raises:
+ ValueError: k_dim is neither 3 nor 4.
+
+ Returns:
+ List[np.ndarray]:
+ K_mat (np.ndarray):
+ In shape [3, 3].
+ R_mat (np.ndarray):
+ Rotation from world to view in default.
+ In shape [3, 3].
+ T_vec (np.ndarray):
+ Translation from world to view in default.
+ In shape [3,].
+ """
+ K_3x3 = self.get_mat_np('in_mat')
+ R_mat = self.get_mat_np('rotation_mat')
+ T_vec = np.asarray(self.get_value('translation'))
+ if k_dim == 3:
+ return [K_3x3, R_mat, T_vec]
+ elif k_dim == 4:
+ K_3x3 = np.expand_dims(K_3x3, 0) # shape (1, 3, 3)
+ K_4x4 = convert_K_3x3_to_4x4(
+ K=K_3x3, is_perspective=True) # shape (1, 4, 4)
+ K_4x4 = K_4x4[0, :, :]
+ return [K_4x4, R_mat, T_vec]
+ else:
+ raise ValueError(f'K mat cannot be converted to {k_dim}x{k_dim}')
+
+ def set_mat_np(self, mat_key: str, mat_numpy: np.ndarray) -> None:
+ """Set a matrix-type parameter to mat_numpy.
+
+ Args:
+ mat_key (str):
+ Key of the target matrix. in_mat or rotation_mat.
+ mat_numpy (ndarray):
+ Matrix in numpy format.
+
+ Raises:
+ TypeError:
+ mat_numpy is not an np.ndarray.
+ """
+ if not isinstance(mat_numpy, np.ndarray):
+ raise TypeError
+ self.set_mat_list(mat_key, mat_numpy.tolist())
+
+ def set_mat_list(self, mat_key: str, mat_list: List[list]) -> None:
+ """Set a matrix-type parameter to mat_list.
+
+ Args:
+ mat_key (str):
+ Key of the target matrix. in_mat or rotation_mat.
+ mat_list (List[list]):
+ Matrix in list format.
+ """
+ _, mat_list = self.validate_item(mat_key, mat_list)
+ self.parameters_dict[mat_key] = mat_list
+
+ def set_value(self, key: str, value: Any) -> None:
+ """Set a parameter to value.
+
+ Args:
+ key (str):
+ Name of the parameter.
+ value (object):
+ New value of the parameter.
+ """
+ _, value = self.validate_item(key, value)
+ self.parameters_dict[key] = value
+
+ def get_value(self, key: str) -> Any:
+ """Get a parameter by key.
+
+ Args:
+ key (str):
+ Name of the parameter.
+ Raises:
+ KeyError: key not in self.parameters_dict
+
+ Returns:
+ object:
+ Value of the parameter.
+ """
+ if key not in self.parameters_dict:
+ raise KeyError(key)
+ else:
+ return self.parameters_dict[key]
+
+ def get_mat_np(self, key: str) -> np.ndarray:
+ """Get a a matrix-type parameter by key.
+
+ Args:
+ key (str):
+ Name of the parameter.
+ Raises:
+ KeyError: key not in self.parameters_dict
+
+ Returns:
+ ndarray:
+ Value of the parameter.
+ """
+ if key not in self.parameters_dict:
+ raise KeyError(key)
+ else:
+ mat_list = self.parameters_dict[key]
+ mat_np = np.array(mat_list).reshape((3, 3))
+ return mat_np
+
+ def to_string(self) -> str:
+ """Convert self.to_dict() to a string.
+
+ Returns:
+ str:
+ A dict in json string format.
+ """
+ dump_dict = self.to_dict()
+ ret_str = json.dumps(dump_dict)
+ return ret_str
+
+ def to_dict(self) -> dict:
+ """Dump camera name and parameters to dict.
+
+ Returns:
+ dict:
+ Put self.name and self.parameters_dict
+ in one dict.
+ """
+ dump_dict = self.parameters_dict.copy()
+ dump_dict['name'] = self.name
+ return dump_dict
+
+ def dump(self, json_path: str) -> None:
+ """Dump camera name and parameters to a file.
+
+ Returns:
+ dict:
+ Put self.name and self.parameters_dict
+ in one dict, and dump them to a json file.
+ """
+ dump_dict = self.to_dict()
+ with open(json_path, 'w') as f_write:
+ json.dump(dump_dict, f_write)
+
+ def load(self, json_path: str) -> None:
+ """Load camera name and parameters from a file."""
+ with open(json_path, 'r') as f_read:
+ dumped_dict = json.load(f_read)
+ self.load_from_dict(dumped_dict)
+
+ def load_from_dict(self, json_dict: dict) -> None:
+ """Load name and parameters from a dict.
+
+ Args:
+ json_dict (dict):
+ A dict comes from self.to_dict().
+ """
+ for key in json_dict.keys():
+ if key == 'name':
+ self.name = json_dict[key]
+ elif key == 'rotation':
+ self.parameters_dict['rotation_mat'] = np.array(
+ json_dict[key]).reshape(3, 3).tolist()
+ elif key == 'translation':
+ self.parameters_dict[key] = np.array(json_dict[key]).reshape(
+ (3)).tolist()
+ else:
+ self.parameters_dict[key] = json_dict[key]
+ if '_mat' in key:
+ self.parameters_dict[key] = np.array(
+ self.parameters_dict[key]).reshape(3, 3).tolist()
+
+ def load_from_chessboard(self,
+ chessboard_dict: dict,
+ name: str,
+ inverse: bool = True) -> None:
+ """Load name and parameters from a dict.
+
+ Args:
+ chessboard_dict (dict):
+ A dict loaded from json.load(chessboard_file).
+ name (str):
+ Name of this camera.
+ inverse (bool, optional):
+ Whether to inverse rotation and translation mat.
+ Defaults to False.
+ """
+ camera_param_dict = \
+ __parse_chessboard_param__(chessboard_dict, name, inverse=inverse)
+ self.load_from_dict(camera_param_dict)
+
+ def load_kinect_from_smc(self, smc_reader, kinect_id: int) -> None:
+ """Load name and parameters of a kinect from an SmcReader instance.
+
+ Args:
+ smc_reader (mmhuman3d.data.data_structures.smc_reader.SMCReader):
+ An SmcReader instance containing kinect camera parameters.
+ kinect_id (int):
+ Id of the target kinect.
+ """
+ name = kinect_id
+ extrinsics_dict = \
+ smc_reader.get_kinect_color_extrinsics(
+ kinect_id, homogeneous=False
+ )
+ rot_np = extrinsics_dict['R']
+ trans_np = extrinsics_dict['T']
+ intrinsics_np = \
+ smc_reader.get_kinect_color_intrinsics(
+ kinect_id
+ )
+ resolution = \
+ smc_reader.get_kinect_color_resolution(
+ kinect_id
+ )
+ rmatrix = np.linalg.inv(rot_np).reshape(3, 3)
+ tvec = -np.dot(rmatrix, trans_np)
+ self.name = name
+ self.set_mat_np('in_mat', intrinsics_np)
+ self.set_mat_np('rotation_mat', rmatrix)
+ self.set_value('translation', tvec.tolist())
+ self.set_value('H', resolution[1])
+ self.set_value('W', resolution[0])
+
+ def load_iphone_from_smc(self,
+ smc_reader,
+ iphone_id: int = 0,
+ frame_id: int = 0) -> None:
+ """Load name and parameters of an iPhone from an SmcReader instance.
+
+ Args:
+ smc_reader (mmhuman3d.data.data_structures.smc_reader.SMCReader):
+ An SmcReader instance containing kinect camera parameters.
+ iphone_id (int):
+ Id of the target iphone.
+ Defaults to 0.
+ frame_id (int):
+ Frame ID of one selected frame.
+ It only influences the intrinsics.
+ Defaults to 0.
+ """
+ name = f'iPhone_{iphone_id}'
+ extrinsics_mat = \
+ smc_reader.get_iphone_extrinsics(
+ iphone_id, homogeneous=True
+ )
+ rot_np = extrinsics_mat[:3, :3]
+ trans_np = extrinsics_mat[:3, 3]
+ intrinsics_np = \
+ smc_reader.get_iphone_intrinsics(
+ iphone_id, frame_id
+ )
+ resolution = \
+ smc_reader.get_iphone_color_resolution(
+ iphone_id
+ )
+ rmatrix = np.linalg.inv(rot_np).reshape(3, 3)
+ tvec = -np.dot(rmatrix, trans_np)
+ self.name = name
+ self.set_mat_np('in_mat', intrinsics_np)
+ self.set_mat_np('rotation_mat', rmatrix)
+ self.set_value('translation', tvec.tolist())
+ self.set_value('H', resolution[1])
+ self.set_value('W', resolution[0])
+
+ @classmethod
+ def load_from_perspective_cameras(cls,
+ cam,
+ name: str,
+ resolution: Union[List, Tuple] = None):
+ """Load parameters from a PerspectiveCameras and return a
+ CameraParameter.
+
+ Args:
+ cam (mmhuman3d.core.cameras.cameras.PerspectiveCameras):
+ An instance.
+ name (str):
+ Name of this camera.
+ """
+ assert isinstance(cam, PerspectiveCameras
+ ), 'Wrong input, support PerspectiveCameras only!'
+ if len(cam) > 1:
+ warnings.warn('Will only use the first camera in the batch.')
+ cam = cam[0]
+
+ resolution = resolution if resolution is not None else cam.resolution[
+ 0].tolist()
+
+ height, width = int(resolution[0]), int(resolution[1])
+
+ cam_param = CameraParameter()
+ cam_param.__init__(H=height, W=width, name=name)
+
+ k_4x4 = cam.K # shape (1, 4, 4)
+ r_3x3 = cam.R # shape (1, 3, 3)
+ t_3 = cam.T # shape (1, 3)
+ is_perspective = cam.is_perspective()
+ in_ndc = cam.in_ndc()
+
+ k_4x4, r_3x3, t_3 = convert_camera_matrix(K=k_4x4,
+ R=r_3x3,
+ T=t_3,
+ is_perspective=False,
+ in_ndc_dst=False,
+ in_ndc_src=in_ndc,
+ convention_src='pytorch3d',
+ convention_dst='opencv',
+ resolution_src=(height,
+ width),
+ resolution_dst=(height,
+ width))
+
+ k_3x3 = \
+ convert_K_4x4_to_3x3(k_4x4, is_perspective=is_perspective)
+
+ k_3x3 = k_3x3.numpy()[0]
+ r_3x3 = r_3x3.numpy()[0]
+ t_3 = t_3.numpy()[0]
+ cam_param.name = name
+ cam_param.set_mat_np('in_mat', k_3x3)
+ cam_param.set_mat_np('rotation_mat', r_3x3)
+ cam_param.set_value('translation', t_3.tolist())
+ cam_param.parameters_dict.update(H=height)
+ cam_param.parameters_dict.update(W=width)
+ return cam_param
+
+ def export_to_perspective_cameras(self) -> PerspectiveCameras:
+ """Export to a opencv defined screen space PerspectiveCameras.
+
+ Returns:
+ Same defined PerspectiveCameras of batch_size 1.
+ """
+ height = self.parameters_dict['H']
+ width = self.parameters_dict['W']
+ k_4x4, rotation, translation = self.get_KRT(k_dim=4)
+ k_4x4 = np.expand_dims(k_4x4, 0) # shape (1, 3, 3)
+ rotation = np.expand_dims(rotation, 0) # shape (1, 3, 3)
+ translation = np.expand_dims(translation, 0) # shape (1, 3)
+ new_K = torch.from_numpy(k_4x4)
+ new_R = torch.from_numpy(rotation)
+ new_T = torch.from_numpy(translation)
+ cam = build_cameras(
+ dict(type='PerspectiveCameras',
+ K=new_K.float(),
+ R=new_R.float(),
+ T=new_T.float(),
+ convention='opencv',
+ in_ndc=False,
+ resolution=(height, width)))
+ return cam
+
+ def validate_item(self, key: Any, val: Any) -> List:
+ """Check whether the key and its value matches definition in
+ CameraParameter.SUPPORTED_KEYS.
+
+ Args:
+ key (Any):
+ Key in CameraParameter.
+ val (Any):
+ Value to the key.
+
+ Raises:
+ KeyError:
+ key cannot be found in
+ CameraParameter.SUPPORTED_KEYS.
+ TypeError:
+ Value's type doesn't match definition.
+ Returns:
+ key (Any): The input key.
+ val (Any): The value casted into correct format.
+ """
+ self.__check_key__(key)
+ formatted_val = self.__validate_value_type__(key, val)
+ return key, formatted_val
+
+ def __check_key__(self, key: Any) -> None:
+ """Check whether the key matches definition in
+ CameraParameter.SUPPORTED_KEYS.
+
+ Args:
+ key (Any):
+ Key in CameraParameter.
+
+ Raises:
+ KeyError:
+ key cannot be found in
+ CameraParameter.SUPPORTED_KEYS.
+ """
+ if key not in self.__class__.SUPPORTED_KEYS:
+ err_msg = 'Key check failed in CameraParameter:\n'
+ err_msg += f'key={str(key)}\n'
+ raise KeyError(err_msg)
+
+ def __validate_value_type__(self, key: Any, val: Any) -> Any:
+ """Check whether the type of value matches definition in
+ CameraParameter.SUPPORTED_KEYS.
+
+ Args:
+ key (Any):
+ Key in CameraParameter.
+ val (Any):
+ Value to the key.
+
+ Raises:
+ TypeError:
+ Value is supported but doesn't match definition.
+
+ Returns:
+ val (Any): The value casted into correct format.
+ """
+ np_type_mapping = {int: np.integer, float: np.floating}
+ supported_keys = self.__class__.SUPPORTED_KEYS
+ validation_result = _TypeValidation.FAIL
+ ret_val = None
+ if supported_keys[key]['type'] == int or\
+ supported_keys[key]['type'] == float:
+ type_str = str(type(val))
+ class_name = type_str.split('\'')[1]
+ if type(val) == self.__class__.SUPPORTED_KEYS[key]['type']:
+ validation_result = _TypeValidation.MATCH
+ ret_val = val
+ elif class_name.startswith('numpy'):
+ # a value is required, not array
+ if np.issubdtype(type(val),
+ np_type_mapping[supported_keys[key]['type']]):
+ validation_result = _TypeValidation.MATCH
+ ret_val = val.astype(supported_keys[key]['type'])
+ elif np.issubdtype(type(val), np.ndarray):
+ validation_result = _TypeValidation.ARRAY
+ elif class_name.startswith('torch'):
+ # only one element tensors
+ # can be converted to Python scalars
+ if len(val.size()) == 0:
+ val_item = val.item()
+ if type(val_item) == supported_keys[key]['type']:
+ validation_result = _TypeValidation.MATCH
+ ret_val = val_item
+ else:
+ validation_result = _TypeValidation.ARRAY
+ else:
+ if type(val) == self.__class__.SUPPORTED_KEYS[key]['type']:
+ validation_result = _TypeValidation.MATCH
+ ret_val = val
+ if validation_result != _TypeValidation.MATCH:
+ err_msg = 'Type check failed in CameraParameter:\n'
+ err_msg += f'key={str(key)}\n'
+ err_msg += f'type(val)={type(val)}\n'
+ if validation_result == _TypeValidation.ARRAY:
+ err_msg += 'A single value is expected, ' +\
+ 'neither an array nor a slice.\n'
+ raise TypeError(err_msg)
+ return ret_val
+
+
+def __parse_chessboard_param__(chessboard_camera_param, name, inverse=True):
+ """Parse a dict loaded from chessboard file into another dict needed by
+ CameraParameter.
+
+ Args:
+ chessboard_camera_param (dict):
+ A dict loaded from json.load(chessboard_file).
+ name (str):
+ Name of this camera.
+ inverse (bool, optional):
+ Whether to inverse rotation and translation mat.
+ Defaults to True.
+
+ Returns:
+ dict:
+ A dict of parameters in CameraParameter.to_dict() format.
+ """
+ camera_param_dict = {}
+ camera_param_dict['H'] = chessboard_camera_param['imgSize'][1]
+ camera_param_dict['W'] = chessboard_camera_param['imgSize'][0]
+ camera_param_dict['in_mat'] = chessboard_camera_param['K']
+ camera_param_dict['k1'] = 0
+ camera_param_dict['k2'] = 0
+ camera_param_dict['k3'] = 0
+ camera_param_dict['k4'] = 0
+ camera_param_dict['k5'] = 0
+ camera_param_dict['p1'] = 0
+ camera_param_dict['p2'] = 0
+ camera_param_dict['name'] = name
+ camera_param_dict['rotation'] = chessboard_camera_param['R']
+ camera_param_dict['translation'] = chessboard_camera_param['T']
+ if inverse:
+ rmatrix = np.linalg.inv(
+ np.array(camera_param_dict['rotation']).reshape(3, 3))
+ camera_param_dict['rotation'] = rmatrix.tolist()
+ tmatrix = np.array(camera_param_dict['translation']).reshape((3, 1))
+ tvec = -np.dot(rmatrix, tmatrix)
+ camera_param_dict['translation'] = tvec.reshape((3)).tolist()
+ return camera_param_dict
+
+
+__distort_coefficient_names__ = [
+ 'k1', 'k2', 'k3', 'k4', 'k5', 'k6', 'p1', 'p2'
+]
+
+
+def __zero_mat_list__(n=3):
+ """Return a zero mat in list format.
+
+ Args:
+ n (int, optional):
+ Length of the edge.
+ Defaults to 3.
+
+ Returns:
+ list:
+ List[List[int]]
+ """
+ ret_list = [[0] * n for _ in range(n)]
+ return ret_list
diff --git a/detrsmpl/core/cameras/cameras.py b/detrsmpl/core/cameras/cameras.py
new file mode 100644
index 0000000000000000000000000000000000000000..907d591e4f6b2edf2d6fc37b0265c06ebbe3f600
--- /dev/null
+++ b/detrsmpl/core/cameras/cameras.py
@@ -0,0 +1,1426 @@
+import math
+from typing import Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from pytorch3d.renderer import cameras
+from pytorch3d.structures import Meshes
+from pytorch3d.transforms import Transform3d
+
+from detrsmpl.core.conventions.cameras.convert_convention import (
+ convert_camera_matrix,
+ convert_ndc_to_screen,
+ convert_screen_to_ndc,
+ convert_world_view,
+)
+from detrsmpl.utils.transforms import ee_to_rotmat
+from .builder import CAMERAS
+
+
+class MMCamerasBase(cameras.CamerasBase):
+ """Inherited from Pytorch3D CamerasBase and provide some new functions."""
+ def __init__(self, **kwargs) -> None:
+ """Initialize your cameras with `build_cameras` following:
+
+ 1): provide `K`, `R`, `T`, `resolution`/`image_size`, `in_ndc`
+ directly.
+ `K` should be shape of (N, 3, 3) or (N, 4, 4).
+ `R` should be shape of (N, 3, 3).
+ `T` should be shape of (N, 3).
+ 2): if `K` is not provided, will use `get_default_projection_matrix`
+ to generate K from camera intrinsic parameters.
+ E.g., you can pass `focal_length`, `principal_point` for
+ perspective camers.
+ If these args are not provided, will use default values.
+ 3): if `R` is not provided, will use Identity matrix as default.
+ 4): if `T` is not provided, will use zeros matrix as default.
+ 5): `convention` means your source parameter camera convention.
+ This mainly depends on how you get the matrixs. E.g., you get the
+ `K` `R`, `T` by calibration with opencv, you should set
+ `convention = opencv`. To figure out your camera convention,
+ please see the definition of its extrinsic and intrinsic matrixs.
+ For projection and rendering, the matrixs will be converted to
+ `pytorch3d` finally since the `transforms3d` called in rendering
+ and projection are defined as `pytorch3d` convention.
+ 6): `image_size` equals `resolution`.
+ 7): `in_ndc` could be set for 'PerspectiveCameras' and
+ 'OrthographicCameras', other cameras are fixed for this arg.
+ `in_ndc = True` means your projection matrix is defined as `camera
+ space to NDC space`. Under this cirecumstance you need to set
+ `image_size` or `resolution` (they are equal) when you need to do
+ `transform_points_screen`. You can also override resolution
+ in `transform_points_screen` function.
+ `in_ndc = False` means your projections matrix is defined as
+ `cameras space to screen space`. Under this cirecumstance you do
+ not need to set `image_size` or `resolution` (they are equal) when
+ you need to do `transform_points_screen` since the projection
+ matrix is defined as view space to screen space.
+ """
+ for k in kwargs:
+ if isinstance(kwargs.get(k), np.ndarray):
+ kwargs.update({k: torch.Tensor(kwargs[k])})
+ convention = kwargs.pop('convention', 'pytorch3d').lower()
+ in_ndc = kwargs.pop('in_ndc', kwargs.get('_in_ndc'))
+ kwargs.update(_in_ndc=in_ndc)
+ is_perspective = kwargs.get('_is_perspective')
+ kwargs.pop('is_perspective', None)
+
+ image_size = kwargs.get('image_size', kwargs.get('resolution', None))
+
+ if image_size is not None:
+ if isinstance(image_size, (int, float)):
+ image_size = (image_size, image_size)
+ if isinstance(image_size, (tuple, list)):
+ image_size = torch.Tensor(image_size)
+ if isinstance(image_size, torch.Tensor):
+ if image_size.numel() == 1:
+ image_size = image_size.repeat(2)
+ image_size = image_size.view(-1, 2)
+
+ if kwargs.get('K') is None:
+ focal_length = kwargs.get('focal_length', None)
+ if focal_length is not None:
+ if not isinstance(focal_length, Iterable):
+ focal_length = [focal_length, focal_length]
+ if not torch.is_tensor(focal_length):
+ focal_length = torch.FloatTensor(focal_length).view(-1, 2)
+ elif focal_length.numel() == 1:
+ focal_length = focal_length.repeat(2).view(-1, 2)
+ kwargs.update(focal_length=focal_length)
+
+ principal_point = kwargs.get('principal_point', None)
+ if principal_point is not None:
+ if isinstance(principal_point, (tuple, list)):
+ principal_point = torch.FloatTensor(principal_point)
+ principal_point = principal_point.view(-1, 2)
+ kwargs.update(principal_point=principal_point)
+
+ K = self.get_default_projection_matrix(**kwargs)
+
+ K, _, _ = convert_camera_matrix(K=K,
+ is_perspective=is_perspective,
+ convention_src='pytorch3d',
+ convention_dst='pytorch3d',
+ in_ndc_src=in_ndc,
+ in_ndc_dst=in_ndc,
+ resolution_dst=image_size,
+ resolution_src=image_size)
+ kwargs.update(K=K)
+
+ K, R, T = convert_camera_matrix(K=kwargs.get('K'),
+ R=kwargs.get('R', None),
+ T=kwargs.get('T', None),
+ convention_src=convention,
+ convention_dst='pytorch3d',
+ is_perspective=is_perspective,
+ in_ndc_src=in_ndc,
+ in_ndc_dst=in_ndc,
+ resolution_src=image_size,
+ resolution_dst=image_size)
+
+ if image_size is not None:
+ if image_size.shape[0] == 1:
+ image_size = image_size.repeat(K.shape[0], 1)
+ kwargs.update(image_size=image_size)
+ kwargs.update(resolution=image_size)
+
+ kwargs.update(K=K, R=R, T=T)
+
+ super().__init__(**kwargs)
+
+ def get_camera_plane_normals(self, **kwargs) -> torch.Tensor:
+ """Get the identity normal vector which stretchs out of the camera
+ plane.
+
+ Could pass `R` to override the camera extrinsic rotation matrix.
+ Returns:
+ torch.Tensor: shape will be (N, 3)
+ """
+ normals = torch.Tensor([0, 0, 1]).view(1, 3).to(self.device)
+ w2v_trans = self.get_world_to_view_transform(**kwargs)
+ normals = w2v_trans.inverse().transform_normals(normals)
+ return normals.view(-1, 3)
+
+ def compute_depth_of_points(self, points: torch.Tensor) -> torch.Tensor:
+ """Compute depth of points to the camera plane.
+
+ Args:
+ points ([torch.Tensor]): shape should be (batch_size, ..., 3).
+
+ Returns:
+ torch.Tensor: shape will be (batch_size, 1)
+ """
+ world_to_view_transform = self.get_world_to_view_transform()
+ world_to_view_points = world_to_view_transform.transform_points(
+ points.to(self.device))
+ return world_to_view_points[..., 2:3]
+
+ def compute_normal_of_meshes(self, meshes: Meshes) -> torch.Tensor:
+ """Compute normal of meshes in the camera view.
+
+ Args:
+ points ([torch.Tensor]): shape should be (batch_size, 3).
+
+ Returns:
+ torch.Tensor: shape will be (batch_size, 1)
+ """
+ world_to_view_transform = self.get_world_to_view_transform()
+ world_to_view_normals = world_to_view_transform.transform_normals(
+ meshes.verts_normals_padded().to(self.device))
+ return world_to_view_normals
+
+ def __repr__(self):
+ """Rewrite __repr__
+
+ Returns:
+ str: print the information of cameras (N, in_ndc, device).
+ """
+ main_str = super().__repr__()
+ main_str = main_str.split(')')[0]
+ main_str += f'N: {self.__len__()}, in_ndc: {self.in_ndc()}, '
+ main_str += f'device: {self.device})'
+ return main_str
+
+ def get_image_size(self):
+ """Returns the image size, if provided, expected in the form of
+ (height, width) The image size is used for conversion of projected
+ points to screen coordinates."""
+ if hasattr(self, 'image_size'):
+ image_size = self.image_size
+ if hasattr(self, 'resolution'):
+ if self.resolution is not None:
+ image_size = self.resolution
+ else:
+ image_size = None
+
+ return image_size
+
+ def __getitem__(
+ self, index: Union[slice, int, torch.Tensor, List,
+ Tuple]) -> 'MMCamerasBase':
+ """Slice the cameras by batch dim.
+
+ Args:
+ index (Union[slice, int, torch.Tensor, List, Tuple]):
+ index for slicing.
+
+ Returns:
+ MMCamerasBase: sliced cameras.
+ """
+ if isinstance(index, int):
+ index = [index]
+ return self.__class__(K=self.K[index],
+ R=self.R[index],
+ T=self.T[index],
+ image_size=self.get_image_size()[index]
+ if self.get_image_size() is not None else None,
+ in_ndc=self.in_ndc(),
+ convention='pytorch3d',
+ device=self.device)
+
+ def extend(self, N) -> 'MMCamerasBase':
+ """Create new camera class which contains each input camera N times.
+
+ Args:
+ N: number of new copies of each camera.
+
+ Returns:
+ MMCamerasBase object.
+ """
+ return self.__class__(K=self.K.repeat(N, 1, 1),
+ R=self.R.repeat(N, 1, 1),
+ T=self.T.repeat(N, 1),
+ image_size=self.get_image_size(),
+ in_ndc=self.in_ndc(),
+ convention='pytorch3d',
+ device=self.device)
+
+ def extend_(self, N):
+ """extend camera inplace."""
+ self.K = self.K.repeat(N, 1, 1)
+ self.R = self.R.repeat(N, 1, 1)
+ self.T = self.T.repeat(N, 1)
+ self._N = self._N * N
+
+ @classmethod
+ def get_default_projection_matrix(cls, ):
+ """Class method. Calculate the projective transformation matrix by
+ default parameters.
+
+ Args:
+ **kwargs: parameters for the projection can be passed in as keyword
+ arguments to override the default values set in `__init__`.
+
+ Return:
+ a `torch.Tensor` which represents a batch of projection matrices K
+ of shape (N, 4, 4)
+ """
+ raise NotImplementedError()
+
+ def to_screen_(self, **kwargs) -> 'MMCamerasBase':
+ """Convert to screen inplace."""
+ if self.in_ndc():
+ if self.get_image_size() is None:
+ self.image_size = kwargs.get('image_size')
+ else:
+ self.image_size = self.get_image_size()
+ self.K = convert_ndc_to_screen(K=self.K,
+ resolution=self.image_size,
+ is_perspective=self._is_perspective)
+ self._in_ndc = False
+ else:
+ print('Redundant operation, already in screen.')
+
+ def to_ndc_(self, **kwargs) -> 'MMCamerasBase':
+ """Convert to ndc inplace."""
+ if self.in_ndc():
+ print('Redundant operation, already in ndc.')
+ else:
+ if self.get_image_size() is None:
+ self.image_size = kwargs.get('image_size')
+ else:
+ self.image_size = self.get_image_size()
+ self.K = convert_screen_to_ndc(K=self.K,
+ resolution=self.image_size,
+ is_perspective=self._is_perspective)
+ self._in_ndc = True
+
+ def to_screen(self, **kwargs) -> 'MMCamerasBase':
+ """Convert to screen."""
+ if self.in_ndc():
+ if self.get_image_size() is None:
+ self.image_size = kwargs.get('image_size')
+ else:
+ self.image_size = self.get_image_size()
+
+ K = convert_ndc_to_screen(K=self.K,
+ resolution=self.image_size,
+ is_perspective=self._is_perspective)
+ return self.__class__(K=K,
+ R=self.R,
+ T=self.T,
+ in_ndc=False,
+ resolution=self.image_size)
+ else:
+ print('Redundant operation, already in screen.')
+
+ def to_ndc(self, **kwargs) -> 'MMCamerasBase':
+ """Convert to ndc."""
+ if self.in_ndc():
+ print('Redundant operation, already in ndc.')
+ else:
+ if self.get_image_size() is None:
+ self.image_size = kwargs.get('image_size')
+ else:
+ self.image_size = self.get_image_size()
+ K = convert_screen_to_ndc(K=self.K,
+ resolution=self.image_size,
+ is_perspective=self._is_perspective)
+ return self.__class__(K=K,
+ R=self.R,
+ T=self.T,
+ in_ndc=True,
+ resolution=self.image_size)
+
+ def detach(self) -> 'MMCamerasBase':
+ image_size = self.image_size.detach(
+ ) if self.image_size is not None else None
+ return self.__class__(K=self.K.detach(),
+ R=self.R.detach(),
+ T=self.T.detach(),
+ in_ndc=self.in_ndc(),
+ device=self.device,
+ resolution=image_size)
+
+ def concat(self, others) -> 'MMCamerasBase':
+ if isinstance(others, type(self)):
+ others = [others]
+ else:
+ raise TypeError('Could only concat with same type cameras.')
+ return concat_cameras([self] + others)
+
+
+@CAMERAS.register_module(name=('WeakPerspectiveCameras', 'WeakPerspective',
+ 'weakperspective'))
+class WeakPerspectiveCameras(MMCamerasBase):
+ """Inherited from [Pytorch3D cameras](https://github.com/facebookresearch/
+ pytorch3d/blob/main/pytorch3d/renderer/cameras.py) and mimiced the code
+ style. And re-inmplemented functions: compute_projection_matrix,
+ get_projection_transform, unproject_points, is_perspective, in_ndc for
+ render.
+
+ K modified from [VIBE](https://github.com/mkocabas/VIBE/blob/master/
+ lib/utils/renderer.py) and changed to opencv convention.
+ Original license please see docs/additional_license/md.
+
+ This intrinsic matrix is orthographics indeed, but could serve as
+ weakperspective for single smpl mesh.
+ """
+ def __init__(
+ self,
+ scale_x: Union[torch.Tensor, float] = 1.0,
+ scale_y: Union[torch.Tensor, float] = 1.0,
+ transl_x: Union[torch.Tensor, float] = 0.0,
+ transl_y: Union[torch.Tensor, float] = 0.0,
+ znear: Union[torch.Tensor, float] = -1.0,
+ aspect_ratio: Union[torch.Tensor, float] = 1.0,
+ K: Optional[torch.Tensor] = None,
+ R: Optional[torch.Tensor] = None,
+ T: Optional[torch.Tensor] = None,
+ device: Union[torch.device, str] = 'cpu',
+ convention: str = 'pytorch3d',
+ **kwargs,
+ ):
+ """Initialize. If K is provided, don't need scale_x, scale_y, transl_x,
+ transl_y, znear, aspect_ratio.
+
+ Args:
+ scale_x (Union[torch.Tensor, float], optional):
+ Scale in x direction.
+ Defaults to 1.0.
+ scale_y (Union[torch.Tensor, float], optional):
+ Scale in y direction.
+ Defaults to 1.0.
+ transl_x (Union[torch.Tensor, float], optional):
+ Translation in x direction.
+ Defaults to 0.0.
+ transl_y (Union[torch.Tensor, float], optional):
+ Translation in y direction.
+ Defaults to 0.0.
+ znear (Union[torch.Tensor, float], optional):
+ near clipping plane of the view frustrum.
+ Defaults to -1.0.
+ aspect_ratio (Union[torch.Tensor, float], optional):
+ aspect ratio of the image pixels. 1.0 indicates square pixels.
+ Defaults to 1.0.
+ K (Optional[torch.Tensor], optional): Intrinsic matrix of shape
+ (N, 4, 4). If provided, don't need scale_x, scale_y, transl_x,
+ transl_y, znear, aspect_ratio.
+ Defaults to None.
+ R (Optional[torch.Tensor], optional):
+ Rotation matrix of shape (N, 3, 3).
+ Defaults to None.
+ T (Optional[torch.Tensor], optional):
+ Translation matrix of shape (N, 3).
+ Defaults to None.
+ device (Union[torch.device, str], optional):
+ torch device. Defaults to 'cpu'.
+ """
+ kwargs.update(
+ _in_ndc=True,
+ _is_perspective=False,
+ )
+ kwargs.pop('in_ndc', None)
+ kwargs.pop('is_perspective', None)
+ super().__init__(scale_x=scale_x,
+ scale_y=scale_y,
+ transl_x=transl_x,
+ transl_y=transl_y,
+ znear=znear,
+ aspect_ratio=aspect_ratio,
+ K=K,
+ R=R,
+ T=T,
+ device=device,
+ convention=convention,
+ **kwargs)
+
+ @staticmethod
+ def convert_orig_cam_to_matrix(
+ orig_cam: torch.Tensor,
+ **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute intrinsic camera matrix from orig_cam parameter of smpl.
+
+ .. code-block:: python
+
+ r > 1::
+
+ K = [[sx*r, 0, 0, tx*sx*r],
+ [0, sy, 0, ty*sy],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1]]
+
+ or r < 1::
+
+ K = [[sx, 0, 0, tx*sx],
+ [0, sy/r, 0, ty*sy/r],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1],]
+
+ rotation matrix: (N, 3, 3)::
+
+ [[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]]
+
+ translation matrix: (N, 3)::
+
+ [0, 0, -znear]
+
+ Args:
+ orig_cam (torch.Tensor): shape should be (N, 4).
+ znear (Union[torch.Tensor, float], optional):
+ near clipping plane of the view frustrum.
+ Defaults to 0.0.
+ aspect_ratio (Union[torch.Tensor, float], optional):
+ aspect ratio of the image pixels. 1.0 indicates square pixels.
+ Defaults to 1.0.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ opencv intrinsic matrix: (N, 4, 4)
+ """
+ znear = kwargs.get('znear', -1.0)
+ aspect_ratio = kwargs.get('aspect_ratio', 1.0)
+ _N = orig_cam.shape[0]
+ scale_x, scale_y, transl_x, transl_y = orig_cam[:, 0], orig_cam[:, 1],\
+ orig_cam[:, 2], orig_cam[:, 3]
+ K = torch.zeros((_N, 4, 4), dtype=torch.float32)
+ if aspect_ratio >= 1.0:
+ K[:, 0, 0] = scale_x * aspect_ratio
+ K[:, 1, 1] = scale_y
+ K[:, 0, 3] = transl_x * scale_x * aspect_ratio
+ K[:, 1, 3] = transl_y * scale_y
+ else:
+ K[:, 0, 0] = scale_x
+ K[:, 1, 1] = scale_y / aspect_ratio
+ K[:, 0, 3] = transl_x * scale_x
+ K[:, 1, 3] = transl_y * scale_y / aspect_ratio
+
+ K[:, 3, 3] = 1
+ K[:, 2, 2] = 1
+ R = torch.eye(3, 3)[None].repeat(_N, 1, 1)
+ T = torch.zeros(_N, 3)
+ T[:, 2] = znear
+ return K, R, T
+
+ @staticmethod
+ def convert_K_to_orig_cam(
+ K: torch.Tensor,
+ aspect_ratio: Union[torch.Tensor, float] = 1.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute intrinsic camera matrix from pred camera parameter of smpl.
+
+ Args:
+ K (torch.Tensor):
+ opencv orthographics intrinsic matrix: (N, 4, 4)
+
+ .. code-block:: python
+
+ K = [[sx*r, 0, 0, tx*sx*r],
+ [0, sy, 0, ty*sy],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1],]
+
+ aspect_ratio (Union[torch.Tensor, float], optional):
+ aspect ratio of the image pixels. 1.0 indicates square pixels.
+ Defaults to 1.0.
+
+ Returns:
+
+ orig_cam (torch.Tensor): shape should be (N, 4).
+ """
+ _N = K.shape[0]
+ s_x = K[:, 0, 0] / aspect_ratio
+ s_y = K[:, 1, 1] / aspect_ratio
+ t_x = K[:, 0, 3] / (aspect_ratio * s_x)
+ t_y = K[:, 1, 3] / s_y
+ orig_cam = torch.cat([s_x, s_y, t_x, t_y], -1).view(_N, 4)
+ return orig_cam
+
+ @classmethod
+ def get_default_projection_matrix(cls, **args):
+ """Class method. Calculate the projective transformation matrix by
+ default parameters.
+
+ Args:
+ **kwargs: parameters for the projection can be passed in as keyword
+ arguments to override the default values set in `__init__`.
+
+ Return:
+ a `torch.Tensor` which represents a batch of projection matrices K
+ of shape (N, 4, 4)
+ """
+ orig_cam = args.get('orig_cam', None)
+ scale_x = args.get('scale_x', 1.0)
+ scale_y = args.get('scale_y', 1.0)
+ transl_x = args.get('transl_x', 0.0)
+ transl_y = args.get('transl_y', 0.0)
+ aspect_ratio = args.get('aspect_ratio', 1.0)
+ batch_size = args.get('batch_size', 1)
+ device = args.get('device', 'cpu')
+
+ if orig_cam is not None:
+ K, _, _ = cls.convert_orig_cam_to_matrix(orig_cam, **args)
+ else:
+ K = torch.zeros((1, 4, 4), dtype=torch.float32)
+
+ K[:, 0, 0] = scale_x * aspect_ratio
+ K[:, 1, 1] = scale_y
+ K[:, 3, 3] = 1
+ K[:, 0, 3] = transl_x * scale_x * aspect_ratio
+ K[:, 1, 3] = transl_y * scale_y
+ K[:, 2, 2] = 1
+ K = K.repeat(batch_size, 1, 1).to(device)
+ return K
+
+ def compute_projection_matrix(self, scale_x, scale_y, transl_x, transl_y,
+ aspect_ratio) -> torch.Tensor:
+ """Compute the calibration matrix K of shape (N, 4, 4)
+
+ Args:
+ scale_x (Union[torch.Tensor, float], optional):
+ Scale in x direction.
+ scale_y (Union[torch.Tensor, float], optional):
+ Scale in y direction.
+ transl_x (Union[torch.Tensor, float], optional):
+ Translation in x direction.
+ transl_y (Union[torch.Tensor, float], optional):
+ Translation in y direction.
+ aspect_ratio (Union[torch.Tensor, float], optional):
+ aspect ratio of the image pixels. 1.0 indicates square pixels.
+
+ Returns:
+ torch.FloatTensor of the calibration matrix with shape (N, 4, 4)
+ """
+ K = torch.zeros((self._N, 4, 4),
+ dtype=torch.float32,
+ device=self.device)
+
+ K[:, 0, 0] = scale_x * aspect_ratio
+ K[:, 1, 1] = scale_y
+ K[:, 3, 3] = 1
+ K[:, 0, 3] = transl_x * scale_x * aspect_ratio
+ K[:, 1, 3] = transl_y * scale_y
+ K[:, 2, 2] = 1
+ return K
+
+ def get_projection_transform(self, **kwargs) -> Transform3d:
+ """Calculate the orthographic projection matrix. Use column major
+ order.
+
+ Args:
+ **kwargs: parameters for the projection can be passed in to
+ override the default values set in __init__.
+ Return:
+ a Transform3d object which represents a batch of projection
+ matrices of shape (N, 4, 4)
+ """
+ K = kwargs.get('K', self.K)
+ if K is not None:
+ if K.shape != (self._N, 4, 4):
+ msg = f'Expected K to have shape of ({self._N}, 4, 4)'
+ raise ValueError(msg)
+ else:
+ K = self.compute_projection_matrix(
+ kwargs.get('scale_x', self.scale_x),
+ kwargs.get('scale_y', self.scale_y),
+ kwargs.get('transl_x', self.trans_x),
+ kwargs.get('transl_y', self.trans_y),
+ kwargs.get('aspect_ratio', self.aspect_ratio))
+
+ transform = Transform3d(matrix=K.transpose(1, 2).contiguous(),
+ device=self.device)
+ return transform
+
+ def unproject_points(self,
+ xy_depth: torch.Tensor,
+ world_coordinates: bool = True,
+ **kwargs) -> torch.Tensor:
+ """Sends points from camera coordinates (NDC or screen) back to camera
+ view or world coordinates depending on the `world_coordinates` boolean
+ argument of the function."""
+ if world_coordinates:
+ to_camera_transform = self.get_full_projection_transform(**kwargs)
+ else:
+ to_camera_transform = self.get_projection_transform(**kwargs)
+
+ unprojection_transform = to_camera_transform.inverse()
+ return unprojection_transform.transform_points(xy_depth)
+
+ def is_perspective(self):
+ """Boolean of whether is perspective."""
+ return False
+
+ def in_ndc(self):
+ """Boolean of whether in NDC."""
+ return True
+
+ def to_ndc_(self, **kwargs):
+ """Not implemented."""
+ raise NotImplementedError()
+
+ def to_screen_(self, **kwargs):
+ """Not implemented."""
+ raise NotImplementedError()
+
+ def to_ndc(self, **kwargs):
+ """Not implemented."""
+ raise NotImplementedError()
+
+ def to_screen(self, **kwargs):
+ """Not implemented."""
+ raise NotImplementedError()
+
+
+@CAMERAS.register_module(name=('PerspectiveCameras', 'perspective',
+ 'Perspective'))
+class PerspectiveCameras(cameras.PerspectiveCameras, MMCamerasBase):
+ """Inherited from Pytorch3D `PerspectiveCameras`."""
+ def __init__(
+ self,
+ focal_length=1.0,
+ principal_point=((0.0, 0.0), ),
+ R: Optional[torch.Tensor] = None,
+ T: Optional[torch.Tensor] = None,
+ K: Optional[torch.Tensor] = None,
+ device: Union[torch.device, str] = 'cpu',
+ in_ndc: bool = True,
+ convention: str = 'pytorch3d',
+ image_size: Optional[Union[List, Tuple, torch.Tensor]] = None,
+ **kwargs,
+ ) -> None:
+ """
+ Args:
+ focal_length (float, torch.Tensor, optional): Defaults to 1.0.
+ principal_point (tuple, optional): Defaults to ((0.0, 0.0), ).
+ R (Optional[torch.Tensor], optional): Defaults to None.
+ T (Optional[torch.Tensor], optional): Defaults to None.
+ K (Optional[torch.Tensor], optional): Defaults to None.
+ device (Union[torch.device, str], optional): Defaults to 'cpu'.
+ in_ndc (bool, optional): Defaults to True.
+ convention (str, optional): Defaults to 'pytorch3d'.
+ image_size (Optional[Union[List, Tuple, torch.Tensor]], optional):
+ Defaults to None.
+
+ """
+ if image_size is not None:
+ kwargs.update({'image_size': image_size})
+ kwargs.update(
+ _in_ndc=in_ndc,
+ _is_perspective=True,
+ )
+ kwargs.pop('is_perspective', None)
+ kwargs.pop('in_ndc', None)
+
+ super(cameras.PerspectiveCameras,
+ self).__init__(device=device,
+ focal_length=focal_length,
+ principal_point=principal_point,
+ R=R,
+ T=T,
+ K=K,
+ convention=convention,
+ **kwargs)
+ if image_size is not None:
+ if (self.image_size < 1).any(): # pyre-ignore
+ raise ValueError('Image_size provided has invalid values')
+ else:
+ self.image_size = None
+
+ def __getitem__(self, index: Union[slice, int, torch.Tensor, List, Tuple]):
+ """Slice the cameras by batch dim.
+
+ Args:
+ index (Union[slice, int, torch.Tensor, List, Tuple]):
+ index for slicing.
+
+ Returns:
+ MMCamerasBase: sliced cameras.
+ """
+ return super(cameras.PerspectiveCameras, self).__getitem__(index)
+
+ @classmethod
+ def get_default_projection_matrix(cls, **args) -> torch.Tensor:
+ """Class method. Calculate the projective transformation matrix by
+ default parameters.
+
+ Args:
+ **kwargs: parameters for the projection can be passed in as keyword
+ arguments to override the default values set in `__init__`.
+
+ Return:
+ a `torch.Tensor` which represents a batch of projection matrices K
+ of shape (N, 4, 4)
+ """
+ batch_size = args.get('batch_size', 1)
+ device = args.get('device', 'cpu')
+ focal_length = args.get('focal_length')
+ principal_point = args.get('principal_point')
+
+ return cameras._get_sfm_calibration_matrix(
+ N=batch_size,
+ device=device,
+ focal_length=focal_length,
+ principal_point=principal_point,
+ orthographic=False)
+
+ def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
+ kwargs.pop('cameras', None)
+ return super().get_ndc_camera_transform(**kwargs)
+
+ def transform_points_screen(self,
+ points,
+ eps: Optional[float] = None,
+ **kwargs) -> torch.Tensor:
+ kwargs.pop('cameras', None)
+ return super().transform_points_screen(points, eps, **kwargs)
+
+
+@CAMERAS.register_module(name=('FoVPerspectiveCameras', 'FoVPerspective',
+ 'fovperspective'))
+class FoVPerspectiveCameras(cameras.FoVPerspectiveCameras, MMCamerasBase):
+ """Inherited from Pytorch3D `FoVPerspectiveCameras`."""
+ def __init__(
+ self,
+ znear=1.0,
+ zfar=100.0,
+ aspect_ratio=1.0,
+ fov=60.0,
+ degrees: bool = True,
+ R: Optional[torch.Tensor] = None,
+ T: Optional[torch.Tensor] = None,
+ K: Optional[torch.Tensor] = None,
+ device: Union[torch.device, str] = 'cpu',
+ convention: str = 'pytorch3d',
+ **kwargs,
+ ) -> None:
+ """Initialize a camera.
+
+ Args:
+ znear (float, optional): Defaults to 1.0.
+ zfar (float, optional): Defaults to 100.0.
+ aspect_ratio (float, optional): Defaults to 1.0.
+ fov (float, optional): Defaults to 60.0.
+ degrees (bool, optional): Defaults to True.
+ R (Optional[torch.Tensor], optional): Defaults to None.
+ T (Optional[torch.Tensor], optional): Defaults to None.
+ K (Optional[torch.Tensor], optional): Defaults to None.
+ device (Union[torch.device, str], optional): Defaults to 'cpu'.
+ convention (str, optional): Defaults to 'pytorch3d'.
+ """
+ kwargs.update(
+ _in_ndc=True,
+ _is_perspective=True,
+ )
+ kwargs.pop('in_ndc', None)
+ kwargs.pop('is_perspective', None)
+ super(cameras.FoVPerspectiveCameras, self).__init__(
+ device=device,
+ znear=znear,
+ zfar=zfar,
+ aspect_ratio=aspect_ratio,
+ fov=fov,
+ R=R,
+ T=T,
+ K=K,
+ convention=convention,
+ **kwargs,
+ )
+ self.degrees = degrees
+
+ def __getitem__(self, index: Union[slice, int, torch.Tensor, List, Tuple]):
+ """Slice the cameras by batch dim.
+
+ Args:
+ index (Union[slice, int, torch.Tensor, List, Tuple]):
+ index for slicing.
+
+ Returns:
+ MMCamerasBase: sliced cameras.
+ """
+ return super(cameras.FoVPerspectiveCameras, self).__getitem__(index)
+
+ def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
+ kwargs.pop('cameras', None)
+ return super().get_ndc_camera_transform(**kwargs)
+
+ def transform_points_screen(self,
+ points,
+ eps: Optional[float] = None,
+ **kwargs) -> torch.Tensor:
+ kwargs.pop('cameras', None)
+ return super().transform_points_screen(points, eps, **kwargs)
+
+ @classmethod
+ def get_default_projection_matrix(cls, **args) -> torch.Tensor:
+ """Class method. Calculate the projective transformation matrix by
+ default parameters.
+
+ Args:
+ **kwargs: parameters for the projection can be passed in as keyword
+ arguments to override the default values set in `__init__`.
+
+ Return:
+ a `torch.Tensor` which represents a batch of projection matrices K
+ of shape (N, 4, 4)
+ """
+ znear = args.get('znear', 1.0)
+ zfar = args.get('zfar', 100.0)
+ aspect_ratio = args.get('aspect_ratio', 1.0)
+ fov = args.get('fov', 60.0)
+ degrees = args.get('degrees', True)
+ batch_size = args.get('batch_size', 1)
+
+ K = torch.zeros((1, 4, 4), dtype=torch.float32)
+ if degrees:
+ fov = (math.pi / 180) * fov
+
+ if not torch.is_tensor(fov):
+ fov = torch.tensor(fov)
+ tanHalfFov = torch.tan((fov / 2))
+ max_y = tanHalfFov * znear
+ min_y = -max_y
+ max_x = max_y * aspect_ratio
+ min_x = -max_x
+
+ z_sign = 1.0
+
+ K[:, 0, 0] = 2.0 * znear / (max_x - min_x)
+ K[:, 1, 1] = 2.0 * znear / (max_y - min_y)
+ K[:, 0, 2] = (max_x + min_x) / (max_x - min_x)
+ K[:, 1, 2] = (max_y + min_y) / (max_y - min_y)
+ K[:, 3, 2] = z_sign
+
+ K[:, 2, 2] = z_sign * zfar / (zfar - znear)
+ K[:, 2, 3] = -(zfar * znear) / (zfar - znear)
+ K = K.repeat(batch_size, 1, 1)
+ return K
+
+ def to_ndc_(self, **kwargs):
+ """Not implemented."""
+ raise NotImplementedError()
+
+ def to_screen_(self, **kwargs):
+ """Not implemented."""
+ raise NotImplementedError()
+
+ def to_ndc(self, **kwargs):
+ """Not implemented."""
+ raise NotImplementedError()
+
+ def to_screen(self, **kwargs):
+ """Not implemented."""
+ raise NotImplementedError()
+
+
+@CAMERAS.register_module(name=('OrthographicCameras', 'Orthographic',
+ 'orthographic'))
+class OrthographicCameras(cameras.OrthographicCameras, MMCamerasBase):
+ """Inherited from Pytorch3D `OrthographicCameras`."""
+ def __init__(
+ self,
+ focal_length=1.0,
+ principal_point=((0.0, 0.0), ),
+ R: Optional[torch.Tensor] = None,
+ T: Optional[torch.Tensor] = None,
+ K: Optional[torch.Tensor] = None,
+ device: Union[torch.Tensor, str] = 'cpu',
+ in_ndc: bool = True,
+ image_size: Optional[torch.Tensor] = None,
+ convention: str = 'pytorch3d',
+ **kwargs,
+ ) -> None:
+ """Initialize OrthographicCameras.
+
+ Args:
+ focal_length (float, optional): Defaults to 1.0.
+ principal_point (tuple, optional): Defaults to ((0.0, 0.0), ).
+ R (Optional[torch.Tensor], optional): Defaults to None.
+ T (Optional[torch.Tensor], optional): Defaults to None.
+ K (Optional[torch.Tensor], optional): Defaults to None.
+ device (Union[torch.Tensor, str], optional): Defaults to 'cpu'.
+ in_ndc (bool, optional): Defaults to True.
+ image_size (Optional[torch.Tensor], optional): Defaults to None.
+ convention (str, optional): Defaults to 'pytorch3d'.
+
+ Raises:
+ ValueError: [description]
+ """
+ if image_size is not None:
+ kwargs.update({'image_size': image_size})
+ kwargs.update(
+ _is_perspective=False,
+ _in_ndc=in_ndc,
+ )
+ kwargs.pop('is_perspective', None)
+ kwargs.pop('in_ndc', None)
+ super(cameras.OrthographicCameras,
+ self).__init__(device=device,
+ focal_length=focal_length,
+ principal_point=principal_point,
+ R=R,
+ T=T,
+ K=K,
+ convention=convention,
+ **kwargs)
+ if image_size is not None:
+ if (self.image_size < 1).any(): # pyre-ignore
+ raise ValueError('Image_size provided has invalid values')
+ else:
+ self.image_size = None
+
+ def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
+ kwargs.pop('cameras', None)
+ return super().get_ndc_camera_transform(**kwargs)
+
+ def transform_points_screen(self,
+ points,
+ eps: Optional[float] = None,
+ **kwargs) -> torch.Tensor:
+ kwargs.pop('cameras', None)
+ return super().transform_points_screen(points, eps, **kwargs)
+
+ def __getitem__(self, index: Union[slice, int, torch.Tensor, List, Tuple]):
+ """Slice the cameras by batch dim.
+
+ Args:
+ index (Union[slice, int, torch.Tensor, List, Tuple]):
+ index for slicing.
+
+ Returns:
+ MMCamerasBase: sliced cameras.
+ """
+ return super(cameras.OrthographicCameras, self).__getitem__(index)
+
+ @classmethod
+ def get_default_projection_matrix(cls, **args) -> torch.Tensor:
+ """Class method. Calculate the projective transformation matrix by
+ default parameters.
+
+ .. code-block:: python
+
+ fx = focal_length[:,0]
+ fy = focal_length[:,1]
+ px = principal_point[:,0]
+ py = principal_point[:,1]
+
+ K = [[fx, 0, 0, px],
+ [0, fy, 0, py],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1],]
+
+ Args:
+ **kwargs: parameters for the projection can be passed in as keyword
+ arguments to override the default values.
+
+ Return:
+ a `torch.Tensor` which represents a batch of projection matrices K
+ of shape (N, 4, 4)
+ """
+ batch_size = args.get('batch_size', 1)
+ device = args.get('device', 'cpu')
+ focal_length = args.get('focal_length')
+ principal_point = args.get('principal_point')
+
+ return cameras._get_sfm_calibration_matrix(
+ N=batch_size,
+ device=device,
+ focal_length=focal_length,
+ principal_point=principal_point,
+ orthographic=True)
+
+
+@CAMERAS.register_module(name=('FoVOrthographicCameras', 'FoVOrthographic',
+ 'fovorthographic'))
+class FoVOrthographicCameras(cameras.FoVOrthographicCameras, MMCamerasBase):
+ """Inherited from Pytorch3D `FoVOrthographicCameras`."""
+ def __init__(
+ self,
+ znear: Union[torch.Tensor, int, float] = 1.0,
+ zfar: Union[torch.Tensor, int, float] = 100.0,
+ max_y: Union[torch.Tensor, int, float] = 1.0,
+ min_y: Union[torch.Tensor, int, float] = -1.0,
+ max_x: Union[torch.Tensor, int, float] = 1.0,
+ min_x: Union[torch.Tensor, int, float] = -1.0,
+ scale_xyz: Union[Iterable[float],
+ Iterable[int]] = ((1.0, 1.0, 1.0), ), # (1, 3)
+ R: Optional[torch.Tensor] = None,
+ T: Optional[torch.Tensor] = None,
+ K: Optional[torch.Tensor] = None,
+ device: Union[torch.device, str] = 'cpu',
+ convention: str = 'pytorch3d',
+ **kwargs):
+ """reimplemented __init__, add `convention`.
+
+ Args:
+ znear (Union[torch.Tensor, int, float], optional):
+ Defaults to 1.0.
+ zfar (Union[torch.Tensor, int, float], optional):
+ Defaults to 100.0.
+ max_y (Union[torch.Tensor, int, float], optional):
+ Defaults to 1.0.
+ min_y (Union[torch.Tensor, int, float], optional):
+ Defaults to -1.0.
+ max_x (Union[torch.Tensor, int, float], optional):
+ Defaults to 1.0.
+ min_x (Union[torch.Tensor, int, float], optional):
+ Defaults to -1.0.
+ scale_xyz (Union[Iterable[float], Iterable[int]], optional):
+ Defaults to ((1.0, 1.0, 1.0), ).
+ T (Optional[torch.Tensor], optional): Defaults to None.
+ K (Optional[torch.Tensor], optional): Defaults to None.
+ device (Union[torch.device, str], optional): Defaults to 'cpu'.
+ convention (str, optional): Defaults to 'pytorch3d'.
+ """
+ kwargs.update(_is_perspective=False, _in_ndc=True)
+ kwargs.pop('in_ndc', None)
+ kwargs.pop('is_perspective', None)
+ super(cameras.FoVOrthographicCameras,
+ self).__init__(device=device,
+ znear=znear,
+ zfar=zfar,
+ max_y=max_y,
+ min_y=min_y,
+ max_x=max_x,
+ min_x=min_x,
+ scale_xyz=scale_xyz,
+ R=R,
+ T=T,
+ K=K,
+ convention=convention,
+ **kwargs)
+
+ def __getitem__(self, index: Union[slice, int, torch.Tensor, List, Tuple]):
+ """Slice the cameras by batch dim.
+
+ Args:
+ index (Union[slice, int, torch.Tensor, List, Tuple]):
+ index for slicing.
+
+ Returns:
+ MMCamerasBase: sliced cameras.
+ """
+ return super(cameras.FoVOrthographicCameras, self).__getitem__(index)
+
+ @classmethod
+ def get_default_projection_matrix(cls, **args) -> torch.Tensor:
+ """Class method. Calculate the projective transformation matrix by
+ default parameters.
+
+ .. code-block:: python
+
+ scale_x = 2 / (max_x - min_x)
+ scale_y = 2 / (max_y - min_y)
+ scale_z = 2 / (far-near)
+ mid_x = (max_x + min_x) / (max_x - min_x)
+ mix_y = (max_y + min_y) / (max_y - min_y)
+ mid_z = (far + near) / (far - near)
+
+ K = [[scale_x, 0, 0, -mid_x],
+ [0, scale_y, 0, -mix_y],
+ [0, 0, -scale_z, -mid_z],
+ [0, 0, 0, 1],]
+
+ Args:
+ **kwargs: parameters for the projection can be passed in as keyword
+ arguments to override the default values.
+
+ Return:
+ a `torch.Tensor` which represents a batch of projection matrices K
+ of shape (N, 4, 4)
+ """
+ znear = args.get('znear', 1.0)
+ zfar = args.get('zfar', 100.0)
+ max_y = args.get('max_y', 1.0)
+ min_y = args.get('min_y', -1.0)
+ max_x = args.get('max_x', 1.0)
+ min_x = args.get('min_x', -1.0)
+ scale_xyz = args.get(
+ 'scale_xyz',
+ ((1.0, 1.0, 1.0), ),
+ )
+ batch_size = args.get('batch_size', 1)
+
+ K = torch.zeros((1, 4, 4), dtype=torch.float32)
+ ones = torch.ones((1), dtype=torch.float32)
+ z_sign = +1.0
+
+ if not isinstance(scale_xyz, torch.Tensor):
+ scale_xyz = torch.Tensor(scale_xyz)
+ K[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0]
+ K[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1]
+ K[:, 0, 3] = -(max_x + min_x) / (max_x - min_x)
+ K[:, 1, 3] = -(max_y + min_y) / (max_y - min_y)
+ K[:, 3, 3] = ones
+
+ # NOTE: This maps the z coordinate to the range [0, 1] and replaces the
+ # the OpenGL z normalization to [-1, 1]
+ K[:, 2, 2] = z_sign * (1.0 / (zfar - znear)) * scale_xyz[:, 2]
+ K[:, 2, 3] = -znear / (zfar - znear)
+ K = K.repeat(batch_size, 1, 1)
+ return K
+
+ def to_ndc_(self, **kwargs):
+ """Not implemented."""
+ raise NotImplementedError()
+
+ def to_screen_(self, **kwargs):
+ """Not implemented."""
+ raise NotImplementedError()
+
+ def to_ndc(self, **kwargs):
+ """Not implemented."""
+ raise NotImplementedError()
+
+ def to_screen(self, **kwargs):
+ """Not implemented."""
+ raise NotImplementedError()
+
+ def get_ndc_camera_transform(self, **kwargs) -> Transform3d:
+ kwargs.pop('cameras', None)
+ return super().get_ndc_camera_transform(**kwargs)
+
+ def transform_points_screen(self,
+ points,
+ eps: Optional[float] = None,
+ **kwargs) -> torch.Tensor:
+ kwargs.pop('cameras', None)
+ return super().transform_points_screen(points, eps, **kwargs)
+
+
+def concat_cameras(cameras_list: List[MMCamerasBase]) -> MMCamerasBase:
+ """Concat a list of cameras of the same type.
+
+ Args:
+ cameras_list (List[cameras.CamerasBase]): a list of cameras.
+
+ Returns:
+ MMCamerasBase: the returned cameras concated following the batch
+ dim.
+ """
+ K = []
+ R = []
+ T = []
+ is_perspective = cameras_list[0].is_perspective()
+ in_ndc = cameras_list[0].in_ndc()
+ cam_cls = type(cameras_list[0])
+ image_size = cameras_list[0].get_image_size()
+ device = cameras_list[0].device
+ for cam in cameras_list:
+ assert type(cam) is cam_cls
+ assert cam.in_ndc() is in_ndc
+ assert cam.is_perspective() is is_perspective
+ assert cam.device is device
+ K.append(cam.K)
+ R.append(cam.R)
+ T.append(cam.T)
+ K = torch.cat(K)
+ R = torch.cat(R)
+ T = torch.cat(T)
+ concated_cameras = cam_cls(K=K,
+ R=R,
+ T=T,
+ device=device,
+ is_perspective=is_perspective,
+ in_ndc=in_ndc,
+ image_size=image_size)
+ return concated_cameras
+
+
+def compute_orbit_cameras(
+ K: Union[torch.Tensor, np.ndarray, None] = None,
+ elev: float = 0,
+ azim: float = 0,
+ dist: float = 2.7,
+ at: Union[torch.Tensor, List, Tuple] = (0, 0, 0),
+ batch_size: int = 1,
+ orbit_speed: Union[float, Tuple[float, float]] = 0,
+ dist_speed: Optional[float] = 0,
+ convention: str = 'pytorch3d',
+) -> Union[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Generate a sequence of moving cameras following an orbit.
+
+ Args:
+ K (Union[torch.Tensor, np.ndarray, None], optional):
+ Intrinsic matrix. Will generate a default K if None.
+ Defaults to None.
+ elev (float, optional): This is the angle between the
+ vector from the object to the camera, and the horizontal
+ plane y = 0 (xz-plane).
+ Defaults to 0.
+ azim (float, optional): angle in degrees or radians. The vector
+ from the object to the camera is projected onto a horizontal
+ plane y = 0. azim is the angle between the projected vector and a
+ reference vector at (0, 0, 1) on the reference plane (the
+ horizontal plane).
+ Defaults to 0.
+ dist (float, optional): distance of the camera from the object.
+ Defaults to 2.7.
+ at (Union[torch.Tensor, List, Tuple], optional):
+ the position of the object(s) in world coordinates.
+ Defaults to (0, 0, 0).
+ batch_size (int, optional): number of frames. Defaults to 1.
+ orbit_speed (Union[float, Tuple[float, float]], optional):
+ degree speed of camera moving along the orbit.
+ Could be one or two number. One number for only elev speed,
+ two number for both.
+ Defaults to 0.
+ dist_speed (Optional[float], optional):
+ speed of camera moving along the center line.
+ Defaults to 0.
+ convention (str, optional): Camera convention. Defaults to 'pytorch3d'.
+
+ Returns:
+ Union[torch.Tensor, torch.Tensor, torch.Tensor]: computed K, R, T.
+ """
+ if not isinstance(orbit_speed, Iterable):
+ orbit_speed = (orbit_speed, 0.0)
+ if not isinstance(at, torch.Tensor):
+ at = torch.Tensor(at)
+ at = at.view(1, 3)
+ if batch_size > 1 and orbit_speed[0] != 0:
+ azim = torch.linspace(azim, azim + batch_size * orbit_speed[0],
+ batch_size)
+ if batch_size > 1 and orbit_speed[1] != 0:
+ elev = torch.linspace(elev, elev + batch_size * orbit_speed[1],
+ batch_size)
+ if batch_size > 1 and dist_speed != 0:
+ dist = torch.linspace(dist, dist + batch_size * dist_speed, batch_size)
+
+ if convention == 'opencv':
+ rotation_compensate = ee_to_rotmat(
+ torch.Tensor([math.pi, 0, 0]).view(1, 3))
+ at = rotation_compensate.permute(0, 2, 1) @ at.view(-1, 3, 1)
+ at = at.view(1, 3)
+ R, T = cameras.look_at_view_transform(dist=dist,
+ elev=elev,
+ azim=azim,
+ at=at)
+ if K is None:
+ K = FoVPerspectiveCameras.get_default_projection_matrix(
+ batch_size=batch_size)
+ if convention == 'opencv':
+ rotation_compensate = ee_to_rotmat(
+ torch.Tensor([math.pi, 0, 0]).view(1, 3))
+ R = rotation_compensate.permute(0, 2, 1) @ R
+ return K, R, T
+
+
+def compute_direction_cameras(
+ K: Union[torch.Tensor, np.ndarray, None] = None,
+ at: Union[torch.Tensor, List, Tuple, None] = None,
+ eye: Union[torch.Tensor, List, Tuple, None] = None,
+ plane: Union[Iterable[torch.Tensor], None] = None,
+ dist: float = 1.0,
+ batch_size: int = 1,
+ dist_speed: float = 0.0,
+ z_vec: Union[torch.Tensor, List, Tuple, None] = None,
+ y_vec: Union[torch.Tensor, List, Tuple] = (0, 1, 0),
+ convention: str = 'pytorch3d',
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Generate a sequence of moving cameras along a direction.
+ We need a `z_vec`, `y_vec` to generate `x_vec` so as to get the `R` matrix.
+ And we need `eye` as `T` matrix.
+ `K` matrix could be set or use default.
+ We recommend `y_vec` as default (0, 1, 0), and it will be orthogonal
+ decomposed. The `x_vec` will be generated by cross production from
+ `y_vec` and `x_vec`.
+ You can set `z_vec` by: 1. set `at`, `dist`, `dist_speed`, `plane`,
+ `batch_size` to get `eye`, then get `z_vec`.
+ 2. set `at`, `eye` directly and get `z_vec`.
+ 3. set `z_vec` directly and:
+ 1). set `eye` and `dist`.
+ 2). set `at`, `dist`, `dist_speed`,
+ `batch_size` then get `eye`.
+ When we have `eye`, `z_vec`, `y_vec`, we will have `R` and `T`.
+
+ Args:
+ K (Union[torch.Tensor, np.ndarray, None], optional):
+ Intrinsic matrix. Will generate a default K if None.
+ Defaults to None.
+ at (Union[torch.Tensor, List, Tuple], optional):
+ the position of the object(s) in world coordinates.
+ Required.
+ Defaults to None.
+ eye (Union[torch.Tensor, List, Tuple], optional):
+ the position of the camera(s) in world coordinates.
+ If eye is not None, it will override the camera position derived
+ from plane, dist, dist_speed.
+ Defaults to None.
+ plane (Optional[Iterable[torch.Tensor, List, Tuple]], optional):
+ The plane of your z direction normal.
+ Should be a tuple or list containing two vectors of shape (N, 3).
+ Defaults to None.
+ dist (float, optional): distance to at.
+ Defaults to 1.0.
+ dist_speed (float, optional): distance moving speed.
+ Defaults to 1.0.
+ batch_size (int, optional): number of frames.
+ Defaults to 1.
+ z_vec (Union[torch.Tensor, List, Tuple], optional):
+ z direction of shape (-1, 3). If z_vec is not None, it will
+ override plane, dist, dist_speed.
+ Defaults to None.
+ y_vec (Union[torch.Tensor, List, Tuple], optional):
+ Will only be used when z_vec is used.
+ Defaults to (0, 1, 0).
+ convention (str, optional): Camera convention.
+ Defaults to 'pytorch3d'.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: computed K, R, T.
+ """
+ def norm_vec(vec):
+ return vec / torch.sqrt((vec * vec).sum())
+
+ if z_vec is None:
+ assert at is not None
+ at = torch.Tensor(at).view(-1, 3)
+ if eye is None:
+ assert plane is not None
+ dist = torch.linspace(dist, dist + batch_size * dist_speed,
+ batch_size)
+ vec1 = torch.Tensor(plane[0]).view(-1, 3)
+ norm_vec1 = norm_vec(vec1)
+ vec2 = torch.Tensor(plane[1]).view(-1, 3)
+ norm_vec2 = norm_vec(vec2)
+ norm = torch.cross(norm_vec1, norm_vec2)
+ normed_norm = norm_vec(norm)
+ eye = at + normed_norm * dist
+ else:
+ eye = torch.Tensor(eye).view(-1, 3)
+ norm = eye - at
+ normed_norm = norm_vec(norm)
+
+ z_vec = -normed_norm
+ else:
+ z_vec = torch.Tensor(z_vec).view(-1, 3)
+ z_vec = norm_vec(z_vec)
+ if eye is None:
+ assert at is not None
+ at = torch.Tensor(at).view(-1, 3)
+ dist = torch.linspace(dist, dist + batch_size * dist_speed,
+ batch_size)
+ eye = -z_vec * dist + at
+ eye = torch.Tensor(eye).view(-1, 3)
+ assert eye is not None
+ z_vec = norm_vec(z_vec)
+ normed_norm = -z_vec
+
+ z_vec = z_vec.view(-1, 3)
+ y_vec = torch.Tensor(y_vec).view(-1, 3)
+
+ y_vec = y_vec - torch.bmm(y_vec.view(-1, 1, 3), z_vec.view(-1, 3, 1)).view(
+ -1, 1) * z_vec
+ y_vec = norm_vec(y_vec)
+ x_vec = torch.cross(y_vec, z_vec)
+ R = torch.cat(
+ [x_vec.view(-1, 3, 1),
+ y_vec.view(-1, 3, 1),
+ z_vec.view(-1, 3, 1)], 1).view(-1, 3, 3)
+ T = eye
+
+ R = R.permute(0, 2, 1)
+ _, T = convert_world_view(R=R, T=T)
+
+ if K is None:
+ K = FoVPerspectiveCameras.get_default_projection_matrix(
+ batch_size=batch_size)
+ K, R, T = convert_camera_matrix(K=K,
+ R=R,
+ T=T,
+ is_perspective=True,
+ convention_src='pytorch3d',
+ convention_dst=convention)
+ return K, R, T
diff --git a/detrsmpl/core/conventions/__init__.py b/detrsmpl/core/conventions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/core/conventions/cameras/__init__.py b/detrsmpl/core/conventions/cameras/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/core/conventions/cameras/convert_convention.py b/detrsmpl/core/conventions/cameras/convert_convention.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb76635a81fdd1b4d005a0f6dee5c64e5f2299f9
--- /dev/null
+++ b/detrsmpl/core/conventions/cameras/convert_convention.py
@@ -0,0 +1,649 @@
+import warnings
+from typing import Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from detrsmpl.utils.transforms import ee_to_rotmat, rotmat_to_ee
+
+CAMERA_CONVENTIONS = {
+ 'pytorch3d': {
+ 'axis': '-xyz',
+ 'left_mm_extrinsic': False,
+ 'view_to_world': False,
+ 'left_mm_intrinsic': True,
+ },
+ 'pyrender': {
+ 'axis': 'xy-z',
+ 'left_mm_extrinsic': True,
+ 'view_to_world': False,
+ 'left_mm_intrinsic': True,
+ },
+ 'opengl': {
+ 'axis': 'xy-z',
+ 'left_mm_extrinsic': True,
+ 'view_to_world': False,
+ 'left_mm_intrinsic': True,
+ },
+ 'open3d': {
+ 'axis': 'x-yz',
+ 'left_mm_extrinsic': False,
+ 'view_to_world': False,
+ 'left_mm_intrinsic': False,
+ },
+ 'opencv': {
+ 'axis': 'x-yz',
+ 'left_mm_extrinsic': True,
+ 'view_to_world': True,
+ 'left_mm_intrinsic': True,
+ },
+ 'unity': {
+ 'axis': 'xyz',
+ 'left_mm_extrinsic': True,
+ 'view_to_world': False,
+ 'left_mm_intrinsic': True,
+ },
+ 'blender': {
+ 'axis': 'xy-z',
+ 'left_mm_extrinsic': True,
+ 'view_to_world': False,
+ 'left_mm_intrinsic': True,
+ },
+ 'maya': {
+ 'axis': 'xy-z',
+ 'left_mm_extrinsic': True,
+ 'view_to_world': False,
+ 'left_mm_intrinsic': True,
+ }
+}
+
+
+def enc_camera_convention(convention, camera_conventions=CAMERA_CONVENTIONS):
+ """convert camera convention to axis direction and order."""
+ if convention in camera_conventions:
+ convention = camera_conventions[convention]['axis']
+ else:
+ assert set(convention).issubset(
+ {'x', 'y', 'z', '+',
+ '-'}), 'Wrong convention string, choose either in'
+ f'set({camera_conventions.keys()}) or define by xyz.'
+ sign = [1, 1, 1]
+ convention = '_' + convention
+ count = 0
+ axis_order = ''
+ for i in range(len(convention)):
+ if convention[i] in 'xyz':
+ axis_order += convention[i]
+ if convention[i - 1] == '-':
+ sign[count] *= -1
+ count += 1
+ return sign, axis_order
+
+
+def convert_camera_matrix(
+ K: Optional[Union[torch.Tensor, np.ndarray]] = None,
+ R: Optional[Union[torch.Tensor, np.ndarray]] = None,
+ T: Optional[Union[torch.Tensor, np.ndarray]] = None,
+ is_perspective: bool = True,
+ convention_src: str = 'opencv',
+ convention_dst: str = 'pytorch3d',
+ in_ndc_src: bool = True,
+ in_ndc_dst: bool = True,
+ resolution_src: Optional[Union[int, Tuple[int, int], torch.Tensor,
+ np.ndarray]] = None,
+ resolution_dst: Optional[Union[int, Tuple[int, int], torch.Tensor,
+ np.ndarray]] = None,
+ camera_conventions: dict = CAMERA_CONVENTIONS,
+) -> Tuple[Union[torch.Tensor, np.ndarray], Union[torch.Tensor, np.ndarray],
+ Union[torch.Tensor, np.ndarray]]:
+ """Convert the intrinsic matrix K and extrinsic matrix [R|T] from source
+ convention to destination convention.
+
+ Args:
+ K (Union[torch.Tensor, np.ndarray]): Intrinsic matrix,
+ shape should be (batch_size, 4, 4) or (batch_size, 3, 3).
+ Will be ignored if None.
+ R (Optional[Union[torch.Tensor, np.ndarray]], optional):
+ Extrinsic rotation matrix. Shape should be (batch_size, 3, 3).
+ Will be identity if None.
+ Defaults to None.
+ T (Optional[Union[torch.Tensor, np.ndarray]], optional):
+ Extrinsic translation matrix. Shape should be (batch_size, 3).
+ Will be zeros if None.
+ Defaults to None.
+ is_perspective (bool, optional): whether is perspective projection.
+ Defaults to True.
+
+ _____________________________________________________________________
+ # Camera dependent args
+ convention_src (str, optional): convention of source camera,
+ convention_dst (str, optional): convention of destination camera,
+
+ We define the convention of cameras by the order of right, front and
+ up.
+ E.g., the first one is pyrender and its convention should be
+ '+x+z+y'. '+' could be ignored.
+ The second one is opencv and its convention should be '+x-z-y'.
+ The third one is pytorch3d and its convention should be '-xzy'.
+ opengl(pyrender) opencv pytorch3d
+ y z y
+ | / |
+ | / |
+ |_______x /________x x________ |
+ / | /
+ / | /
+ z / y | z /
+
+ in_ndc_src (bool, optional): Whether is the source camera defined
+ in ndc.
+ Defaults to True.
+ in_ndc_dst (bool, optional): Whether is the destination camera defined
+ in ndc.
+ Defaults to True.
+
+ in camera_convention, we define these args as:
+ 1). `left_mm_ex` means extrinsic matrix `K` is left matrix
+ multiplcation defined.
+ 2). `left_mm_in` means intrinsic matrix [`R`| `T`] is left
+ matrix multiplcation defined.
+ 3) `view_to_world` means extrinsic matrix [`R`| `T`] is defined
+ as view to world.
+
+ resolution_src (Optional[Union[int, Tuple[int, int], torch.Tensor,
+ np.ndarray]], optional):
+ Source camera image size of (height, width).
+ Required if defined in screen.
+ Will be square if int.
+ Shape should be (2,) if `array` or `tensor`.
+ Defaults to None.
+ resolution_dst (Optional[Union[int, Tuple[int, int], torch.Tensor,
+ np.ndarray]], optional):
+ Destination camera image size of (height, width).
+ Required if defined in screen.
+ Will be square if int.
+ Shape should be (2,) if `array` or `tensor`.
+ Defaults to None.
+ camera_conventions: (dict, optional): `dict` containing
+ pre-defined camera convention information.
+ Defaults to CAMERA_CONVENTIONS.
+
+ Raises:
+ TypeError: K, R, T should all be `torch.Tensor` or `np.ndarray`.
+
+ Returns:
+ Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None],
+ Union[torch.Tensor, None]]:
+ Converted K, R, T matrix of `tensor`.
+ """
+ convention_dst = convention_dst.lower()
+ convention_src = convention_src.lower()
+
+ assert convention_dst in CAMERA_CONVENTIONS
+ assert convention_src in CAMERA_CONVENTIONS
+
+ left_mm_ex_src = CAMERA_CONVENTIONS[convention_src].get(
+ 'left_mm_extrinsic', True)
+ view_to_world_src = CAMERA_CONVENTIONS[convention_src].get(
+ 'view_to_world', False)
+ left_mm_in_src = CAMERA_CONVENTIONS[convention_src].get(
+ 'left_mm_intrinsic', False)
+
+ left_mm_ex_dst = CAMERA_CONVENTIONS[convention_dst].get(
+ 'left_mm_extrinsic', True)
+ view_to_world_dst = CAMERA_CONVENTIONS[convention_dst].get(
+ 'view_to_world', False)
+ left_mm_in_dst = CAMERA_CONVENTIONS[convention_dst].get(
+ 'left_mm_intrinsic', False)
+
+ sign_src, axis_src = enc_camera_convention(convention_src,
+ camera_conventions)
+ sign_dst, axis_dst = enc_camera_convention(convention_dst,
+ camera_conventions)
+ sign = torch.Tensor(sign_dst) / torch.Tensor(sign_src)
+
+ type_ = []
+ for x in [K, R, T]:
+ if x is not None:
+ type_.append(type(x))
+ if len(type_) > 0:
+ if not all(x == type_[0] for x in type_):
+ raise TypeError('Input type should be the same.')
+
+ use_numpy = False
+ if np.ndarray in type_:
+ use_numpy = True
+ # convert raw matrix to tensor
+ if isinstance(K, np.ndarray):
+ new_K = torch.Tensor(K)
+ elif K is None:
+ new_K = None
+ elif isinstance(K, torch.Tensor):
+ new_K = K.clone()
+ else:
+ raise TypeError(
+ f'K should be `torch.Tensor` or `numpy.ndarray`, type(K): '
+ f'{type(K)}')
+
+ if isinstance(R, np.ndarray):
+ new_R = torch.Tensor(R).view(-1, 3, 3)
+ elif R is None:
+ new_R = torch.eye(3, 3)[None]
+ elif isinstance(R, torch.Tensor):
+ new_R = R.clone().view(-1, 3, 3)
+ else:
+ raise TypeError(
+ f'R should be `torch.Tensor` or `numpy.ndarray`, type(R): '
+ f'{type(R)}')
+
+ if isinstance(T, np.ndarray):
+ new_T = torch.Tensor(T).view(-1, 3)
+ elif T is None:
+ new_T = torch.zeros(1, 3)
+ elif isinstance(T, torch.Tensor):
+ new_T = T.clone().view(-1, 3)
+ else:
+ raise TypeError(
+ f'T should be `torch.Tensor` or `numpy.ndarray`, type(T): '
+ f'{type(T)}')
+
+ if axis_dst != axis_src:
+ new_R = ee_to_rotmat(rotmat_to_ee(new_R, convention=axis_src),
+ convention=axis_dst)
+
+ # convert extrinsic to world_to_view
+ if view_to_world_src is True:
+ new_R, new_T = convert_world_view(new_R, new_T)
+
+ # right mm to left mm
+ if (not left_mm_ex_src) and left_mm_ex_dst:
+ new_R *= sign.to(new_R.device)
+ new_R = new_R.permute(0, 2, 1)
+ # left mm to right mm
+ elif left_mm_ex_src and (not left_mm_ex_dst):
+ new_R = new_R.permute(0, 2, 1)
+ new_R *= sign.to(new_R.device)
+ # right_mm to right mm
+ elif (not left_mm_ex_dst) and (not left_mm_ex_src):
+ new_R *= sign.to(new_R.device)
+ # left mm to left mm
+ elif left_mm_ex_src and left_mm_ex_dst:
+ new_R *= sign.view(3, 1).to(new_R.device)
+ new_T *= sign.to(new_T.device)
+
+ # convert extrinsic to as definition
+ if view_to_world_dst is True:
+ new_R, new_T = convert_world_view(new_R, new_T)
+
+ # in ndc or in screen
+ if in_ndc_dst is False and in_ndc_src is True:
+ assert resolution_dst is not None, \
+ 'dst in screen, should specify resolution_dst.'
+
+ if in_ndc_src is False and in_ndc_dst is True:
+ assert resolution_src is not None, \
+ 'src in screen, should specify resolution_dst.'
+ if resolution_dst is None:
+ resolution_dst = 2.0
+ if resolution_src is None:
+ resolution_src = 2.0
+
+ if new_K is not None:
+ if left_mm_in_src is False and left_mm_in_dst is True:
+ new_K = new_K.permute(0, 2, 1)
+ if new_K.shape[-2:] == (3, 3):
+ new_K = convert_K_3x3_to_4x4(new_K, is_perspective)
+ # src in ndc, dst in screen
+
+ if in_ndc_src is True and (in_ndc_dst is False):
+ new_K = convert_ndc_to_screen(K=new_K,
+ is_perspective=is_perspective,
+ sign=sign.to(new_K.device),
+ resolution=resolution_dst)
+ # src in screen, dst in ndc
+ elif in_ndc_src is False and in_ndc_dst is True:
+ new_K = convert_screen_to_ndc(K=new_K,
+ is_perspective=is_perspective,
+ sign=sign.to(new_K.device),
+ resolution=resolution_src)
+ # src in ndc, dst in ndc
+ elif in_ndc_src is True and in_ndc_dst is True:
+ if is_perspective:
+ new_K[:, 0, 2] *= sign[0].to(new_K.device)
+ new_K[:, 1, 2] *= sign[1].to(new_K.device)
+ else:
+ new_K[:, 0, 3] *= sign[0].to(new_K.device)
+ new_K[:, 1, 3] *= sign[1].to(new_K.device)
+ # src in screen, dst in screen
+ else:
+ pass
+
+ if left_mm_in_src is True and left_mm_in_dst is False:
+ new_K = new_K.permute(0, 2, 1)
+
+ num_batch = max(new_K.shape[0], new_R.shape[0], new_T.shape[0])
+ if new_K.shape[0] == 1:
+ new_K = new_K.repeat(num_batch, 1, 1)
+ if new_R.shape[0] == 1:
+ new_R = new_R.repeat(num_batch, 1, 1)
+ if new_T.shape[0] == 1:
+ new_T = new_T.repeat(num_batch, 1)
+
+ if use_numpy:
+ if isinstance(new_K, torch.Tensor):
+ new_K = new_K.cpu().numpy()
+ if isinstance(new_R, torch.Tensor):
+ new_R = new_R.cpu().numpy()
+ if isinstance(new_T, torch.Tensor):
+ new_T = new_T.cpu().numpy()
+ return new_K, new_R, new_T
+
+
+def convert_K_3x3_to_4x4(
+ K: Union[torch.Tensor, np.ndarray],
+ is_perspective: bool = True) -> Union[torch.Tensor, np.ndarray]:
+ """Convert opencv 3x3 intrinsic matrix to 4x4.
+
+ Args:
+ K (Union[torch.Tensor, np.ndarray]):
+ Input 3x3 intrinsic matrix, left mm defined.
+ [[fx, 0, px],
+ [0, fy, py],
+ [0, 0, 1]]
+ is_perspective (bool, optional): whether is perspective projection.
+ Defaults to True.
+
+ Raises:
+ TypeError: K is not `Tensor` or `array`.
+ ValueError: Shape is not (batch, 3, 3) or (3, 3)
+
+ Returns:
+ Union[torch.Tensor, np.ndarray]:
+ Output intrinsic matrix.
+ for perspective:
+ [[fx, 0, px, 0],
+ [0, fy, py, 0],
+ [0, 0, 0, 1],
+ [0, 0, 1, 0]]
+
+ for orthographics:
+ [[fx, 0, 0, px],
+ [0, fy, 0, py],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1]]
+ """
+ if isinstance(K, torch.Tensor):
+ K = K.clone()
+ elif isinstance(K, np.ndarray):
+ K = K.copy()
+
+ else:
+ raise TypeError('K should be `torch.Tensor` or `numpy.ndarray`, '
+ f'type(K): {type(K)}.')
+ if K.shape[-2:] == (4, 4):
+ warnings.warn(
+ f'shape of K already is {K.shape}, will pass converting.')
+ return K
+ use_numpy = False
+ if K.ndim == 2:
+ K = K[None].reshape(-1, 3, 3)
+ elif K.ndim == 3:
+ K = K.reshape(-1, 3, 3)
+ else:
+ raise ValueError(f'Wrong ndim of K: {K.ndim}')
+
+ if isinstance(K, np.ndarray):
+ use_numpy = True
+ if is_perspective:
+ if use_numpy:
+ K_ = np.zeros((K.shape[0], 4, 4))
+ else:
+ K_ = torch.zeros(K.shape[0], 4, 4)
+ K_[:, :2, :3] = K[:, :2, :3]
+ K_[:, 3, 2] = 1
+ K_[:, 2, 3] = 1
+ else:
+ if use_numpy:
+ K_ = np.eye(4, 4)[None].repeat(K.shape[0], 0)
+ else:
+ K_ = torch.eye(4, 4)[None].repeat(K.shape[0], 1, 1)
+ K_[:, :2, :2] = K[:, :2, :2]
+ K_[:, :2, 3:] = K[:, :2, 2:]
+ return K_
+
+
+def convert_K_4x4_to_3x3(
+ K: Union[torch.Tensor, np.ndarray],
+ is_perspective: bool = True) -> Union[torch.Tensor, np.ndarray]:
+ """Convert opencv 4x4 intrinsic matrix to 3x3.
+
+ Args:
+ K (Union[torch.Tensor, np.ndarray]):
+ Input 4x4 intrinsic matrix, left mm defined.
+ for perspective:
+ [[fx, 0, px, 0],
+ [0, fy, py, 0],
+ [0, 0, 0, 1],
+ [0, 0, 1, 0]]
+
+ for orthographics:
+ [[fx, 0, 0, px],
+ [0, fy, 0, py],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1]]
+ is_perspective (bool, optional): whether is perspective projection.
+ Defaults to True.
+
+ Raises:
+ TypeError: type K should be `Tensor` or `array`.
+ ValueError: Shape is not (batch, 3, 3) or (3, 3).
+
+ Returns:
+ Union[torch.Tensor, np.ndarray]:
+ Output 3x3 intrinsic matrix, left mm defined.
+ [[fx, 0, px],
+ [0, fy, py],
+ [0, 0, 1]]
+ """
+
+ if isinstance(K, torch.Tensor):
+ K = K.clone()
+ elif isinstance(K, np.ndarray):
+ K = K.copy()
+ else:
+ raise TypeError('K should be `torch.Tensor` or `numpy.ndarray`, '
+ f'type(K): {type(K)}.')
+ if K.shape[-2:] == (3, 3):
+ warnings.warn(
+ f'shape of K already is {K.shape}, will pass converting.')
+ return K
+ use_numpy = True if isinstance(K, np.ndarray) else False
+ if K.ndim == 2:
+ K = K[None].reshape(-1, 4, 4)
+ elif K.ndim == 3:
+ K = K.reshape(-1, 4, 4)
+ else:
+ raise ValueError(f'Wrong ndim of K: {K.ndim}')
+
+ if use_numpy:
+ K_ = np.eye(3, 3)[None].repeat(K.shape[0], 0)
+ else:
+ K_ = torch.eye(3, 3)[None].repeat(K.shape[0], 1, 1)
+ if is_perspective:
+ K_[:, :2, :3] = K[:, :2, :3]
+ else:
+ K_[:, :2, :2] = K[:, :2, :2]
+ K_[:, :2, 2:3] = K[:, :2, 3:4]
+ return K_
+
+
+def convert_ndc_to_screen(
+ K: Union[torch.Tensor, np.ndarray],
+ resolution: Union[int, Tuple[int, int], List[int], torch.Tensor,
+ np.ndarray],
+ sign: Optional[Iterable[int]] = None,
+ is_perspective: bool = True) -> Union[torch.Tensor, np.ndarray]:
+ """Convert intrinsic matrix from ndc to screen.
+
+ Args:
+ K (Union[torch.Tensor, np.ndarray]):
+ Input 4x4 intrinsic matrix, left mm defined.
+ resolution (Union[int, Tuple[int, int], torch.Tensor, np.ndarray]):
+ (height, width) of image.
+ sign (Optional[Union[Iterable[int]]], optional): xyz axis sign.
+ Defaults to None.
+ is_perspective (bool, optional): whether is perspective projection.
+ Defaults to True.
+
+ Raises:
+ TypeError: K should be Tensor or array.
+ ValueError: shape of K should be (batch, 4, 4)
+
+ Returns:
+ Union[torch.Tensor, np.ndarray]: output intrinsic matrix.
+ """
+ sign = [1, 1, 1] if sign is None else sign
+ if isinstance(K, torch.Tensor):
+ K = K.clone()
+ elif isinstance(K, np.ndarray):
+ K = K.copy()
+ else:
+ raise TypeError(
+ f'K should be `torch.Tensor` or `np.ndarray`, type(K): {type(K)}')
+ if K.ndim == 2:
+ K = K[None].reshape(-1, 4, 4)
+ elif K.ndim == 3:
+ K = K.reshape(-1, 4, 4)
+ else:
+ raise ValueError(f'Wrong ndim of K: {K.ndim}')
+
+ if isinstance(resolution, (int, float)):
+ w_dst = h_dst = resolution
+ elif isinstance(resolution, (list, tuple)):
+ h_dst, w_dst = resolution
+ elif isinstance(resolution, (torch.Tensor, np.ndarray)):
+ resolution = resolution.reshape(-1, 2)
+ h_dst, w_dst = resolution[:, 0], resolution[:, 1]
+
+ aspect_ratio = w_dst / h_dst
+ K[:, 0, 0] *= w_dst / 2
+ K[:, 1, 1] *= h_dst / 2
+ if aspect_ratio > 1:
+ K[:, 0, 0] /= aspect_ratio
+ else:
+ K[:, 1, 1] *= aspect_ratio
+ if is_perspective:
+ K[:, 0, 2] *= sign[0]
+ K[:, 1, 2] *= sign[1]
+ K[:, 0, 2] = (K[:, 0, 2] + 1) * (w_dst / 2)
+ K[:, 1, 2] = (K[:, 1, 2] + 1) * (h_dst / 2)
+ else:
+ K[:, 0, 3] *= sign[0]
+ K[:, 1, 3] *= sign[1]
+ K[:, 0, 3] = (K[:, 0, 3] + 1) * (w_dst / 2)
+ K[:, 1, 3] = (K[:, 1, 3] + 1) * (h_dst / 2)
+ return K
+
+
+def convert_screen_to_ndc(
+ K: Union[torch.Tensor, np.ndarray],
+ resolution: Union[int, Tuple[int, int], torch.Tensor, np.ndarray],
+ sign: Optional[Iterable[int]] = None,
+ is_perspective: bool = True) -> Union[torch.Tensor, np.ndarray]:
+ """Convert intrinsic matrix from screen to ndc.
+
+ Args:
+ K (Union[torch.Tensor, np.ndarray]): input intrinsic matrix.
+ resolution (Union[int, Tuple[int, int], torch.Tensor, np.ndarray]):
+ (height, width) of image.
+ sign (Optional[Union[Iterable[int]]], optional): xyz axis sign.
+ Defaults to None.
+ is_perspective (bool, optional): whether is perspective projection.
+ Defaults to True.
+
+ Raises:
+ TypeError: K should be Tensor or array.
+ ValueError: shape of K should be (batch, 4, 4)
+
+ Returns:
+ Union[torch.Tensor, np.ndarray]: output intrinsic matrix.
+ """
+ if sign is None:
+ sign = [1, 1, 1]
+
+ if isinstance(K, torch.Tensor):
+ K = K.clone()
+ elif isinstance(K, np.ndarray):
+ K = K.copy()
+ else:
+ raise TypeError(
+ f'K should be `torch.Tensor` or `np.ndarray`, type(K): {type(K)}')
+ if K.ndim == 2:
+ K = K[None].reshape(-1, 4, 4)
+ elif K.ndim == 3:
+ K = K.reshape(-1, 4, 4)
+ else:
+ raise ValueError(f'Wrong ndim of K: {K.ndim}')
+
+ if isinstance(resolution, (int, float)):
+ w_src = h_src = resolution
+ elif isinstance(resolution, (list, tuple)):
+ h_src, w_src = resolution
+ elif isinstance(resolution, (torch.Tensor, np.ndarray)):
+ resolution = resolution.reshape(-1, 2)
+ h_src, w_src = resolution[:, 0], resolution[:, 1]
+
+ aspect_ratio = w_src / h_src
+ K[:, 0, 0] /= w_src / 2
+ K[:, 1, 1] /= h_src / 2
+ if aspect_ratio > 1:
+ K[:, 0, 0] *= aspect_ratio
+ else:
+ K[:, 1, 1] /= aspect_ratio
+ if is_perspective:
+ K[:, 0, 2] = K[:, 0, 2] / (w_src / 2) - 1
+ K[:, 1, 2] = K[:, 1, 2] / (h_src / 2) - 1
+ K[:, 0, 2] *= sign[0]
+ K[:, 1, 2] *= sign[1]
+ else:
+ K[:, 0, 3] = K[:, 0, 3] / (w_src / 2) - 1
+ K[:, 1, 3] = K[:, 1, 3] / (h_src / 2) - 1
+ K[:, 0, 3] *= sign[0]
+ K[:, 1, 3] *= sign[1]
+ return K
+
+
+def convert_world_view(
+ R: Union[torch.Tensor, np.ndarray], T: Union[torch.Tensor, np.ndarray]
+) -> Tuple[Union[torch.Tensor, np.ndarray], Union[torch.Tensor, np.ndarray]]:
+ """Convert between view_to_world and world_to_view defined extrinsic
+ matrix.
+
+ Args:
+ R (Union[torch.Tensor, np.ndarray]): extrinsic rotation matrix.
+ shape should be (batch, 3, 4)
+ T (Union[torch.Tensor, np.ndarray]): extrinsic translation matrix.
+
+ Raises:
+ TypeError: R and T should be of the same type.
+
+ Returns:
+ Tuple[Union[torch.Tensor, np.ndarray], Union[torch.Tensor,
+ np.ndarray]]: output R, T.
+ """
+ if not (type(R) is type(T)):
+ raise TypeError(
+ f'R: {type(R)}, T: {type(T)} should have the same type.')
+ if isinstance(R, torch.Tensor):
+ R = R.clone()
+ T = T.clone()
+ R = R.permute(0, 2, 1)
+ T = -(R @ T.view(-1, 3, 1)).view(-1, 3)
+ elif isinstance(R, np.ndarray):
+ R = R.copy()
+ T = T.copy()
+ R = R.transpose(0, 2, 1)
+ T = -(R @ T.reshape(-1, 3, 1)).reshape(-1, 3)
+ else:
+ raise TypeError(f'R: {type(R)}, T: {type(T)} should be torch.Tensor '
+ f'or numpy.ndarray.')
+ return R, T
diff --git a/detrsmpl/core/conventions/cameras/convert_projection.py b/detrsmpl/core/conventions/cameras/convert_projection.py
new file mode 100644
index 0000000000000000000000000000000000000000..e051d2f195a018ed5948c21fa790a91fafe693c1
--- /dev/null
+++ b/detrsmpl/core/conventions/cameras/convert_projection.py
@@ -0,0 +1,108 @@
+from typing import Tuple, Union
+
+import numpy as np
+import torch
+
+from .convert_convention import convert_camera_matrix
+
+
+def convert_perspective_to_weakperspective(
+ K: Union[torch.Tensor, np.ndarray],
+ zmean: Union[torch.Tensor, np.ndarray, float, int],
+ resolution: Union[int, Tuple[int, int], torch.Tensor,
+ np.ndarray] = None,
+ in_ndc: bool = False,
+ convention: str = 'opencv') -> Union[torch.Tensor, np.ndarray]:
+ """Convert perspective to weakperspective intrinsic matrix.
+
+ Args:
+ K (Union[torch.Tensor, np.ndarray]): input intrinsic matrix, shape
+ should be (batch, 4, 4) or (batch, 3, 3).
+ zmean (Union[torch.Tensor, np.ndarray, int, float]): zmean for object.
+ shape should be (batch, ) or singleton number.
+ resolution (Union[int, Tuple[int, int], torch.Tensor, np.ndarray],
+ optional): (height, width) of image. Defaults to None.
+ in_ndc (bool, optional): whether defined in ndc. Defaults to False.
+ convention (str, optional): camera convention. Defaults to 'opencv'.
+
+ Returns:
+ Union[torch.Tensor, np.ndarray]: output weakperspective pred_cam,
+ shape is (batch, 4)
+ """
+ assert K is not None, 'K is required.'
+ K, _, _ = convert_camera_matrix(K=K,
+ convention_src=convention,
+ convention_dst='pytorch3d',
+ is_perspective=True,
+ in_ndc_src=in_ndc,
+ in_ndc_dst=True,
+ resolution_src=resolution)
+ if isinstance(zmean, np.ndarray):
+ zmean = torch.Tensor(zmean)
+ elif isinstance(zmean, (float, int)):
+ zmean = torch.Tensor([zmean])
+ zmean = zmean.view(-1)
+ num_frame = max(zmean.shape[0], K.shape[0])
+ new_K = torch.eye(4, 4)[None].repeat(num_frame, 1, 1)
+ fx = K[:, 0, 0]
+ fy = K[:, 0, 0]
+ cx = K[:, 0, 2]
+ cy = K[:, 1, 2]
+ new_K[:, 0, 0] = fx / zmean
+ new_K[:, 1, 1] = fy / zmean
+ new_K[:, 0, 3] = cx
+ new_K[:, 1, 3] = cy
+ return new_K
+
+
+def convert_weakperspective_to_perspective(
+ K: Union[torch.Tensor, np.ndarray],
+ zmean: Union[torch.Tensor, np.ndarray, int, float],
+ resolution: Union[int, Tuple[int, int], torch.Tensor,
+ np.ndarray] = None,
+ in_ndc: bool = False,
+ convention: str = 'opencv') -> Union[torch.Tensor, np.ndarray]:
+ """Convert perspective to weakperspective intrinsic matrix.
+
+ Args:
+ K (Union[torch.Tensor, np.ndarray]): input intrinsic matrix, shape
+ should be (batch, 4, 4) or (batch, 3, 3).
+ zmean (Union[torch.Tensor, np.ndarray, int, float]): zmean for object.
+ shape should be (batch, ) or singleton number.
+ resolution (Union[int, Tuple[int, int], torch.Tensor, np.ndarray],
+ optional): (height, width) of image. Defaults to None.
+ in_ndc (bool, optional): whether defined in ndc. Defaults to False.
+ convention (str, optional): camera convention. Defaults to 'opencv'.
+
+ Returns:
+ Union[torch.Tensor, np.ndarray]: output weakperspective pred_cam,
+ shape is (batch, 4)
+ """
+ if K.ndim == 2:
+ K = K[None]
+ if isinstance(zmean, np.ndarray):
+ zmean = torch.Tensor(zmean)
+ elif isinstance(zmean, (float, int)):
+ zmean = torch.Tensor([zmean])
+ zmean = zmean.view(-1)
+ _N = max(K.shape[0], zmean.shape[0])
+ s1 = K[:, 0, 0]
+ s2 = K[:, 1, 1]
+ c1 = K[:, 0, 3]
+ c2 = K[:, 1, 3]
+ new_K = torch.zeros(_N, 4, 4)
+ new_K[:, 0, 0] = zmean * s1
+ new_K[:, 1, 1] = zmean * s2
+ new_K[:, 0, 2] = c1
+ new_K[:, 1, 2] = c2
+ new_K[:, 2, 3] = 1
+ new_K[:, 3, 2] = 1
+
+ new_K, _, _ = convert_camera_matrix(K=new_K,
+ convention_src=convention,
+ convention_dst='pytorch3d',
+ is_perspective=True,
+ in_ndc_src=in_ndc,
+ in_ndc_dst=True,
+ resolution_src=resolution)
+ return new_K
diff --git a/detrsmpl/core/conventions/joints_mapping/__init__.py b/detrsmpl/core/conventions/joints_mapping/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/core/conventions/joints_mapping/standard_joint_angles.py b/detrsmpl/core/conventions/joints_mapping/standard_joint_angles.py
new file mode 100644
index 0000000000000000000000000000000000000000..44ca32a659224afb8465ffa09d625e70a087b409
--- /dev/null
+++ b/detrsmpl/core/conventions/joints_mapping/standard_joint_angles.py
@@ -0,0 +1,54 @@
+import torch
+
+TRANSFORMATION_AA_TO_SJA = torch.Tensor([
+ [[1, 0, 0], [0, 0, 1], [0, -1, 0]], # 00, 'left_hip',
+ [[1, 0, 0], [0, 0, 1], [0, -1, 0]], # 01, 'right_hip',
+ [[1, 0, 0], [0, 0, -1], [0, 1, 0]], # 02, 'spine1',
+ [[1, 0, 0], [0, 0, 1], [0, -1, 0]], # 03, 'left_knee',
+ [[1, 0, 0], [0, 0, 1], [0, -1, 0]], # 04, 'right_knee',
+ [[1, 0, 0], [0, 0, -1], [0, 1, 0]], # 05, 'spine2',
+ [[1, 0, 0], [0, 1, 0], [0, 0, 1]], # 06, 'left_ankle',
+ [[1, 0, 0], [0, 1, 0], [0, 0, 1]], # 07, 'right_ankle',
+ [[1, 0, 0], [0, 0, -1], [0, 1, 0]], # 08, 'spine3',
+ [[1, 0, 0], [0, 1, 0], [0, 0, 1]], # 09, 'left_foot',
+ [[1, 0, 0], [0, 1, 0], [0, 0, 1]], # 10, 'right_foot',
+ [[1, 0, 0], [0, 0, -1], [0, 1, 0]], # 11, 'neck',
+ [[0, 0, -1], [0, 1, 0], [1, 0, 0]], # 12, 'left_collar',
+ [[0, 0, 1], [0, 1, 0], [-1, 0, 0]], # 13, 'right_collar',
+ [[1, 0, 0], [0, 0, -1], [0, 1, 0]], # 14, 'head',
+ [[0, 0, -1], [0, 1, 0], [1, 0, 0]], # 15, 'left_shoulder',
+ [[0, 0, 1], [0, 1, 0], [-1, 0, 0]], # 16, 'right_shoulder',
+ [[0, 0, -1], [0, 1, 0], [1, 0, 0]], # 17, 'left_elbow',
+ [[0, 0, 1], [0, 1, 0], [-1, 0, 0]], # 18, 'right_elbow',
+ [[0, 0, -1], [0, 1, 0], [1, 0, 0]], # 19, 'left_wrist',
+ [[0, 0, 1], [0, 1, 0], [-1, 0, 0]], # 20, 'right_wrist',
+])
+
+TRANSFORMATION_SJA_TO_AA = \
+ torch.inverse(TRANSFORMATION_AA_TO_SJA)
+
+# TODO: spines and shoulders may need further adjustment
+STANDARD_JOINT_ANGLE_LIMITS = torch.deg2rad(
+ torch.Tensor([
+ [[-45, 155], [-88, 17], [-105, 85]], # 00, 'left_hip',
+ [[-45, 155], [-17, 88], [-85, 105]], # 01, 'right_hip',
+ [[-25, 15], [-20, 20], [-30, 30]], # 02, 'spine1',
+ [[0, 150], [0, 0], [0, 0]], # 03, 'left_knee',
+ [[0, 150], [0, 0], [0, 0]], # 04, 'right_knee',
+ [[-25, 15], [-15, 15], [-25, 25]], # 05, 'spine2',
+ [[-31, 63], [-26, 26], [-74, 15]], # 06, 'left_ankle',
+ [[-31, 63], [-26, 26], [-15, 74]], # 07, 'right_ankle',
+ [[-25, 15], [-15, 15], [-25, 25]], # 08, 'spine3',
+ [[-60, 45], [0, 0], [-45, 45]], # 09, 'left_foot',
+ [[-60, 45], [0, 0], [-45, 45]], # 10, 'right_foot',
+ [[-37, 22], [-30, 30], [-45, 45]], # 11, 'neck',
+ [[-30, 30], [-30, 10], [0, 0]], # 12, 'left_collar',
+ [[-30, 30], [-10, 30], [0, 0]], # 13, 'right_collar',
+ [[-37, 22], [-30, 30], [-45, 45]], # 14, 'head',
+ [[-90, 135], [-97, 91], [-90, 135]], # 15, 'left_shoulder',
+ [[-135, 90], [-91, 97], [-135, 90]], # 16, 'right_shoulder',
+ [[0, 0], [-150, 0], [0, 0]], # 17, 'left_elbow',
+ [[0, 0], [0, 150], [0, 0]], # 18, 'right_elbow',
+ [[-90, 90], [-45, 45], [-180, 60]], # 19, 'left_wrist',
+ [[-90, 90], [-45, 45], [-60, 180]], # 20, 'right_wrist',
+ ]))
diff --git a/detrsmpl/core/conventions/keypoints_mapping/__init__.py b/detrsmpl/core/conventions/keypoints_mapping/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe966a6158eca3ecabe31b4bc6e44c4622983668
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/__init__.py
@@ -0,0 +1,399 @@
+from collections import defaultdict
+from typing import List, Tuple, Union
+
+import numpy as np
+import torch
+from mmcv.utils import print_log
+
+from detrsmpl.core.conventions.keypoints_mapping import (
+ agora,
+ coco,
+ coco_wholebody,
+ crowdpose,
+ face3d,
+ flame,
+ gta,
+ h36m,
+ human_data,
+ hybrik,
+ instavariety,
+ lsp,
+ mano,
+ mpi_inf_3dhp,
+ mpii,
+ openpose,
+ penn_action,
+ posetrack,
+ pw3d,
+ smpl,
+ smplx,
+ spin_smplx,
+ star,
+)
+
+KEYPOINTS_FACTORY = {
+ 'human_data': human_data.HUMAN_DATA,
+ 'agora': agora.AGORA_KEYPOINTS,
+ 'coco': coco.COCO_KEYPOINTS,
+ 'coco_wholebody': coco_wholebody.COCO_WHOLEBODY_KEYPOINTS,
+ 'crowdpose': crowdpose.CROWDPOSE_KEYPOINTS,
+ 'smplx': smplx.SMPLX_KEYPOINTS,
+ 'smplx_137': smplx.SMPLX_137_KEYPOINTS,
+ 'smplx_lhand': smplx.SMPLX_LHAND,
+ 'smplx_rhand': smplx.SMPLX_RHAND,
+ 'smplx_face': smplx.SMPLX_FACE,
+ 'smplx_aios': smplx.AiOS_35_KEYPOINTS,
+ 'smpl': smpl.SMPL_KEYPOINTS,
+ 'smpl_45': smpl.SMPL_45_KEYPOINTS,
+ 'smpl_54': smpl.SMPL_54_KEYPOINTS,
+ 'smpl_49': smpl.SMPL_49_KEYPOINTS,
+ 'smpl_24': smpl.SMPL_24_KEYPOINTS,
+ 'star': star.STAR_KEYPOINTS,
+ 'mpi_inf_3dhp': mpi_inf_3dhp.MPI_INF_3DHP_KEYPOINTS,
+ 'mpi_inf_3dhp_test': mpi_inf_3dhp.MPI_INF_3DHP_TEST_KEYPOINTS,
+ 'penn_action': penn_action.PENN_ACTION_KEYPOINTS,
+ 'h36m': h36m.H36M_KEYPOINTS,
+ 'h36m_mmpose': h36m.H36M_KEYPOINTS_MMPOSE,
+ 'h36m_smplx': h36m.H36M_KEYPOINTS_SMPLX,
+ 'pw3d': pw3d.PW3D_KEYPOINTS,
+ 'mpii': mpii.MPII_KEYPOINTS,
+ 'lsp': lsp.LSP_KEYPOINTS,
+ 'posetrack': posetrack.POSETRACK_KEYPOINTS,
+ 'instavariety': instavariety.INSTAVARIETY_KEYPOINTS,
+ 'openpose_25': openpose.OPENPOSE_25_KEYPOINTS,
+ 'openpose_118': openpose.OPENPOSE_118_KEYPOINTS,
+ 'openpose_135': openpose.OPENPOSE_135_KEYPOINTS,
+ 'openpose_137': openpose.OPENPOSE_137_KEYPOINTS,
+ 'hybrik_29': hybrik.HYBRIK_29_KEYPOINTS,
+ 'hybrik_hp3d': mpi_inf_3dhp.HYBRIK_MPI_INF_3DHP_KEYPOINTS,
+ 'gta': gta.GTA_KEYPOINTS,
+ 'flame': flame.FLAME_73_KEYPOINTS,
+ 'face3d': face3d.FACE3D_IND,
+ 'spin_smplx': spin_smplx.SPIN_SMPLX_KEYPOINTS,
+ 'mano': mano.MANO_KEYPOINTS,
+ 'mano_left': mano.MANO_LEFT_KEYPOINTS,
+ 'mano_right': mano.MANO_RIGHT_KEYPOINTS,
+ 'mano_hands': mano.MANO_HANDS_KEYPOINTS,
+ 'mano_left_reorder': mano.MANO_LEFT_REORDER_KEYPOINTS,
+ 'mano_right_reorder': mano.MANO_RIGHT_REORDER_KEYPOINTS,
+ 'mano_hands_reorder': mano.MANO_HANDS_REORDER_KEYPOINTS,
+}
+
+__KEYPOINTS_MAPPING_CACHE__ = defaultdict(dict)
+
+
+def convert_kps(
+ keypoints: Union[np.ndarray, torch.Tensor],
+ src: str,
+ dst: str,
+ approximate: bool = False,
+ mask: Union[np.ndarray, torch.Tensor] = None,
+ keypoints_factory: dict = KEYPOINTS_FACTORY,
+ return_mask: bool = True
+) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]:
+ """Convert keypoints following the mapping correspondence between src and
+ dst keypoints definition. Supported conventions by now: agora, coco, smplx,
+ smpl, mpi_inf_3dhp, mpi_inf_3dhp_test, h36m, h36m_mmpose, pw3d, mpii, lsp.
+ Args:
+ keypoints [Union[np.ndarray, torch.Tensor]]: input keypoints array,
+ could be (f * n * J * 3/2) or (f * J * 3/2).
+ You can set keypoints as np.zeros((1, J, 2))
+ if you only need mask.
+ src (str): source data type from keypoints_factory.
+ dst (str): destination data type from keypoints_factory.
+ approximate (bool): control whether approximate mapping is allowed.
+ mask (Union[np.ndarray, torch.Tensor], optional):
+ The original mask to mark the existence of the keypoints.
+ None represents all ones mask.
+ Defaults to None.
+ keypoints_factory (dict, optional): A class to store the attributes.
+ Defaults to keypoints_factory.
+ return_mask (bool, optional): whether to return a mask as part of the
+ output. It is unnecessary to return a mask if the keypoints consist
+ of confidence. Any invalid keypoints will have zero confidence.
+ Defaults to True.
+ Returns:
+ Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]
+ : tuple of (out_keypoints, mask). out_keypoints and mask will be of
+ the same type.
+ """
+ assert keypoints.ndim in {3, 4}
+ if isinstance(keypoints, torch.Tensor):
+
+ def new_array_func(shape, value, device_data, if_uint8):
+ if if_uint8:
+ dtype = torch.uint8
+ else:
+ dtype = None
+ if value == 1:
+ return torch.ones(size=shape,
+ dtype=dtype,
+ device=device_data.device)
+ elif value == 0:
+ return torch.zeros(size=shape,
+ dtype=dtype,
+ device=device_data.device)
+ else:
+ raise ValueError
+
+ def to_type_uint8_func(data):
+ return data.to(dtype=torch.uint8)
+
+ elif isinstance(keypoints, np.ndarray):
+
+ def new_array_func(shape, value, device_data, if_uint8):
+ if if_uint8:
+ dtype = np.uint8
+ else:
+ dtype = None
+ if value == 1:
+ return np.ones(shape=shape)
+ elif value == 0:
+ return np.zeros(shape=shape, dtype=dtype)
+ else:
+ raise ValueError
+
+ def to_type_uint8_func(data):
+ return data.astype(np.uint8)
+
+ else:
+ raise TypeError('Type of keypoints is neither' +
+ ' torch.Tensor nor np.ndarray.\n' +
+ f'Type of keypoints: {type(keypoints)}')
+
+ if mask is not None:
+ assert type(mask) == type(keypoints)
+ else:
+ mask = new_array_func(shape=(keypoints.shape[-2], ),
+ value=1,
+ device_data=keypoints,
+ if_uint8=True)
+
+ if src == dst:
+ if return_mask:
+ return keypoints, mask
+ else:
+ return keypoints
+
+ src_names = keypoints_factory[src.lower()]
+ dst_names = keypoints_factory[dst.lower()]
+ extra_dims = keypoints.shape[:-2]
+ keypoints = keypoints.reshape(-1, len(src_names), keypoints.shape[-1])
+
+ out_keypoints = new_array_func(shape=(keypoints.shape[0], len(dst_names),
+ keypoints.shape[-1]),
+ value=0,
+ device_data=keypoints,
+ if_uint8=False)
+
+ original_mask = mask
+ if original_mask is not None:
+ original_mask = original_mask.reshape(-1)
+ assert original_mask.shape[0] == len(
+ src_names), f'The length of mask should be {len(src_names)}'
+
+ mask = new_array_func(shape=(len(dst_names), ),
+ value=0,
+ device_data=keypoints,
+ if_uint8=True)
+
+ dst_idxs, src_idxs, _ = \
+ get_mapping(src, dst, approximate, keypoints_factory)
+ out_keypoints[:, dst_idxs] = keypoints[:, src_idxs]
+ out_shape = extra_dims + (len(dst_names), keypoints.shape[-1])
+ out_keypoints = out_keypoints.reshape(out_shape)
+ mask[dst_idxs] = to_type_uint8_func(original_mask[src_idxs]) \
+ if original_mask is not None else 1.0
+
+ if return_mask:
+ return out_keypoints, mask
+ else:
+ return out_keypoints
+
+
+def compress_converted_kps(
+ zero_pad_array: Union[np.ndarray, torch.Tensor],
+ mask_array: Union[np.ndarray, torch.Tensor],
+) -> Union[np.ndarray, torch.Tensor]:
+ """Compress keypoints that are zero-padded after applying convert_kps.
+
+ Args:
+ keypoints (np.ndarray): input keypoints array, could be
+ (f * n * J * 3/2) or (f * J * 3/2). You can set keypoints as
+ np.zeros((1, J, 2)) if you only need mask.
+ mask [Union[np.ndarray, torch.Tensor]]:
+ The original mask to mark the existence of the keypoints.
+ Returns:
+ Union[np.ndarray, torch.Tensor]: out_keypoints
+ """
+
+ assert mask_array.shape[0] == zero_pad_array.shape[1]
+ valid_mask_index = np.where(mask_array == 1)[0]
+ compressed_array = np.take(zero_pad_array, valid_mask_index, axis=1)
+ return compressed_array
+
+
+def get_mapping(src: str,
+ dst: str,
+ approximate: bool = False,
+ keypoints_factory: dict = KEYPOINTS_FACTORY):
+ """Get mapping list from src to dst.
+
+ Args:
+ src (str): source data type from keypoints_factory.
+ dst (str): destination data type from keypoints_factory.
+ approximate (bool): control whether approximate mapping is allowed.
+ keypoints_factory (dict, optional): A class to store the attributes.
+ Defaults to keypoints_factory.
+
+ Returns:
+ list:
+ [src_to_intersection_idx, dst_to_intersection_index,
+ intersection_names]
+ """
+ if src in __KEYPOINTS_MAPPING_CACHE__ and \
+ dst in __KEYPOINTS_MAPPING_CACHE__[src] and \
+ __KEYPOINTS_MAPPING_CACHE__[src][dst][3] == approximate:
+ return __KEYPOINTS_MAPPING_CACHE__[src][dst][:3]
+ else:
+ src_names = keypoints_factory[src.lower()]
+ dst_names = keypoints_factory[dst.lower()]
+
+ dst_idxs, src_idxs, intersection = [], [], []
+ unmapped_names, approximate_names = [], []
+ for dst_idx, dst_name in enumerate(dst_names):
+ matched = False
+ try:
+ src_idx = src_names.index(dst_name)
+ except ValueError:
+ src_idx = -1
+ if src_idx >= 0:
+ matched = True
+ dst_idxs.append(dst_idx)
+ src_idxs.append(src_idx)
+ intersection.append(dst_name)
+ # approximate mapping
+ if approximate and not matched:
+
+ try:
+ part_list = human_data.APPROXIMATE_MAP[dst_name]
+ except KeyError:
+ continue
+ for approximate_name in part_list:
+ try:
+ src_idx = src_names.index(approximate_name)
+ except ValueError:
+ src_idx = -1
+ if src_idx >= 0:
+ dst_idxs.append(dst_idx)
+ src_idxs.append(src_idx)
+ intersection.append(dst_name)
+ unmapped_names.append(src_names[src_idx])
+ approximate_names.append(dst_name)
+ break
+
+ if unmapped_names:
+ warn_message = \
+ f'Approximate mapping {unmapped_names}' +\
+ f' to {approximate_names}'
+ print_log(msg=warn_message)
+
+ mapping_list = [dst_idxs, src_idxs, intersection, approximate]
+
+ if src not in __KEYPOINTS_MAPPING_CACHE__:
+ __KEYPOINTS_MAPPING_CACHE__[src] = {}
+ __KEYPOINTS_MAPPING_CACHE__[src][dst] = mapping_list
+ return mapping_list[:3]
+
+
+def get_flip_pairs(convention: str = 'smplx',
+ keypoints_factory: dict = KEYPOINTS_FACTORY) -> List[int]:
+ """Get indices of left, right keypoint pairs from specified convention.
+
+ Args:
+ convention (str): data type from keypoints_factory.
+ keypoints_factory (dict, optional): A class to store the attributes.
+ Defaults to keypoints_factory.
+ Returns:
+ List[int]: left, right keypoint indices
+ """
+ flip_pairs = []
+ keypoints = keypoints_factory[convention]
+ left_kps = [kp for kp in keypoints if 'left_' in kp]
+ for left_kp in left_kps:
+ right_kp = left_kp.replace('left_', 'right_')
+ flip_pairs.append([keypoints.index(kp) for kp in [left_kp, right_kp]])
+ return flip_pairs
+
+
+def get_keypoint_idxs_by_part(
+ part: str,
+ convention: str = 'smplx',
+ keypoints_factory: dict = KEYPOINTS_FACTORY) -> List[int]:
+ """Get part keypoints indices from specified part and convention.
+
+ Args:
+ part (str): part to search from
+ convention (str): data type from keypoints_factory.
+ keypoints_factory (dict, optional): A class to store the attributes.
+ Defaults to keypoints_factory.
+ Returns:
+ List[int]: part keypoint indices
+ """
+ humandata_parts = human_data.HUMAN_DATA_PARTS
+ keypoints = keypoints_factory[convention]
+ if part not in humandata_parts.keys():
+ raise ValueError('part not in allowed parts')
+ part_keypoints = list(set(humandata_parts[part]) & set(keypoints))
+ part_keypoints_idx = [keypoints.index(kp) for kp in part_keypoints]
+ return part_keypoints_idx
+
+
+def get_keypoint_idx(name: str,
+ convention: str = 'smplx',
+ approximate: bool = False,
+ keypoints_factory: dict = KEYPOINTS_FACTORY) -> List[int]:
+ """Get keypoint index from specified convention with keypoint name.
+
+ Args:
+ name (str): keypoint name
+ convention (str): data type from keypoints_factory.
+ approximate (bool): control whether approximate mapping is allowed.
+ keypoints_factory (dict, optional): A class to store the attributes.
+ Defaults to keypoints_factory.
+ Returns:
+ List[int]: keypoint index
+ """
+ keypoints = keypoints_factory[convention]
+ try:
+ idx = keypoints.index(name)
+ except ValueError:
+ idx = -1 # not matched
+ if approximate and idx == -1:
+ try:
+ part_list = human_data.APPROXIMATE_MAP[name]
+ except KeyError:
+ return idx
+ for approximate_name in part_list:
+ try:
+ idx = keypoints.index(approximate_name)
+ except ValueError:
+ idx = -1
+ if idx >= 0:
+ return idx
+ return idx
+
+
+def get_keypoint_num(convention: str = 'smplx',
+ keypoints_factory: dict = KEYPOINTS_FACTORY) -> List[int]:
+ """Get number of keypoints of specified convention.
+
+ Args:
+ convention (str): data type from keypoints_factory.
+ keypoints_factory (dict, optional): A class to store the attributes.
+ Defaults to keypoints_factory.
+ Returns:
+ List[int]: part keypoint indices
+ """
+ keypoints = keypoints_factory[convention]
+ return len(keypoints)
diff --git a/detrsmpl/core/conventions/keypoints_mapping/agora.py b/detrsmpl/core/conventions/keypoints_mapping/agora.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a1e08f739cc77d70b0f8a24bcaf265dcedd33ec
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/agora.py
@@ -0,0 +1,129 @@
+AGORA_KEYPOINTS = [
+ 'pelvis',
+ 'left_hip',
+ 'right_hip',
+ 'spine_1',
+ 'left_knee',
+ 'right_knee',
+ 'spine_2',
+ 'left_ankle',
+ 'right_ankle',
+ 'spine_3',
+ 'left_foot',
+ 'right_foot',
+ 'neck',
+ 'left_collar',
+ 'right_collar',
+ 'head',
+ 'left_shoulder',
+ 'right_shoulder',
+ 'left_elbow',
+ 'right_elbow',
+ 'left_wrist',
+ 'right_wrist',
+ 'jaw',
+ 'left_eyeball',
+ 'right_eyeball',
+ 'left_index_1',
+ 'left_index_2',
+ 'left_index_3',
+ 'left_middle_1',
+ 'left_middle_2',
+ 'left_middle_3',
+ 'left_pinky_1',
+ 'left_pinky_2',
+ 'left_pinky_3',
+ 'left_ring_1',
+ 'left_ring_2',
+ 'left_ring_3',
+ 'left_thumb_1',
+ 'left_thumb_2',
+ 'left_thumb_3',
+ 'right_index_1',
+ 'right_index_2',
+ 'right_index_3',
+ 'right_middle_1',
+ 'right_middle_2',
+ 'right_middle_3',
+ 'right_pinky_1',
+ 'right_pinky_2',
+ 'right_pinky_3',
+ 'right_ring_1',
+ 'right_ring_2',
+ 'right_ring_3',
+ 'right_thumb_1',
+ 'right_thumb_2',
+ 'right_thumb_3',
+ 'nose',
+ 'right_eye',
+ 'left_eye',
+ 'right_ear',
+ 'left_ear',
+ 'left_bigtoe',
+ 'left_smalltoe',
+ 'left_heel',
+ 'right_bigtoe',
+ 'right_smalltoe',
+ 'right_heel',
+ 'left_thumb',
+ 'left_index',
+ 'left_middle',
+ 'left_ring',
+ 'left_pinky',
+ 'right_thumb',
+ 'right_index',
+ 'right_middle',
+ 'right_ring',
+ 'right_pinky',
+ 'right_eyebrow_1',
+ 'right_eyebrow_2',
+ 'right_eyebrow_3',
+ 'right_eyebrow_4',
+ 'right_eyebrow_5',
+ 'left_eyebrow_5',
+ 'left_eyebrow_4',
+ 'left_eyebrow_3',
+ 'left_eyebrow_2',
+ 'left_eyebrow_1',
+ 'nosebridge_1',
+ 'nosebridge_2',
+ 'nosebridge_3',
+ 'nosebridge_4',
+ 'right_nose_2', # original name: nose_1
+ 'right_nose_1', # original name: nose_2
+ 'nose_middle', # original name: nose_3
+ 'left_nose_1', # original name: nose_4
+ 'left_nose_2', # original name: nose_5
+ 'right_eye_1',
+ 'right_eye_2',
+ 'right_eye_3',
+ 'right_eye_4',
+ 'right_eye_5',
+ 'right_eye_6',
+ 'left_eye_4',
+ 'left_eye_3',
+ 'left_eye_2',
+ 'left_eye_1',
+ 'left_eye_6',
+ 'left_eye_5',
+ 'right_mouth_1', # original name: mouth_1
+ 'right_mouth_2', # original name: mouth_2
+ 'right_mouth_3', # original name: mouth_3
+ 'mouth_top', # original name: mouth_4
+ 'left_mouth_3', # original name: mouth_5
+ 'left_mouth_2', # original name: mouth_6
+ 'left_mouth_1', # original name: mouth_7
+ 'left_mouth_5', # original name: mouth_8
+ 'left_mouth_4', # original name: mouth_9
+ 'mouth_bottom', # original name: mouth_10
+ 'right_mouth_4', # original name: mouth_11
+ 'right_mouth_5', # original name: mouth_12
+ 'right_lip_1', # original name: lip_1
+ 'right_lip_2', # original name: lip_2
+ 'lip_top', # original name: lip_3
+ 'left_lip_2', # original name: lip_4
+ 'left_lip_1', # original name: lip_5
+ 'left_lip_3', # original name: lip_6
+ 'lip_bottom', # original name: lip_7
+ 'right_lip_3', # original name: lip_8
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/coco.py b/detrsmpl/core/conventions/keypoints_mapping/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c39e1ae5f14b767356751cd5195fbe224d0fd88
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/coco.py
@@ -0,0 +1,19 @@
+COCO_KEYPOINTS = [
+ 'nose',
+ 'left_eye',
+ 'right_eye',
+ 'left_ear',
+ 'right_ear',
+ 'left_shoulder',
+ 'right_shoulder',
+ 'left_elbow',
+ 'right_elbow',
+ 'left_wrist',
+ 'right_wrist',
+ 'left_hip_extra',
+ 'right_hip_extra',
+ 'left_knee',
+ 'right_knee',
+ 'left_ankle',
+ 'right_ankle',
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/coco_wholebody.py b/detrsmpl/core/conventions/keypoints_mapping/coco_wholebody.py
new file mode 100644
index 0000000000000000000000000000000000000000..39e1e7e1c5cace9092b591ea3d3f0e0e637c4b60
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/coco_wholebody.py
@@ -0,0 +1,135 @@
+COCO_WHOLEBODY_KEYPOINTS = [
+ 'nose',
+ 'left_eye',
+ 'right_eye',
+ 'left_ear',
+ 'right_ear',
+ 'left_shoulder',
+ 'right_shoulder',
+ 'left_elbow',
+ 'right_elbow',
+ 'left_wrist',
+ 'right_wrist',
+ 'left_hip',
+ 'right_hip',
+ 'left_knee',
+ 'right_knee',
+ 'left_ankle',
+ 'right_ankle',
+ 'left_bigtoe',
+ 'left_smalltoe',
+ 'left_heel',
+ 'right_bigtoe',
+ 'right_smalltoe',
+ 'right_heel',
+ 'right_contour_1', # original name: face_contour_1
+ 'right_contour_2', # original name: face_contour_2
+ 'right_contour_3', # original name: face_contour_3
+ 'right_contour_4', # original name: face_contour_4
+ 'right_contour_5', # original name: face_contour_5
+ 'right_contour_6', # original name: face_contour_6
+ 'right_contour_7', # original name: face_contour_7
+ 'right_contour_8', # original name: face_contour_8
+ 'contour_middle', # original name: face_contour_9
+ 'left_contour_8', # original name: face_contour_10
+ 'left_contour_7', # original name: face_contour_11
+ 'left_contour_6', # original name: face_contour_12
+ 'left_contour_5', # original name: face_contour_13
+ 'left_contour_4', # original name: face_contour_14
+ 'left_contour_3', # original name: face_contour_15
+ 'left_contour_2', # original name: face_contour_16
+ 'left_contour_1', # original name: face_contour_17
+ 'right_eyebrow_1',
+ 'right_eyebrow_2',
+ 'right_eyebrow_3',
+ 'right_eyebrow_4',
+ 'right_eyebrow_5',
+ 'left_eyebrow_5',
+ 'left_eyebrow_4',
+ 'left_eyebrow_3',
+ 'left_eyebrow_2',
+ 'left_eyebrow_1',
+ 'nosebridge_1',
+ 'nosebridge_2',
+ 'nosebridge_3',
+ 'nosebridge_4',
+ 'right_nose_2', # original name: nose_1
+ 'right_nose_1', # original name: nose_2
+ 'nose_middle', # original name: nose_3
+ 'left_nose_1', # original name: nose_4
+ 'left_nose_2', # original name: nose_5
+ 'right_eye_1',
+ 'right_eye_2',
+ 'right_eye_3',
+ 'right_eye_4',
+ 'right_eye_5',
+ 'right_eye_6',
+ 'left_eye_4',
+ 'left_eye_3',
+ 'left_eye_2',
+ 'left_eye_1',
+ 'left_eye_6',
+ 'left_eye_5',
+ 'right_mouth_1', # original name: mouth_1
+ 'right_mouth_2', # original name: mouth_2
+ 'right_mouth_3', # original name: mouth_3
+ 'mouth_top', # original name: mouth_4
+ 'left_mouth_3', # original name: mouth_5
+ 'left_mouth_2', # original name: mouth_6
+ 'left_mouth_1', # original name: mouth_7
+ 'left_mouth_5', # original name: mouth_8
+ 'left_mouth_4', # original name: mouth_9
+ 'mouth_bottom', # original name: mouth_10
+ 'right_mouth_4', # original name: mouth_11
+ 'right_mouth_5', # original name: mouth_12
+ 'right_lip_1', # original name: lip_1
+ 'right_lip_2', # original name: lip_2
+ 'lip_top', # original name: lip_3
+ 'left_lip_2', # original name: lip_4
+ 'left_lip_1', # original name: lip_5
+ 'left_lip_3', # original name: lip_6
+ 'lip_bottom', # original name: lip_7
+ 'right_lip_3', # original name: lip_8
+ 'left_hand_root',
+ 'left_thumb_1',
+ 'left_thumb_2',
+ 'left_thumb_3',
+ 'left_thumb',
+ 'left_index_1',
+ 'left_index_2',
+ 'left_index_3',
+ 'left_index',
+ 'left_middle_1',
+ 'left_middle_2',
+ 'left_middle_3',
+ 'left_middle',
+ 'left_ring_1',
+ 'left_ring_2',
+ 'left_ring_3',
+ 'left_ring',
+ 'left_pinky_1',
+ 'left_pinky_2',
+ 'left_pinky_3',
+ 'left_pinky',
+ 'right_hand_root',
+ 'right_thumb_1',
+ 'right_thumb_2',
+ 'right_thumb_3',
+ 'right_thumb',
+ 'right_index_1',
+ 'right_index_2',
+ 'right_index_3',
+ 'right_index',
+ 'right_middle_1',
+ 'right_middle_2',
+ 'right_middle_3',
+ 'right_middle',
+ 'right_ring_1',
+ 'right_ring_2',
+ 'right_ring_3',
+ 'right_ring',
+ 'right_pinky_1',
+ 'right_pinky_2',
+ 'right_pinky_3',
+ 'right_pinky',
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/crowdpose.py b/detrsmpl/core/conventions/keypoints_mapping/crowdpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f6b154c6b3d1a434225d19ac0755adb39dc52e2
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/crowdpose.py
@@ -0,0 +1,5 @@
+CROWDPOSE_KEYPOINTS = [
+ 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
+ 'left_wrist', 'right_wrist', 'left_hip', 'right_hip', 'left_knee',
+ 'right_knee', 'left_ankle', 'right_ankle', 'head', 'neck'
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/face3d.py b/detrsmpl/core/conventions/keypoints_mapping/face3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb4c3312005f687c6f50f1175009e0d2d11118fb
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/face3d.py
@@ -0,0 +1,4 @@
+FACE3D_IND = [
+ 'right_eye_1', 'right_eye_4', 'left_eye_4', 'left_eye_1', 'nose_middle',
+ 'right_mouth_1', 'left_mouth_1'
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/flame.py b/detrsmpl/core/conventions/keypoints_mapping/flame.py
new file mode 100644
index 0000000000000000000000000000000000000000..15a6d8051ce369abecbb319f956fdbdb35fbc979
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/flame.py
@@ -0,0 +1,75 @@
+FLAME_73_KEYPOINTS = [
+ 'head',
+ 'neck',
+ 'jaw',
+ 'left_eye',
+ 'right_eye',
+ 'right_eyebrow_1',
+ 'right_eyebrow_2',
+ 'right_eyebrow_3',
+ 'right_eyebrow_4',
+ 'right_eyebrow_5',
+ 'left_eyebrow_5',
+ 'left_eyebrow_4',
+ 'left_eyebrow_3',
+ 'left_eyebrow_2',
+ 'left_eyebrow_1',
+ 'nosebridge_1',
+ 'nosebridge_2',
+ 'nosebridge_3',
+ 'nosebridge_4',
+ 'right_nose_2',
+ 'right_nose_1',
+ 'nose_middle',
+ 'left_nose_1',
+ 'left_nose_2',
+ 'right_eye_1',
+ 'right_eye_2',
+ 'right_eye_3',
+ 'right_eye_4',
+ 'right_eye_5',
+ 'right_eye_6',
+ 'left_eye_4',
+ 'left_eye_3',
+ 'left_eye_2',
+ 'left_eye_1',
+ 'left_eye_6',
+ 'left_eye_5',
+ 'right_mouth_1',
+ 'right_mouth_2',
+ 'right_mouth_3',
+ 'mouth_top',
+ 'left_mouth_3',
+ 'left_mouth_2',
+ 'left_mouth_1',
+ 'left_mouth_5',
+ 'left_mouth_4',
+ 'mouth_bottom',
+ 'right_mouth_4',
+ 'right_mouth_5',
+ 'right_lip_1',
+ 'right_lip_2',
+ 'lip_top',
+ 'left_lip_2',
+ 'left_lip_1',
+ 'left_lip_3',
+ 'lip_bottom',
+ 'right_lip_3',
+ 'right_contour_1',
+ 'right_contour_2',
+ 'right_contour_3',
+ 'right_contour_4',
+ 'right_contour_5',
+ 'right_contour_6',
+ 'right_contour_7',
+ 'right_contour_8',
+ 'contour_middle',
+ 'left_contour_8',
+ 'left_contour_7',
+ 'left_contour_6',
+ 'left_contour_5',
+ 'left_contour_4',
+ 'left_contour_3',
+ 'left_contour_2',
+ 'left_contour_1',
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/gta.py b/detrsmpl/core/conventions/keypoints_mapping/gta.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d9446ce6c7d56a382b76fb1a6d640df58c16937
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/gta.py
@@ -0,0 +1,205 @@
+# ORIGINAL_NAMES = [
+# 'head_top', # 00, extrapolate 02-01
+# 'head_center', # 01
+# 'neck', # 02
+# 'right_clavicle', # 03
+# 'right_shoulder', # 04
+# 'right_elbow', # 05
+# 'right_wrist', # 06
+# 'left_clavicle', # 07
+# 'left_shoulder', # 08
+# 'left_elbow', # 09
+# 'left_wrist', # 10
+# 'spine0', # 11
+# 'spine1', # 12
+# 'spine2', # 13
+# 'spine3', # 14
+# 'spine4', # 15
+# 'right_hip', # 16
+# 'right_knee', # 17
+# 'right_ankle', # 18
+# 'left_hip', # 19
+# 'left_knee', # 20
+# 'left_ankle', # 21
+# 'SKEL_ROOT', # 22
+# 'FB_R_Brow_Out_000', # 23
+# 'SKEL_L_Toe0', # 24
+# 'MH_R_Elbow', # 25
+# 'SKEL_L_Finger01', # 26
+# 'SKEL_L_Finger02', # 27
+# 'SKEL_L_Finger31', # 28
+# 'SKEL_L_Finger32', # 29
+# 'SKEL_L_Finger41', # 30
+# 'SKEL_L_Finger42', # 31
+# 'SKEL_L_Finger11', # 32
+# 'SKEL_L_Finger12', # 33
+# 'SKEL_L_Finger21', # 34
+# 'SKEL_L_Finger22', # 35
+# 'RB_L_ArmRoll', # 36
+# 'IK_R_Hand', # 37
+# 'RB_R_ThighRoll', # 38
+# 'FB_R_Lip_Corner_000', # 39
+# 'SKEL_Pelvis', # 40
+# 'IK_Head', # 41
+# 'MH_R_Knee', # 42
+# 'FB_LowerLipRoot_000', # 43
+# 'FB_R_Lip_Top_000', # 44
+# 'FB_R_CheekBone_000', # 45
+# 'FB_UpperLipRoot_000', # 46
+# 'FB_L_Lip_Top_000', # 47
+# 'FB_LowerLip_000', # 48
+# 'SKEL_R_Toe0', # 49
+# 'FB_L_CheekBone_000', # 50
+# 'MH_L_Elbow', # 51
+# 'RB_L_ThighRoll', # 52
+# 'PH_R_Foot', # 53
+# 'FB_L_Eye_000', # 54
+# 'SKEL_L_Finger00', # 55
+# 'SKEL_L_Finger10', # 56
+# 'SKEL_L_Finger20', # 57
+# 'SKEL_L_Finger30', # 58
+# 'SKEL_L_Finger40', # 59
+# 'FB_R_Eye_000', # 60
+# 'PH_R_Hand', # 61
+# 'FB_L_Lip_Corner_000', # 62
+# 'IK_R_Foot', # 63
+# 'RB_Neck_1', # 64
+# 'IK_L_Hand', # 65
+# 'RB_R_ArmRoll', # 66
+# 'FB_Brow_Centre_000', # 67
+# 'FB_R_Lid_Upper_000', # 68
+# 'RB_R_ForeArmRoll', # 69
+# 'FB_L_Lid_Upper_000', # 70
+# 'MH_L_Knee', # 71
+# 'FB_Jaw_000', # 72
+# 'FB_L_Lip_Bot_000', # 73
+# 'FB_Tongue_000', # 74
+# 'FB_R_Lip_Bot_000', # 75
+# 'IK_Root', # 76
+# 'PH_L_Foot', # 77
+# 'FB_L_Brow_Out_000', # 78
+# 'SKEL_R_Finger00', # 79
+# 'SKEL_R_Finger10', # 80
+# 'SKEL_R_Finger20', # 81
+# 'SKEL_R_Finger30', # 82
+# 'SKEL_R_Finger40', # 83
+# 'PH_L_Hand', # 84
+# 'RB_L_ForeArmRoll', # 85
+# 'FB_UpperLip_000', # 86
+# 'SKEL_R_Finger01', # 87
+# 'SKEL_R_Finger02', # 88
+# 'SKEL_R_Finger31', # 89
+# 'SKEL_R_Finger32', # 90
+# 'SKEL_R_Finger41', # 91
+# 'SKEL_R_Finger42', # 92
+# 'SKEL_R_Finger11', # 93
+# 'SKEL_R_Finger12', # 94
+# 'SKEL_R_Finger21', # 95
+# 'SKEL_R_Finger22', # 96
+# 'FACIAL_facialRoot', # 97
+# 'IK_L_Foot', # 98
+# 'interpolated_nose' # 99, mid-point of 45-50
+# ]
+
+GTA_KEYPOINTS = [
+ 'gta_head_top', # 00
+ 'head', # 01 - head_center
+ 'neck', # 02 - neck
+ 'gta_right_clavicle', # 03
+ 'right_shoulder', # 04 - right_shoulder
+ 'right_elbow', # 05 - right_elbow
+ 'right_wrist', # 06 - right_wrist
+ 'gta_left_clavicle', # 07
+ 'left_shoulder', # 08 - left_shoulder
+ 'left_elbow', # 09 - left_elbow
+ 'left_wrist', # 10 - left_wrist
+ 'spine_2', # 11 - spine0
+ 'gta_spine1', # 12
+ 'spine_1', # 13 - spine2
+ 'pelvis', # 14 - pelvis
+ 'gta_spine4', # 15
+ 'right_hip', # 16 - right_hip
+ 'right_knee', # 17 - right_knee
+ 'right_ankle', # 18 - right_ankle
+ 'left_hip', # 19 - left_hip
+ 'left_knee', # 20 - left_knee
+ 'left_ankle', # 21 - left_ankle
+ 'gta_SKEL_ROOT', # 22
+ 'gta_FB_R_Brow_Out_000', # 23
+ 'left_foot', # 24 - SKEL_L_Toe0
+ 'gta_MH_R_Elbow', # 25
+ 'left_thumb_2', # 26 - SKEL_L_Finger01
+ 'left_thumb_3', # 27 - SKEL_L_Finger02
+ 'left_ring_2', # 28 - SKEL_L_Finger31
+ 'left_ring_3', # 29 - SKEL_L_Finger32
+ 'left_pinky_2', # 30 - SKEL_L_Finger41
+ 'left_pinky_3', # 31 - SKEL_L_Finger42
+ 'left_index_2', # 32 - SKEL_L_Finger11
+ 'left_index_3', # 33 - SKEL_L_Finger12
+ 'left_middle_2', # 34 - SKEL_L_Finger21
+ 'left_middle_3', # 35 - SKEL_L_Finger22
+ 'gta_RB_L_ArmRoll', # 36
+ 'gta_IK_R_Hand', # 37
+ 'gta_RB_R_ThighRoll', # 38
+ 'gta_FB_R_Lip_Corner_000', # 39
+ 'gta_SKEL_Pelvis', # 40
+ 'gta_IK_Head', # 41
+ 'gta_MH_R_Knee', # 42
+ 'gta_FB_LowerLipRoot_000', # 43
+ 'gta_FB_R_Lip_Top_000', # 44
+ 'gta_FB_R_CheekBone_000', # 45
+ 'gta_FB_UpperLipRoot_000', # 46
+ 'gta_FB_L_Lip_Top_000', # 47
+ 'gta_FB_LowerLip_000', # 48
+ 'right_foot', # 49 - SKEL_R_Toe0
+ 'gta_FB_L_CheekBone_000', # 50
+ 'gta_MH_L_Elbow', # 51
+ 'gta_RB_L_ThighRoll', # 52
+ 'gta_PH_R_Foot', # 53
+ 'left_eye', # 54 - FB_L_Eye_000
+ 'gta_SKEL_L_Finger00', # 55
+ 'left_index_1', # 56 - SKEL_L_Finger10
+ 'left_middle_1', # 57 - SKEL_L_Finger20
+ 'left_ring_1', # 58 - SKEL_L_Finger30
+ 'left_pinky_1', # 59 - SKEL_L_Finger40
+ 'right_eye', # 60 - FB_R_Eye_000
+ 'gta_PH_R_Hand', # 61
+ 'gta_FB_L_Lip_Corner_000', # 62
+ 'gta_IK_R_Foot', # 63
+ 'gta_RB_Neck_1', # 64
+ 'gta_IK_L_Hand', # 65
+ 'gta_RB_R_ArmRoll', # 66
+ 'gta_FB_Brow_Centre_000', # 67
+ 'gta_FB_R_Lid_Upper_000', # 68
+ 'gta_RB_R_ForeArmRoll', # 69
+ 'gta_FB_L_Lid_Upper_000', # 70
+ 'gta_MH_L_Knee', # 71
+ 'gta_FB_Jaw_000', # 72
+ 'gta_FB_L_Lip_Bot_000', # 73
+ 'gta_FB_Tongue_000', # 74
+ 'gta_FB_R_Lip_Bot_000', # 75
+ 'gta_IK_Root', # 76
+ 'gta_PH_L_Foot', # 77
+ 'gta_FB_L_Brow_Out_000', # 78
+ 'gta_SKEL_R_Finger00', # 79
+ 'right_index_1', # 80 - SKEL_R_Finger10
+ 'right_middle_1', # 81 - SKEL_R_Finger20
+ 'right_ring_1', # 82 - SKEL_R_Finger30
+ 'right_pinky_1', # 83 - SKEL_R_Finger40
+ 'gta_PH_L_Hand', # 84
+ 'gta_RB_L_ForeArmRoll', # 85
+ 'gta_FB_UpperLip_000', # 86
+ 'right_thumb_2', # 87 - SKEL_R_Finger01
+ 'right_thumb_3', # 88 - SKEL_R_Finger02
+ 'right_ring_2', # 89 - SKEL_R_Finger31
+ 'right_ring_3', # 90 - SKEL_R_Finger32
+ 'right_pinky_2', # 91 - SKEL_R_Finger41
+ 'right_pinky_3', # 92 - SKEL_R_Finger42
+ 'right_index_2', # 93 - SKEL_R_Finger11
+ 'right_index_3', # 94 - SKEL_R_Finger12
+ 'right_middle_2', # 95 - SKEL_R_Finger21
+ 'right_middle_3', # 96 - SKEL_R_Finger22
+ 'gta_FACIAL_facialRoot', # 97
+ 'gta_IK_L_Foot', # 98
+ 'nose' # 99 - interpolated nose
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/h36m.py b/detrsmpl/core/conventions/keypoints_mapping/h36m.py
new file mode 100644
index 0000000000000000000000000000000000000000..34d7cd01de2b723398810256a475f94298ed2bc2
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/h36m.py
@@ -0,0 +1,59 @@
+H36M_KEYPOINTS = [
+ 'pelvis_extra',
+ 'left_hip_extra',
+ 'left_knee',
+ 'left_ankle',
+ 'right_hip_extra',
+ 'right_knee',
+ 'right_ankle',
+ 'spine_extra',
+ 'neck_extra',
+ 'head_extra',
+ 'headtop',
+ 'left_shoulder',
+ 'left_elbow',
+ 'left_wrist',
+ 'right_shoulder',
+ 'right_elbow',
+ 'right_wrist',
+]
+
+H36M_KEYPOINTS_MMPOSE = [
+ 'pelvis_extra',
+ 'right_hip_extra',
+ 'right_knee',
+ 'right_ankle',
+ 'left_hip_extra',
+ 'left_knee',
+ 'left_ankle',
+ 'spine_extra',
+ 'neck_extra',
+ 'head_extra',
+ 'headtop',
+ 'left_shoulder',
+ 'left_elbow',
+ 'left_wrist',
+ 'right_shoulder',
+ 'right_elbow',
+ 'right_wrist',
+]
+
+H36M_KEYPOINTS_SMPLX = [
+ 'pelvis',
+ 'left_hip',
+ 'left_knee',
+ 'left_ankle',
+ 'right_hip',
+ 'right_knee',
+ 'right_ankle',
+ 'spine',
+ 'neck', # 'thorax',
+ 'neck/nose',
+ 'head', # 'head_h36m',
+ 'left_shoulder',
+ 'left_elbow',
+ 'left_wrist',
+ 'right_shoulder',
+ 'right_elbow',
+ 'right_wrist'
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/human_data.py b/detrsmpl/core/conventions/keypoints_mapping/human_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..c96f9e8af2ba621987f98a3d52071e327c990ea4
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/human_data.py
@@ -0,0 +1,534 @@
+from collections import defaultdict
+
+HUMAN_DATA = [
+ 'pelvis',
+ 'left_hip',
+ 'right_hip',
+ 'spine_1',
+ 'left_knee',
+ 'right_knee',
+ 'spine_2',
+ 'left_ankle',
+ 'right_ankle',
+ 'spine_3',
+ 'left_foot',
+ 'right_foot',
+ 'neck',
+ 'left_collar',
+ 'right_collar',
+ 'head',
+ 'left_shoulder',
+ 'right_shoulder',
+ 'left_elbow',
+ 'right_elbow',
+ 'left_wrist',
+ 'right_wrist',
+ 'jaw',
+ 'left_eyeball',
+ 'right_eyeball',
+ 'left_index_1',
+ 'left_index_2',
+ 'left_index_3',
+ 'left_middle_1',
+ 'left_middle_2',
+ 'left_middle_3',
+ 'left_pinky_1',
+ 'left_pinky_2',
+ 'left_pinky_3',
+ 'left_ring_1',
+ 'left_ring_2',
+ 'left_ring_3',
+ 'left_thumb_1',
+ 'left_thumb_2',
+ 'left_thumb_3',
+ 'right_index_1',
+ 'right_index_2',
+ 'right_index_3',
+ 'right_middle_1',
+ 'right_middle_2',
+ 'right_middle_3',
+ 'right_pinky_1',
+ 'right_pinky_2',
+ 'right_pinky_3',
+ 'right_ring_1',
+ 'right_ring_2',
+ 'right_ring_3',
+ 'right_thumb_1',
+ 'right_thumb_2',
+ 'right_thumb_3',
+ 'nose',
+ 'right_eye',
+ 'left_eye',
+ 'right_ear',
+ 'left_ear',
+ 'left_bigtoe',
+ 'left_smalltoe',
+ 'left_heel',
+ 'right_bigtoe',
+ 'right_smalltoe',
+ 'right_heel',
+ 'left_thumb',
+ 'left_index',
+ 'left_middle',
+ 'left_ring',
+ 'left_pinky',
+ 'right_thumb',
+ 'right_index',
+ 'right_middle',
+ 'right_ring',
+ 'right_pinky',
+ 'right_eyebrow_1',
+ 'right_eyebrow_2',
+ 'right_eyebrow_3',
+ 'right_eyebrow_4',
+ 'right_eyebrow_5',
+ 'left_eyebrow_5',
+ 'left_eyebrow_4',
+ 'left_eyebrow_3',
+ 'left_eyebrow_2',
+ 'left_eyebrow_1',
+ 'nosebridge_1',
+ 'nosebridge_2',
+ 'nosebridge_3',
+ 'nosebridge_4',
+ 'right_nose_2', # original name: nose_1
+ 'right_nose_1', # original name: nose_2
+ 'nose_middle', # original name: nose_3
+ 'left_nose_1', # original name: nose_4
+ 'left_nose_2', # original name: nose_5
+ 'right_eye_1',
+ 'right_eye_2',
+ 'right_eye_3',
+ 'right_eye_4',
+ 'right_eye_5',
+ 'right_eye_6',
+ 'left_eye_4',
+ 'left_eye_3',
+ 'left_eye_2',
+ 'left_eye_1',
+ 'left_eye_6',
+ 'left_eye_5',
+ 'right_mouth_1', # original name: mouth_1
+ 'right_mouth_2', # original name: mouth_2
+ 'right_mouth_3', # original name: mouth_3
+ 'mouth_top', # original name: mouth_4
+ 'left_mouth_3', # original name: mouth_5
+ 'left_mouth_2', # original name: mouth_6
+ 'left_mouth_1', # original name: mouth_7
+ 'left_mouth_5', # original name: mouth_8
+ 'left_mouth_4', # original name: mouth_9
+ 'mouth_bottom', # original name: mouth_10
+ 'right_mouth_4', # original name: mouth_11
+ 'right_mouth_5', # original name: mouth_12
+ 'right_lip_1', # original name: lip_1
+ 'right_lip_2', # original name: lip_2
+ 'lip_top', # original name: lip_3
+ 'left_lip_2', # original name: lip_4
+ 'left_lip_1', # original name: lip_5
+ 'left_lip_3', # original name: lip_6
+ 'lip_bottom', # original name: lip_7
+ 'right_lip_3', # original name: lip_8
+ 'right_contour_1', # original name: face_contour_1
+ 'right_contour_2', # original name: face_contour_2
+ 'right_contour_3', # original name: face_contour_3
+ 'right_contour_4', # original name: face_contour_4
+ 'right_contour_5', # original name: face_contour_5
+ 'right_contour_6', # original name: face_contour_6
+ 'right_contour_7', # original name: face_contour_7
+ 'right_contour_8', # original name: face_contour_8
+ 'contour_middle', # original name: face_contour_9
+ 'left_contour_8', # original name: face_contour_10
+ 'left_contour_7', # original name: face_contour_11
+ 'left_contour_6', # original name: face_contour_12
+ 'left_contour_5', # original name: face_contour_13
+ 'left_contour_4', # original name: face_contour_14
+ 'left_contour_3', # original name: face_contour_15
+ 'left_contour_2', # original name: face_contour_16
+ 'left_contour_1', # original name: face_contour_17
+ # J_regressor_extra
+ 'right_hip_extra',
+ 'left_hip_extra',
+ 'neck_extra', # LSP
+ 'headtop', # LSP mpii peen_action mpi_inf_3dhp
+ 'pelvis_extra', # MPII
+ 'thorax_extra', # MPII
+ 'spine_extra', # H36M
+ 'jaw_extra', # H36M
+ 'head_extra', # H36M
+ # openpose
+ 'nose_openpose',
+ 'neck_openpose',
+ 'right_shoulder_openpose',
+ 'right_elbow_openpose',
+ 'right_wrist_openpose',
+ 'left_shoulder_openpose',
+ 'left_elbow_openpose',
+ 'left_wrist_openpose',
+ 'pelvis_openpose',
+ 'right_hip_openpose',
+ 'right_knee_openpose',
+ 'right_ankle_openpose',
+ 'left_hip_openpose',
+ 'left_knee_openpose',
+ 'left_ankle_openpose',
+ 'right_eye_openpose',
+ 'left_eye_openpose',
+ 'right_ear_openpose',
+ 'left_ear_openpose',
+ 'left_bigtoe_openpose',
+ 'left_smalltoe_openpose',
+ 'left_heel_openpose',
+ 'right_bigtoe_openpose',
+ 'right_smalltoe_openpose',
+ 'right_heel_openpose',
+ # 3dhp
+ 'spine_4_3dhp',
+ 'left_clavicle_3dhp',
+ 'right_clavicle_3dhp',
+ 'left_hand_3dhp',
+ 'right_hand_3dhp',
+ 'left_toe_3dhp',
+ 'right_toe_3dhp',
+ 'head_h36m', # H36M GT
+ 'headtop_h36m', # H36M GT
+ 'head_bottom_pt', # pose track
+ 'left_hand', # SMPL
+ 'right_hand', # SMPL
+]
+
+APPROXIMATE_MAPPING_LIST = [
+ # extra
+ ['pelvis', 'pelvis_openpose', 'pelvis_extra'],
+ ['left_hip', 'left_hip_openpose', 'left_hip_extra'],
+ ['right_hip', 'right_hip_openpose', 'right_hip_extra'],
+ ['neck', 'neck_openpose', 'neck_extra'],
+ ['jaw', 'jaw_extra'],
+ ['head_extra', 'head_h36m'],
+ ['headtop', 'headtop_h36m'],
+ # 3dhp
+ ['left_hand', 'left_hand_3dhp'],
+ ['right_hand', 'right_hand_3dhp'],
+ # openpose
+ ['nose', 'nose_openpose'],
+ ['right_shoulder', 'right_shoulder_openpose'],
+ ['right_elbow', 'right_elbow_openpose'],
+ ['right_wrist', 'right_wrist_openpose'],
+ ['left_shoulder', 'left_shoulder_openpose'],
+ ['left_elbow', 'left_elbow_openpose'],
+ ['left_wrist', 'left_wrist_openpose'],
+ ['right_knee', 'right_knee_openpose'],
+ ['right_ankle', 'right_ankle_openpose'],
+ ['left_knee', 'left_knee_openpose'],
+ ['left_ankle', 'left_ankle_openpose'],
+ ['right_eye', 'right_eye_openpose'],
+ ['left_eye', 'left_eye_openpose'],
+ ['right_ear', 'right_ear_openpose'],
+ ['left_ear', 'left_ear_openpose'],
+ ['left_bigtoe', 'left_bigtoe_openpose'],
+ ['left_smalltoe', 'left_smalltoe_openpose'],
+ ['left_heel', 'left_heel_openpose'],
+ ['right_bigtoe', 'right_bigtoe_openpose'],
+ ['right_smalltoe', 'right_smalltoe_openpose'],
+ ['right_heel', 'right_heel_openpose'],
+]
+
+APPROXIMATE_MAP = defaultdict(list)
+for group in APPROXIMATE_MAPPING_LIST:
+ for member in group:
+ for other_member in group:
+ if member == other_member:
+ continue
+ APPROXIMATE_MAP[member].append(other_member)
+
+HUMAN_DATA_HEAD = [
+ 'head', 'jaw', 'left_eyeball', 'right_eyeball', 'nose', 'right_eye',
+ 'left_eye', 'right_ear', 'left_ear', 'right_eyebrow_1', 'right_eyebrow_2',
+ 'right_eyebrow_3', 'right_eyebrow_4', 'right_eyebrow_5', 'left_eyebrow_5',
+ 'left_eyebrow_4', 'left_eyebrow_3', 'left_eyebrow_2', 'left_eyebrow_1',
+ 'nosebridge_1', 'nosebridge_2', 'nosebridge_3', 'nosebridge_4',
+ 'right_nose_2', 'right_nose_1', 'nose_middle', 'left_nose_1',
+ 'left_nose_2', 'right_eye_1', 'right_eye_2', 'right_eye_3', 'right_eye_4',
+ 'right_eye_5', 'right_eye_6', 'left_eye_4', 'left_eye_3', 'left_eye_2',
+ 'left_eye_1', 'left_eye_6', 'left_eye_5', 'right_mouth_1', 'right_mouth_2',
+ 'right_mouth_3', 'mouth_top', 'left_mouth_3', 'left_mouth_2',
+ 'left_mouth_1', 'left_mouth_5', 'left_mouth_4', 'mouth_bottom',
+ 'right_mouth_4', 'right_mouth_5', 'right_lip_1', 'right_lip_2', 'lip_top',
+ 'left_lip_2', 'left_lip_1', 'left_lip_3', 'lip_bottom', 'right_lip_3',
+ 'right_contour_1', 'right_contour_2', 'right_contour_3', 'right_contour_4',
+ 'right_contour_5', 'right_contour_6', 'right_contour_7', 'right_contour_8',
+ 'contour_middle', 'left_contour_8', 'left_contour_7', 'left_contour_6',
+ 'left_contour_5', 'left_contour_4', 'left_contour_3', 'left_contour_2',
+ 'left_contour_1', 'headtop', 'jaw_extra', 'head_extra', 'nose_openpose',
+ 'right_eye_openpose', 'left_eye_openpose', 'right_ear_openpose',
+ 'left_ear_openpose', 'headtop_h36m', 'head_bottom_pt', 'head_h36m'
+]
+
+HUMAN_DATA_LEFT_HAND = [
+ 'left_index_1', 'left_index_2', 'left_index_3', 'left_middle_1',
+ 'left_middle_2', 'left_middle_3', 'left_pinky_1', 'left_pinky_2',
+ 'left_pinky_3', 'left_ring_1', 'left_ring_2', 'left_ring_3',
+ 'left_thumb_1', 'left_thumb_2', 'left_thumb_3', 'left_thumb', 'left_index',
+ 'left_middle', 'left_ring', 'left_pinky', 'left_hand_3dhp', 'left_hand'
+]
+
+HUMAN_DATA_RIGHT_HAND = [
+ 'right_index_1', 'right_index_2', 'right_index_3', 'right_middle_1',
+ 'right_middle_2', 'right_middle_3', 'right_pinky_1', 'right_pinky_2',
+ 'right_pinky_3', 'right_ring_1', 'right_ring_2', 'right_ring_3',
+ 'right_thumb_1', 'right_thumb_2', 'right_thumb_3', 'right_thumb',
+ 'right_index', 'right_middle', 'right_ring', 'right_pinky',
+ 'right_hand_3dhp', 'right_hand'
+]
+
+HUMAN_DATA_SHOULDER = [
+ 'left_shoulder', 'left_shoulder_openpose', 'right_shoulder',
+ 'right_shoulder_openpose'
+]
+
+HUMAN_DATA_HIP = [
+ 'left_hip', 'left_hip_openpose', 'left_hip_extra', 'right_hip',
+ 'right_hip_openpose', 'right_hip_extra'
+]
+
+HUMAN_DATA_BODY = HUMAN_DATA_SHOULDER + HUMAN_DATA_HIP + [
+ 'pelvis', 'spine_1', 'left_knee', 'right_knee', 'spine_2', 'left_ankle',
+ 'right_ankle', 'spine_3', 'left_foot', 'right_foot', 'neck', 'left_collar',
+ 'right_collar', 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist',
+ 'left_bigtoe', 'left_smalltoe', 'left_heel', 'right_bigtoe',
+ 'right_smalltoe', 'right_heel', 'neck_extra', 'pelvis_extra',
+ 'thorax_extra', 'spine_extra', 'neck_openpose', 'right_elbow_openpose',
+ 'right_wrist_openpose', 'left_elbow_openpose', 'left_wrist_openpose',
+ 'pelvis_openpose', 'right_knee_openpose', 'right_ankle_openpose',
+ 'left_knee_openpose', 'left_ankle_openpose', 'left_bigtoe_openpose',
+ 'left_smalltoe_openpose', 'left_heel_openpose', 'right_bigtoe_openpose',
+ 'right_smalltoe_openpose', 'right_heel_openpose', 'spine_4_3dhp',
+ 'left_clavicle_3dhp', 'right_clavicle_3dhp', 'left_toe_3dhp',
+ 'right_toe_3dhp'
+]
+
+HUMAN_DATA_PARTS = {
+ 'head': HUMAN_DATA_HEAD,
+ 'left_hand': HUMAN_DATA_LEFT_HAND,
+ 'right_hand': HUMAN_DATA_RIGHT_HAND,
+ 'shoulder': HUMAN_DATA_SHOULDER,
+ 'hip': HUMAN_DATA_HIP,
+ 'body': HUMAN_DATA_BODY
+}
+
+HUMAN_DATA_LIMBS = {
+ 'body': [
+ ['pelvis', 'left_hip'],
+ ['pelvis', 'right_hip'],
+ ['pelvis', 'spine_1'],
+ ['spine_1', 'spine_2'],
+ ['spine_2', 'spine_3'],
+ ['spine_3', 'neck'],
+ ['neck', 'head'],
+ ['left_ankle', 'left_knee'],
+ ['left_knee', 'left_hip'],
+ ['right_ankle', 'right_knee'],
+ ['right_knee', 'right_hip'],
+ ['right_ankle', 'right_foot'],
+ ['left_ankle', 'left_foot'],
+ ['left_hip', 'right_hip'],
+ ['left_shoulder', 'left_hip'],
+ ['right_shoulder', 'right_hip'],
+ ['left_collar', 'spine_3'],
+ ['right_collar', 'spine_3'],
+ ['right_collar', 'right_shoulder'],
+ ['left_collar', 'left_shoulder'],
+ ['left_shoulder', 'right_shoulder'],
+ ['left_shoulder', 'left_elbow'],
+ ['right_shoulder', 'right_elbow'],
+ ['left_elbow', 'left_wrist'],
+ ['right_elbow', 'right_wrist'],
+ ['left_ankle', 'left_bigtoe'],
+ ['left_ankle', 'left_smalltoe'],
+ ['left_ankle', 'left_heel'],
+ ['right_ankle', 'right_bigtoe'],
+ ['right_ankle', 'right_smalltoe'],
+ ['right_ankle', 'right_heel'],
+ ['left_shoulder', 'left_ear'],
+ ['right_shoulder', 'right_ear'],
+ ['right_ear', 'right_eye'],
+ ['right_eye', 'nose'],
+ ['nose', 'left_eye'],
+ ['left_eye', 'left_ear'],
+ ['nose', 'jaw'],
+ ['jaw', 'neck'],
+ # extra limbs
+ ['pelvis_extra', 'left_hip_extra'],
+ ['pelvis_extra', 'right_hip_extra'],
+ ['left_hip_extra', 'left_knee'],
+ ['right_hip_extra', 'right_knee'],
+ ['left_hip_extra', 'left_shoulder'],
+ ['right_hip_extra', 'right_shoulder'],
+ ['pelvis_extra', 'spine_1'],
+ ['spine_2', 'spine_extra'],
+ ['spine_extra', 'spine_3'],
+ ['spine_3', 'thorax_extra'],
+ ['thorax_extra', 'left_shoulder'],
+ ['thorax_extra', 'right_shoulder'],
+ ['thorax_extra', 'neck_extra'],
+ ['neck_extra', 'jaw_extra'],
+ ['jaw_extra', 'nose'],
+ ['head_extra', 'nose'],
+ ['head_extra', 'headtop'],
+ ['head_extra', 'neck_extra'],
+ ['neck_extra', 'headtop'],
+ ['right_hip_extra', 'left_hip_extra'],
+ ['right_eye_openpose', 'right_ear_openpose'],
+ ['left_ear_openpose', 'left_eye_openpose'],
+ ['right_shoulder_openpose', 'right_elbow_openpose'],
+ ['right_elbow_openpose', 'right_wrist_openpose'],
+ ['left_shoulder_openpose', 'right_shoulder_openpose'],
+ ['left_shoulder_openpose', 'left_elbow_openpose'],
+ ['left_elbow_openpose', 'left_wrist_openpose'],
+ ['pelvis_openpose', 'headtop'],
+ ['pelvis_openpose', 'headtop'],
+ ['neck_extra', 'right_hip_openpose'],
+ ['neck_extra', 'left_hip_openpose'],
+ ['right_hip_openpose', 'right_shoulder_openpose'],
+ ['right_hip_openpose', 'right_knee_openpose'],
+ ['left_hip_openpose', 'left_shoulder_openpose'],
+ ['left_hip_openpose', 'left_knee_openpose'],
+ ['right_knee_openpose', 'right_ankle_openpose'],
+ ['left_knee_openpose', 'left_ankle_openpose'],
+ ['right_ankle_openpose', 'right_heel_openpose'],
+ ['left_ankle_openpose', 'left_heel_openpose'],
+ ['right_heel_openpose', 'right_bigtoe_openpose'],
+ ['right_heel_openpose', 'right_smalltoe_openpose'],
+ ['left_ankle_openpose', 'left_bigtoe_openpose'],
+ ['left_ankle_openpose', 'left_smalltoe_openpose'],
+ ],
+ 'face': [['right_contour_1', 'right_contour_2'],
+ ['right_contour_2', 'right_contour_3'],
+ ['right_contour_3', 'right_contour_4'],
+ ['right_contour_4', 'right_contour_5'],
+ ['right_contour_5', 'right_contour_6'],
+ ['right_contour_6', 'right_contour_7'],
+ ['right_contour_7', 'right_contour_8'],
+ ['right_contour_8', 'contour_middle'],
+ ['contour_middle', 'left_contour_8'],
+ ['left_contour_8', 'left_contour_7'],
+ ['left_contour_7', 'left_contour_6'],
+ ['left_contour_6', 'left_contour_5'],
+ ['left_contour_5', 'left_contour_4'],
+ ['left_contour_4', 'left_contour_3'],
+ ['left_contour_3', 'left_contour_2'],
+ ['left_contour_2', 'left_contour_1']],
+ 'left_hand': [['left_wrist', 'left_thumb_1'],
+ ['left_thumb_1', 'left_thumb_2'],
+ ['left_thumb_2', 'left_thumb_3'],
+ ['left_thumb_3', 'left_thumb'],
+ ['left_wrist', 'left_index_1'],
+ ['left_index_1', 'left_index_2'],
+ ['left_index_2', 'left_index_3'],
+ ['left_index_3', 'left_index'],
+ ['left_wrist', 'left_middle_1'],
+ ['left_middle_1', 'left_middle_2'],
+ ['left_middle_2', 'left_middle_3'],
+ ['left_middle_3', 'left_middle'],
+ ['left_wrist', 'left_ring_1'],
+ ['left_ring_1', 'left_ring_2'],
+ ['left_ring_2', 'left_ring_3'],
+ ['left_ring_3', 'left_ring'],
+ ['left_wrist', 'left_pinky_1'],
+ ['left_pinky_1', 'left_pinky_2'],
+ ['left_pinky_2', 'left_pinky_3'],
+ ['left_pinky_3', 'left_pinky'],
+ ['left_wrist', 'left_thumb'],
+ ['left_wrist', 'left_index'],
+ ['left_wrist', 'left_middle'],
+ ['left_wrist', 'left_ring'],
+ ['left_wrist', 'left_pinky'],
+
+ ],
+ 'right_hand': [['right_wrist', 'right_thumb_1'],
+ ['right_thumb_1', 'right_thumb_2'],
+ ['right_thumb_2', 'right_thumb_3'],
+ ['right_thumb_3', 'right_thumb'],
+ ['right_wrist', 'right_index_1'],
+ ['right_index_1', 'right_index_2'],
+ ['right_index_2', 'right_index_3'],
+ ['right_index_3', 'right_index'],
+ ['right_wrist', 'right_middle_1'],
+ ['right_middle_1', 'right_middle_2'],
+ ['right_middle_2', 'right_middle_3'],
+ ['right_middle_3', 'right_middle'],
+ ['right_wrist', 'right_ring_1'],
+ ['right_ring_1', 'right_ring_2'],
+ ['right_ring_2', 'right_ring_3'],
+ ['right_ring_3', 'right_ring'],
+ ['right_wrist', 'right_pinky_1'],
+ ['right_pinky_1', 'right_pinky_2'],
+ ['right_pinky_2', 'right_pinky_3'],
+ ['right_pinky_3', 'right_pinky'],
+ ['right_wrist', 'right_thumb'],
+ ['right_wrist', 'right_index'],
+ ['right_wrist', 'right_middle'],
+ ['right_wrist', 'right_ring'],
+ ['right_wrist', 'right_pinky']],
+ 'right_eye':
+ [['right_eye_1', 'right_eye_2'], ['right_eye_2', 'right_eye_3'],
+ ['right_eye_3', 'right_eye_4'], ['right_eye_4', 'right_eye_5'],
+ ['right_eye_5', 'right_eye_6'], ['right_eye_6', 'right_eye_1'],
+ ['right_eyebrow_1', 'right_eyebrow_2'],
+ ['right_eyebrow_2', 'right_eyebrow_3'],
+ ['right_eyebrow_3', 'right_eyebrow_4'],
+ ['right_eyebrow_4', 'right_eyebrow_5']],
+ 'left_eye': [['left_eye_4', 'left_eye_3'], ['left_eye_3', 'left_eye_2'],
+ ['left_eye_2', 'left_eye_1'], ['left_eye_1', 'left_eye_6'],
+ ['left_eye_6', 'left_eye_5'], ['left_eye_5', 'left_eye_4'],
+ ['left_eyebrow_1', 'left_eyebrow_2'],
+ ['left_eyebrow_2', 'left_eyebrow_3'],
+ ['left_eyebrow_3', 'left_eyebrow_4'],
+ ['left_eyebrow_4', 'left_eyebrow_5']],
+ 'mouth':
+ [['right_mouth_1', 'right_mouth_2'], ['right_mouth_2', 'right_mouth_3'],
+ ['right_mouth_3', 'mouth_top'], ['mouth_top', 'left_mouth_3'],
+ ['left_mouth_3', 'left_mouth_2'], ['left_mouth_2', 'left_mouth_1'],
+ ['left_mouth_1', 'left_mouth_5'], ['left_mouth_5', 'left_mouth_4'],
+ ['left_mouth_4', 'mouth_bottom'], ['mouth_bottom', 'right_mouth_4'],
+ ['right_mouth_4', 'right_mouth_5'], ['right_mouth_5', 'right_mouth_1'],
+ ['right_lip_1', 'right_lip_2'], ['right_lip_2', 'lip_top'],
+ ['lip_top', 'left_lip_2'], ['left_lip_2', 'left_lip_1'],
+ ['left_lip_1', 'left_lip_3'], ['left_lip_3', 'lip_bottom'],
+ ['lip_bottom', 'right_lip_3'], ['right_lip_3', 'right_lip_1'],
+
+ ['nose', 'mouth_top'], ['mouth_top', 'right_contour_1'],
+ ['mouth_top', 'left_contour_1'],
+ ['jaw', 'left_contour_1'],
+ ['jaw', 'right_contour_1']
+
+ ],
+
+
+
+ 'nose': [
+ ['nosebridge_1', 'nosebridge_2'],
+ ['nosebridge_2', 'nosebridge_3'],
+ ['nosebridge_3', 'nosebridge_4'],
+ ['right_nose_2', 'right_nose_1'],
+ ['right_nose_1', 'nose_middle'],
+ ['nose_middle', 'left_nose_1'],
+ ['left_nose_1', 'left_nose_2'],
+ ]
+}
+
+HUMAN_DATA_LIMBS_INDEX = {}
+for k in HUMAN_DATA_LIMBS:
+ HUMAN_DATA_LIMBS_INDEX[k] = [[
+ HUMAN_DATA.index(limb[0]),
+ HUMAN_DATA.index(limb[1])
+ ] for limb in HUMAN_DATA_LIMBS[k]]
+
+HUMAN_DATA_PALETTE = {
+ 'left_eye': [[0, 0, 0]],
+ 'right_eye': [[255, 255, 0]],
+ 'nose': [[0, 0, 255]],
+ 'mouth': [[0, 255, 255]],
+ 'face': [[255, 0, 0]],
+ 'left_hand': [[0, 255, 0]],
+ 'right_hand': [[255, 0, 255]],
+}
diff --git a/detrsmpl/core/conventions/keypoints_mapping/hybrik.py b/detrsmpl/core/conventions/keypoints_mapping/hybrik.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f801e846f510e87d14d9df40f5c7745a7535613
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/hybrik.py
@@ -0,0 +1,31 @@
+HYBRIK_29_KEYPOINTS = [
+ 'pelvis',
+ 'left_hip',
+ 'right_hip', # 2
+ 'spine_1',
+ 'left_knee',
+ 'right_knee', # 5
+ 'spine_2',
+ 'left_ankle',
+ 'right_ankle', # 8
+ 'spine_3',
+ 'left_foot',
+ 'right_foot', # 11
+ 'neck',
+ 'left_collar',
+ 'right_collar', # 14
+ 'jaw', # 15
+ 'left_shoulder',
+ 'right_shoulder', # 17
+ 'left_elbow',
+ 'right_elbow', # 19
+ 'left_wrist',
+ 'right_wrist', # 21
+ 'left_thumb',
+ 'right_thumb', # 23
+ 'head',
+ 'left_middle',
+ 'right_middle', # 26
+ 'left_bigtoe',
+ 'right_bigtoe' # 28
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/instavariety.py b/detrsmpl/core/conventions/keypoints_mapping/instavariety.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e8587f831eb0191924a6fb1115988c766870ce9
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/instavariety.py
@@ -0,0 +1,27 @@
+INSTAVARIETY_KEYPOINTS = [
+ 'right_heel_openpose',
+ 'right_knee_openpose',
+ 'right_hip_openpose',
+ 'left_hip_openpose',
+ 'left_knee_openpose',
+ 'left_heel_openpose',
+ 'right_wrist_openpose',
+ 'right_elbow_openpose',
+ 'right_shoulder_openpose',
+ 'left_shoulder_openpose',
+ 'left_elbow_openpose',
+ 'left_wrist_openpose',
+ 'neck_openpose',
+ 'headtop',
+ 'nose_openpose',
+ 'left_eye_openpose',
+ 'right_eye_openpose',
+ 'left_ear_openpose',
+ 'right_ear_openpose',
+ 'left_bigtoe_openpose',
+ 'right_bigtoe_openpose',
+ 'left_smalltoe_openpose',
+ 'right_smalltoe_openpose',
+ 'left_ankle_openpose',
+ 'right_ankle_openpose',
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/lsp.py b/detrsmpl/core/conventions/keypoints_mapping/lsp.py
new file mode 100644
index 0000000000000000000000000000000000000000..214f16d52eecd2de163ca0ffe40457bfa2860c6c
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/lsp.py
@@ -0,0 +1,16 @@
+LSP_KEYPOINTS = [
+ 'right_ankle',
+ 'right_knee',
+ 'right_hip_extra',
+ 'left_hip_extra',
+ 'left_knee',
+ 'left_ankle',
+ 'right_wrist',
+ 'right_elbow',
+ 'right_shoulder',
+ 'left_shoulder',
+ 'left_elbow',
+ 'left_wrist',
+ 'neck_extra',
+ 'headtop',
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/mano.py b/detrsmpl/core/conventions/keypoints_mapping/mano.py
new file mode 100644
index 0000000000000000000000000000000000000000..225d27751274d62da245af7276214363944b3194
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/mano.py
@@ -0,0 +1,33 @@
+# Original order from MANO J_regressor
+MANO_RIGHT_KEYPOINTS = [
+ 'right_wrist', 'right_index_1', 'right_index_2', 'right_index_3',
+ 'right_middle_1', 'right_middle_2', 'right_middle_3', 'right_pinky_1',
+ 'right_pinky_2', 'right_pinky_3', 'right_ring_1', 'right_ring_2',
+ 'right_ring_3', 'right_thumb_1', 'right_thumb_2', 'right_thumb_3',
+ 'right_thumb', 'right_index', 'right_middle', 'right_ring', 'right_pinky'
+]
+
+MANO_LEFT_KEYPOINTS = [
+ x.replace('right_', 'left_') for x in MANO_RIGHT_KEYPOINTS
+]
+
+# Re-arranged order is compatible with the output of manolayer
+# from official [manopth](https://github.com/hassony2/manopth)
+MANO_REORDER_MAP = [
+ 0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20
+]
+
+MANO_RIGHT_REORDER_KEYPOINTS = [
+ MANO_RIGHT_KEYPOINTS[i] for i in MANO_REORDER_MAP
+]
+MANO_LEFT_REORDER_KEYPOINTS = [
+ MANO_LEFT_KEYPOINTS[i] for i in MANO_REORDER_MAP
+]
+
+# Deprecated: reserved for backward compatibility
+MANO_KEYPOINTS = MANO_RIGHT_KEYPOINTS
+# Two hands (left + right)
+MANO_HANDS_KEYPOINTS = MANO_LEFT_KEYPOINTS + MANO_RIGHT_KEYPOINTS
+# Reordered two hands (left + right)
+MANO_HANDS_REORDER_KEYPOINTS = \
+ MANO_LEFT_REORDER_KEYPOINTS + MANO_RIGHT_REORDER_KEYPOINTS
diff --git a/detrsmpl/core/conventions/keypoints_mapping/mpi_inf_3dhp.py b/detrsmpl/core/conventions/keypoints_mapping/mpi_inf_3dhp.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffc180f925d20fc3d3d4adf12198a48bea606c20
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/mpi_inf_3dhp.py
@@ -0,0 +1,81 @@
+MPI_INF_3DHP_KEYPOINTS = [
+ 'spine_3',
+ 'spine_4_3dhp',
+ 'spine_2',
+ 'spine_extra', # close to spine2
+ 'pelvis_extra',
+ 'neck_extra', # throat
+ 'head_extra',
+ 'headtop',
+ 'left_clavicle_3dhp',
+ 'left_shoulder',
+ 'left_elbow',
+ 'left_wrist',
+ 'left_hand_3dhp',
+ 'right_clavicle_3dhp',
+ 'right_shoulder',
+ 'right_elbow',
+ 'right_wrist',
+ 'right_hand_3dhp',
+ 'left_hip_extra',
+ 'left_knee',
+ 'left_ankle',
+ 'left_foot',
+ 'left_toe_3dhp',
+ 'right_hip_extra',
+ 'right_knee',
+ 'right_ankle',
+ 'right_foot',
+ 'right_toe_3dhp'
+]
+
+MPI_INF_3DHP_TEST_KEYPOINTS = [
+ 'headtop',
+ 'neck_extra',
+ 'right_shoulder',
+ 'right_elbow',
+ 'right_wrist',
+ 'left_shoulder',
+ 'left_elbow',
+ 'left_wrist',
+ 'right_hip_extra',
+ 'right_knee',
+ 'right_ankle',
+ 'left_hip_extra',
+ 'left_knee',
+ 'left_ankle',
+ 'pelvis_extra',
+ 'spine_extra', # close to spine2
+ 'head_extra'
+]
+
+HYBRIK_MPI_INF_3DHP_KEYPOINTS = [
+ 'spine_3',
+ 'spine_4_3dhp',
+ 'spine_2',
+ 'spine_extra', # close to spine2
+ 'pelvis',
+ 'neck', # throat
+ 'head_extra',
+ 'headtop',
+ 'left_clavicle_3dhp',
+ 'left_shoulder',
+ 'left_elbow',
+ 'left_wrist',
+ 'left_hand_3dhp',
+ 'right_clavicle_3dhp',
+ 'right_shoulder',
+ 'right_elbow',
+ 'right_wrist',
+ 'right_hand_3dhp',
+ 'left_hip',
+ 'left_knee',
+ 'left_ankle',
+ 'left_foot',
+ 'left_toe_3dhp',
+ 'right_hip',
+ 'right_knee',
+ 'right_ankle',
+ 'right_foot',
+ 'right_toe_3dhp'
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/mpii.py b/detrsmpl/core/conventions/keypoints_mapping/mpii.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e78953cba05941b0054e0ffe72f962c5399a9fe
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/mpii.py
@@ -0,0 +1,18 @@
+MPII_KEYPOINTS = [
+ 'right_ankle',
+ 'right_knee',
+ 'right_hip_extra',
+ 'left_hip_extra',
+ 'left_knee',
+ 'left_ankle',
+ 'pelvis_extra',
+ 'thorax_extra',
+ 'neck_extra',
+ 'headtop',
+ 'right_wrist',
+ 'right_elbow',
+ 'right_shoulder',
+ 'left_shoulder',
+ 'left_elbow',
+ 'left_wrist',
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/openpose.py b/detrsmpl/core/conventions/keypoints_mapping/openpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebe3b3cdee24a6c088f64a70238c6f8963213919
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/openpose.py
@@ -0,0 +1,452 @@
+"""These keypoint formats are taken from https://github.com/CMU-Perceptual-
+Computing-Lab/openpose/blob/master/src/openpose/pose/poseParameters.cpp.
+Openpose mainly supports 25 and 135 now, 118 convention can be found in
+https://github.com/vchoutas/smplify-x/issues/152#issuecomment-923715702.
+
+OPENPOSE_137_KEYPOINTS can be found in
+https://github.com/vchoutas/expose
+
+- OPENPOSE_25_KEYPOINTS: body(25)
+- OPENPOSE_118_KEYPOINTS: body(25) + hand(42) + face(51)
+- OPENPOSE_135_KEYPOINTS: body(25) + hand(40) + face(70)
+- OPENPOSE_137_KEYPOINTS: body(27) + hand(40) + face(70)
+
+Note that:
+1. 135 and coco17 share the first 17 body keypoints
+2. 25 and 118 share the first 25 body keypoints
+3. 137 and 135 share the hand and face parts
+"""
+
+OPENPOSE_135_KEYPOINTS = [
+ 'nose',
+ 'left_eye',
+ 'right_eye',
+ 'left_ear',
+ 'right_ear',
+ 'left_shoulder',
+ 'right_shoulder',
+ 'left_elbow',
+ 'right_elbow',
+ 'left_wrist',
+ 'right_wrist',
+ 'left_hip',
+ 'right_hip',
+ 'left_knee',
+ 'right_knee',
+ 'left_ankle',
+ 'right_ankle',
+ 'neck', # upper_neck
+ 'head',
+ 'left_bigtoe',
+ 'left_smalltoe',
+ 'left_heel',
+ 'right_bigtoe',
+ 'right_smalltoe',
+ 'right_heel',
+ 'left_thumb_1',
+ 'left_thumb_2',
+ 'left_thumb_3',
+ 'left_thumb',
+ 'left_index_1',
+ 'left_index_2',
+ 'left_index_3',
+ 'left_index',
+ 'left_middle_1',
+ 'left_middle_2',
+ 'left_middle_3',
+ 'left_middle',
+ 'left_ring_1',
+ 'left_ring_2',
+ 'left_ring_3',
+ 'left_ring',
+ 'left_pinky_1',
+ 'left_pinky_2',
+ 'left_pinky_3',
+ 'left_pinky',
+ 'right_thumb_1',
+ 'right_thumb_2',
+ 'right_thumb_3',
+ 'right_thumb',
+ 'right_index_1',
+ 'right_index_2',
+ 'right_index_3',
+ 'right_index',
+ 'right_middle_1',
+ 'right_middle_2',
+ 'right_middle_3',
+ 'right_middle',
+ 'right_ring_1',
+ 'right_ring_2',
+ 'right_ring_3',
+ 'right_ring',
+ 'right_pinky_1',
+ 'right_pinky_2',
+ 'right_pinky_3',
+ 'right_pinky',
+ 'right_contour_1', # original name: face_contour_1
+ 'right_contour_2', # original name: face_contour_2
+ 'right_contour_3', # original name: face_contour_3
+ 'right_contour_4', # original name: face_contour_4
+ 'right_contour_5', # original name: face_contour_5
+ 'right_contour_6', # original name: face_contour_6
+ 'right_contour_7', # original name: face_contour_7
+ 'right_contour_8', # original name: face_contour_8
+ 'contour_middle', # original name: face_contour_9
+ 'left_contour_8', # original name: face_contour_10
+ 'left_contour_7', # original name: face_contour_11
+ 'left_contour_6', # original name: face_contour_12
+ 'left_contour_5', # original name: face_contour_13
+ 'left_contour_4', # original name: face_contour_14
+ 'left_contour_3', # original name: face_contour_15
+ 'left_contour_2', # original name: face_contour_16
+ 'left_contour_1', # original name: face_contour_17
+ 'right_eyebrow_1',
+ 'right_eyebrow_2',
+ 'right_eyebrow_3',
+ 'right_eyebrow_4',
+ 'right_eyebrow_5',
+ 'left_eyebrow_5',
+ 'left_eyebrow_4',
+ 'left_eyebrow_3',
+ 'left_eyebrow_2',
+ 'left_eyebrow_1',
+ 'nosebridge_1',
+ 'nosebridge_2',
+ 'nosebridge_3',
+ 'nosebridge_4',
+ 'right_nose_2', # original name: nose_1
+ 'right_nose_1', # original name: nose_2
+ 'nose_middle', # original name: nose_3
+ 'left_nose_1', # original name: nose_4
+ 'left_nose_2', # original name: nose_5
+ 'right_eye_1',
+ 'right_eye_2',
+ 'right_eye_3',
+ 'right_eye_4',
+ 'right_eye_5',
+ 'right_eye_6',
+ 'left_eye_4',
+ 'left_eye_3',
+ 'left_eye_2',
+ 'left_eye_1',
+ 'left_eye_6',
+ 'left_eye_5',
+ 'right_mouth_1', # original name: mouth_1
+ 'right_mouth_2', # original name: mouth_2
+ 'right_mouth_3', # original name: mouth_3
+ 'mouth_top', # original name: mouth_4
+ 'left_mouth_3', # original name: mouth_5
+ 'left_mouth_2', # original name: mouth_6
+ 'left_mouth_1', # original name: mouth_7
+ 'left_mouth_5', # original name: mouth_8
+ 'left_mouth_4', # original name: mouth_9
+ 'mouth_bottom', # original name: mouth_10
+ 'right_mouth_4', # original name: mouth_11
+ 'right_mouth_5', # original name: mouth_12
+ 'right_lip_1', # original name: lip_1
+ 'right_lip_2', # original name: lip_2
+ 'lip_top', # original name: lip_3
+ 'left_lip_2', # original name: lip_4
+ 'left_lip_1', # original name: lip_5
+ 'left_lip_3', # original name: lip_6
+ 'lip_bottom', # original name: lip_7
+ 'right_lip_3', # original name: lip_8
+ 'right_eyeball',
+ 'left_eyeball'
+]
+
+# TODO: OPENPOSE-25->HumanData->SMPLX causes the whole body to be lost
+# OPENPOSE-25: nose_openpose
+# SMPLX: nose
+
+OPENPOSE_25_KEYPOINTS = [
+ 'nose_openpose',
+ 'neck_openpose', # 'upper_neck'
+ 'right_shoulder_openpose',
+ 'right_elbow_openpose',
+ 'right_wrist_openpose',
+ 'left_shoulder_openpose',
+ 'left_elbow_openpose',
+ 'left_wrist_openpose',
+ 'pelvis_openpose', # 'mid_hip'
+ 'right_hip_openpose',
+ 'right_knee_openpose',
+ 'right_ankle_openpose',
+ 'left_hip_openpose',
+ 'left_knee_openpose',
+ 'left_ankle_openpose',
+ 'right_eye_openpose',
+ 'left_eye_openpose',
+ 'right_ear_openpose',
+ 'left_ear_openpose',
+ 'left_bigtoe_openpose',
+ 'left_smalltoe_openpose',
+ 'left_heel_openpose',
+ 'right_bigtoe_openpose',
+ 'right_smalltoe_openpose',
+ 'right_heel_openpose'
+]
+
+OPENPOSE_118_KEYPOINTS = [
+ 'nose_openpose',
+ 'neck_openpose',
+ 'right_shoulder_openpose',
+ 'right_elbow_openpose',
+ 'right_wrist_openpose',
+ 'left_shoulder_openpose',
+ 'left_elbow_openpose',
+ 'left_wrist_openpose',
+ 'pelvis_openpose',
+ 'right_hip_openpose',
+ 'right_knee_openpose',
+ 'right_ankle_openpose',
+ 'left_hip_openpose',
+ 'left_knee_openpose',
+ 'left_ankle_openpose',
+ 'right_eye_openpose',
+ 'left_eye_openpose',
+ 'right_ear_openpose',
+ 'left_ear_openpose',
+ 'left_bigtoe_openpose',
+ 'left_smalltoe_openpose',
+ 'left_heel_openpose',
+ 'right_bigtoe_openpose',
+ 'right_smalltoe_openpose',
+ 'right_heel_openpose',
+ 'left_wrist',
+ 'left_thumb_1',
+ 'left_thumb_2',
+ 'left_thumb_3',
+ 'left_thumb',
+ 'left_index_1',
+ 'left_index_2',
+ 'left_index_3',
+ 'left_index',
+ 'left_middle_1',
+ 'left_middle_2',
+ 'left_middle_3',
+ 'left_middle',
+ 'left_ring_1',
+ 'left_ring_2',
+ 'left_ring_3',
+ 'left_ring',
+ 'left_pinky_1',
+ 'left_pinky_2',
+ 'left_pinky_3',
+ 'left_pinky',
+ 'right_wrist',
+ 'right_thumb_1',
+ 'right_thumb_2',
+ 'right_thumb_3',
+ 'right_thumb',
+ 'right_index_1',
+ 'right_index_2',
+ 'right_index_3',
+ 'right_index',
+ 'right_middle_1',
+ 'right_middle_2',
+ 'right_middle_3',
+ 'right_middle',
+ 'right_ring_1',
+ 'right_ring_2',
+ 'right_ring_3',
+ 'right_ring',
+ 'right_pinky_1',
+ 'right_pinky_2',
+ 'right_pinky_3',
+ 'right_pinky',
+ 'right_eyebrow_1',
+ 'right_eyebrow_2',
+ 'right_eyebrow_3',
+ 'right_eyebrow_4',
+ 'right_eyebrow_5',
+ 'left_eyebrow_5',
+ 'left_eyebrow_4',
+ 'left_eyebrow_3',
+ 'left_eyebrow_2',
+ 'left_eyebrow_1',
+ 'nosebridge_1',
+ 'nosebridge_2',
+ 'nosebridge_3',
+ 'nosebridge_4',
+ 'right_nose_2', # original name: nose_1
+ 'right_nose_1', # original name: nose_2
+ 'nose_middle', # original name: nose_3
+ 'left_nose_1', # original name: nose_4
+ 'left_nose_2', # original name: nose_5
+ 'right_eye_1',
+ 'right_eye_2',
+ 'right_eye_3',
+ 'right_eye_4',
+ 'right_eye_5',
+ 'right_eye_6',
+ 'left_eye_4',
+ 'left_eye_3',
+ 'left_eye_2',
+ 'left_eye_1',
+ 'left_eye_6',
+ 'left_eye_5',
+ 'right_mouth_1', # original name: mouth_1
+ 'right_mouth_2', # original name: mouth_2
+ 'right_mouth_3', # original name: mouth_3
+ 'mouth_top', # original name: mouth_4
+ 'left_mouth_3', # original name: mouth_5
+ 'left_mouth_2', # original name: mouth_6
+ 'left_mouth_1', # original name: mouth_7
+ 'left_mouth_5', # original name: mouth_8
+ 'left_mouth_4', # original name: mouth_9
+ 'mouth_bottom', # original name: mouth_10
+ 'right_mouth_4', # original name: mouth_11
+ 'right_mouth_5', # original name: mouth_12
+ 'right_lip_1', # original name: lip_1
+ 'right_lip_2', # original name: lip_2
+ 'lip_top', # original name: lip_3
+ 'left_lip_2', # original name: lip_4
+ 'left_lip_1', # original name: lip_5
+ 'left_lip_3', # original name: lip_6
+ 'lip_bottom', # original name: lip_7
+ 'right_lip_3', # original name: lip_8
+]
+
+OPENPOSE_JOINTS = [
+ 'nose',
+ 'neck',
+ 'right_shoulder',
+ 'right_elbow',
+ 'right_wrist',
+ 'left_shoulder',
+ 'left_elbow',
+ 'left_wrist',
+ 'pelvis',
+ 'right_hip',
+ 'right_knee',
+ 'right_ankle',
+ 'left_hip',
+ 'left_knee',
+ 'left_ankle',
+ 'right_eye',
+ 'left_eye',
+ 'right_ear',
+ 'left_ear',
+ 'left_wrist_openpose',
+ 'left_thumb_1',
+ 'left_thumb_2',
+ 'left_thumb_3',
+ 'left_thumb',
+ 'left_index_1',
+ 'left_index_2',
+ 'left_index_3',
+ 'left_index',
+ 'left_middle_1',
+ 'left_middle_2',
+ 'left_middle_3',
+ 'left_middle',
+ 'left_ring_1',
+ 'left_ring_2',
+ 'left_ring_3',
+ 'left_ring',
+ 'left_pinky_1',
+ 'left_pinky_2',
+ 'left_pinky_3',
+ 'left_pinky',
+ 'right_wrist_openpose',
+ 'right_thumb_1',
+ 'right_thumb_2',
+ 'right_thumb_3',
+ 'right_thumb',
+ 'right_index_1',
+ 'right_index_2',
+ 'right_index_3',
+ 'right_index',
+ 'right_middle_1',
+ 'right_middle_2',
+ 'right_middle_3',
+ 'right_middle',
+ 'right_ring_1',
+ 'right_ring_2',
+ 'right_ring_3',
+ 'right_ring',
+ 'right_pinky_1',
+ 'right_pinky_2',
+ 'right_pinky_3',
+ 'right_pinky',
+ # Face contour
+ 'right_contour_1',
+ 'right_contour_2',
+ 'right_contour_3',
+ 'right_contour_4',
+ 'right_contour_5',
+ 'right_contour_6',
+ 'right_contour_7',
+ 'right_contour_8',
+ 'contour_middle',
+ 'left_contour_8',
+ 'left_contour_7',
+ 'left_contour_6',
+ 'left_contour_5',
+ 'left_contour_4',
+ 'left_contour_3',
+ 'left_contour_2',
+ 'left_contour_1',
+ # Eye brows
+ 'right_eye_brow_1',
+ 'right_eye_brow_2',
+ 'right_eye_brow_3',
+ 'right_eye_brow_4',
+ 'right_eye_brow_5',
+ 'left_eye_brow_5',
+ 'left_eye_brow_4',
+ 'left_eye_brow_3',
+ 'left_eye_brow_2',
+ 'left_eye_brow_1',
+ 'nosebridge_1',
+ 'nosebridge_2',
+ 'nosebridge_3',
+ 'nosebridge_4',
+ 'right_nose_2',
+ 'right_nose_1',
+ 'nose_middle',
+ 'left_nose_1',
+ 'left_nose_2',
+ 'right_eye_1',
+ 'right_eye_2',
+ 'right_eye_3',
+ 'right_eye_4',
+ 'right_eye_5',
+ 'right_eye_6',
+ 'left_eye_4',
+ 'left_eye_3',
+ 'left_eye_2',
+ 'left_eye_1',
+ 'left_eye_6',
+ 'left_eye_5',
+ 'right_mouth_1',
+ 'right_mouth_2',
+ 'right_mouth_3',
+ 'mouth_top',
+ 'left_mouth_3',
+ 'left_mouth_2',
+ 'left_mouth_1',
+ 'left_mouth_5',
+ 'left_mouth_4',
+ 'mouth_bottom',
+ 'right_mouth_4',
+ 'right_mouth_5',
+ 'right_lip_1',
+ 'right_lip_2',
+ 'lip_top',
+ 'left_lip_2',
+ 'left_lip_1',
+ 'left_lip_3',
+ 'lip_bottom',
+ 'right_lip_3',
+ 'right_eyeball_unused', # not used in expose
+ 'left_eyeball_unused', # not used in expose
+]
+
+OPENPOSE_FEET_KEYPOINTS = [
+ 'left_bigtoe', 'left_smalltoe', 'left_heel', 'right_bigtoe',
+ 'right_smalltoe', 'right_heel'
+]
+OPENPOSE_137_KEYPOINTS = OPENPOSE_JOINTS[:19] + \
+ OPENPOSE_FEET_KEYPOINTS + OPENPOSE_JOINTS[19:]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/penn_action.py b/detrsmpl/core/conventions/keypoints_mapping/penn_action.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d6d70ea2baab877cbc3e83b92b742a099ba78ff
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/penn_action.py
@@ -0,0 +1,15 @@
+PENN_ACTION_KEYPOINTS = [
+ 'head',
+ 'left_shoulder',
+ 'right_shoulder',
+ 'left_elbow',
+ 'right_elbow',
+ 'left_wrist',
+ 'right_wrist',
+ 'left_hip',
+ 'right_hip',
+ 'left_knee',
+ 'right_knee',
+ 'left_ankle',
+ 'right_ankle',
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/posetrack.py b/detrsmpl/core/conventions/keypoints_mapping/posetrack.py
new file mode 100644
index 0000000000000000000000000000000000000000..03700f7c93e09638c2d31dbd55b55eb8b6f050dd
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/posetrack.py
@@ -0,0 +1,6 @@
+POSETRACK_KEYPOINTS = [
+ 'nose', 'head_bottom_pt', 'headtop', 'left_ear', 'right_ear',
+ 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
+ 'left_wrist', 'right_wrist', 'left_hip', 'right_hip', 'left_knee',
+ 'right_knee', 'left_ankle', 'right_ankle'
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/pw3d.py b/detrsmpl/core/conventions/keypoints_mapping/pw3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5ec7a5167208a93f24cafd15596aa0cece7af0f
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/pw3d.py
@@ -0,0 +1,20 @@
+PW3D_KEYPOINTS = [
+ 'nose',
+ 'neck_extra',
+ 'right_shoulder',
+ 'right_elbow',
+ 'right_wrist',
+ 'left_shoulder',
+ 'left_elbow',
+ 'left_wrist',
+ 'right_hip_extra',
+ 'right_knee',
+ 'right_ankle',
+ 'left_hip_extra',
+ 'left_knee',
+ 'left_ankle',
+ 'right_eye',
+ 'left_eye',
+ 'right_ear',
+ 'left_ear',
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/smpl.py b/detrsmpl/core/conventions/keypoints_mapping/smpl.py
new file mode 100644
index 0000000000000000000000000000000000000000..52bb70ba6e8a1563e6cbebf4bad1da393477685c
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/smpl.py
@@ -0,0 +1,126 @@
+# the keypoints defined in the SMPL paper
+SMPL_KEYPOINTS = [
+ 'pelvis',
+ 'left_hip',
+ 'right_hip',
+ 'spine_1',
+ 'left_knee',
+ 'right_knee',
+ 'spine_2',
+ 'left_ankle',
+ 'right_ankle',
+ 'spine_3',
+ 'left_foot',
+ 'right_foot',
+ 'neck',
+ 'left_collar',
+ 'right_collar',
+ 'head',
+ 'left_shoulder',
+ 'right_shoulder',
+ 'left_elbow',
+ 'right_elbow',
+ 'left_wrist',
+ 'right_wrist',
+
+ # 'left_hand',
+ # 'right_hand',
+ 'left_middle',
+ 'right_middle'
+]
+
+# the full keypoints produced by the default SMPL J_regressor
+SMPL_45_KEYPOINTS = SMPL_KEYPOINTS + [
+ 'nose',
+ 'right_eye',
+ 'left_eye',
+ 'right_ear',
+ 'left_ear',
+ 'left_bigtoe',
+ 'left_smalltoe',
+ 'left_heel',
+ 'right_bigtoe',
+ 'right_smalltoe',
+ 'right_heel',
+ 'left_thumb',
+ 'left_index',
+ 'left_middle',
+ 'left_ring',
+ 'left_pinky',
+ 'right_thumb',
+ 'right_index',
+ 'right_middle',
+ 'right_ring',
+ 'right_pinky',
+]
+
+# the full keypoints produced by the default SMPL J_regressor and
+# extra_J_regressor (provided by SPIN)
+SMPL_54_KEYPOINTS = SMPL_45_KEYPOINTS + [
+ 'right_hip_extra', # LSP
+ 'left_hip_extra', # LSP
+ 'neck_extra', # LSP
+ 'headtop', # LSP
+ 'pelvis_extra', # MPII
+ 'thorax_extra', # MPII
+ 'spine_extra', # H36M
+ 'jaw_extra', # H36M
+ 'head_extra', # H36M
+]
+
+# SMPL keypoint convention used by SPIN, EFT and so on
+SMPL_49_KEYPOINTS = [
+ # OpenPose
+ 'nose_openpose',
+ 'neck_openpose', # 'upper_neck'
+ 'right_shoulder_openpose',
+ 'right_elbow_openpose',
+ 'right_wrist_openpose',
+ 'left_shoulder_openpose',
+ 'left_elbow_openpose',
+ 'left_wrist_openpose',
+ 'pelvis_openpose',
+ 'right_hip_openpose',
+ 'right_knee_openpose',
+ 'right_ankle_openpose',
+ 'left_hip_openpose',
+ 'left_knee_openpose',
+ 'left_ankle_openpose',
+ 'right_eye_openpose',
+ 'left_eye_openpose',
+ 'right_ear_openpose',
+ 'left_ear_openpose',
+ 'left_bigtoe_openpose',
+ 'left_smalltoe_openpose',
+ 'left_heel_openpose',
+ 'right_bigtoe_openpose',
+ 'right_smalltoe_openpose',
+ 'right_heel_openpose',
+ # 24 Keypoints
+ 'right_ankle',
+ 'right_knee',
+ 'right_hip_extra', # LSP
+ 'left_hip_extra', # LSP
+ 'left_knee',
+ 'left_ankle',
+ 'right_wrist',
+ 'right_elbow',
+ 'right_shoulder',
+ 'left_shoulder',
+ 'left_elbow',
+ 'left_wrist',
+ 'neck_extra', # LSP
+ 'headtop', # LSP mpii peen_action mpi_inf_3dhp
+ 'pelvis_extra', # MPII
+ 'thorax_extra', # MPII
+ 'spine_extra', # H36M
+ 'jaw_extra', # H36M
+ 'head_extra', # H36M
+ 'nose',
+ 'left_eye',
+ 'right_eye',
+ 'left_ear',
+ 'right_ear'
+]
+
+SMPL_24_KEYPOINTS = SMPL_49_KEYPOINTS[-24:]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/smplx.py b/detrsmpl/core/conventions/keypoints_mapping/smplx.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f247073947a438019741420a322264bcadc3a12
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/smplx.py
@@ -0,0 +1,382 @@
+SMPLX_KEYPOINTS = [
+ 'pelvis',
+ 'left_hip',
+ 'right_hip',
+ 'spine_1',
+ 'left_knee',
+ 'right_knee',
+ 'spine_2',
+ 'left_ankle',
+ 'right_ankle',
+ 'spine_3',
+ 'left_foot',
+ 'right_foot',
+ 'neck',
+ 'left_collar',
+ 'right_collar',
+ 'head',
+ 'left_shoulder',
+ 'right_shoulder',
+ 'left_elbow',
+ 'right_elbow',
+ 'left_wrist',
+ 'right_wrist',
+ 'jaw',
+ 'left_eyeball',
+ 'right_eyeball',
+ 'left_index_1',
+ 'left_index_2',
+ 'left_index_3',
+ 'left_middle_1',
+ 'left_middle_2',
+ 'left_middle_3',
+ 'left_pinky_1',
+ 'left_pinky_2',
+ 'left_pinky_3',
+ 'left_ring_1',
+ 'left_ring_2',
+ 'left_ring_3',
+ 'left_thumb_1',
+ 'left_thumb_2',
+ 'left_thumb_3',
+ 'right_index_1',
+ 'right_index_2',
+ 'right_index_3',
+ 'right_middle_1',
+ 'right_middle_2',
+ 'right_middle_3',
+ 'right_pinky_1',
+ 'right_pinky_2',
+ 'right_pinky_3',
+ 'right_ring_1',
+ 'right_ring_2',
+ 'right_ring_3',
+ 'right_thumb_1',
+ 'right_thumb_2',
+ 'right_thumb_3',
+ 'nose',
+ 'right_eye',
+ 'left_eye',
+ 'right_ear',
+ 'left_ear',
+ 'left_bigtoe',
+ 'left_smalltoe',
+ 'left_heel',
+ 'right_bigtoe',
+ 'right_smalltoe',
+ 'right_heel',
+ 'left_thumb',
+ 'left_index',
+ 'left_middle',
+ 'left_ring',
+ 'left_pinky',
+ 'right_thumb',
+ 'right_index',
+ 'right_middle',
+ 'right_ring',
+ 'right_pinky',
+ 'right_eyebrow_1',
+ 'right_eyebrow_2',
+ 'right_eyebrow_3',
+ 'right_eyebrow_4',
+ 'right_eyebrow_5',
+ 'left_eyebrow_5',
+ 'left_eyebrow_4',
+ 'left_eyebrow_3',
+ 'left_eyebrow_2',
+ 'left_eyebrow_1',
+ 'nosebridge_1',
+ 'nosebridge_2',
+ 'nosebridge_3',
+ 'nosebridge_4',
+ 'right_nose_2', # original name: nose_1
+ 'right_nose_1', # original name: nose_2
+ 'nose_middle', # original name: nose_3
+ 'left_nose_1', # original name: nose_4
+ 'left_nose_2', # original name: nose_5
+ 'right_eye_1',
+ 'right_eye_2',
+ 'right_eye_3',
+ 'right_eye_4',
+ 'right_eye_5',
+ 'right_eye_6',
+ 'left_eye_4',
+ 'left_eye_3',
+ 'left_eye_2',
+ 'left_eye_1',
+ 'left_eye_6',
+ 'left_eye_5',
+ 'right_mouth_1', # original name: mouth_1
+ 'right_mouth_2', # original name: mouth_2
+ 'right_mouth_3', # original name: mouth_3
+ 'mouth_top', # original name: mouth_4
+ 'left_mouth_3', # original name: mouth_5
+ 'left_mouth_2', # original name: mouth_6
+ 'left_mouth_1', # original name: mouth_7
+ 'left_mouth_5', # original name: mouth_8
+ 'left_mouth_4', # original name: mouth_9
+ 'mouth_bottom', # original name: mouth_10
+ 'right_mouth_4', # original name: mouth_11
+ 'right_mouth_5', # original name: mouth_12
+ 'right_lip_1', # original name: lip_1
+ 'right_lip_2', # original name: lip_2
+ 'lip_top', # original name: lip_3
+ 'left_lip_2', # original name: lip_4
+ 'left_lip_1', # original name: lip_5
+ 'left_lip_3', # original name: lip_6
+ 'lip_bottom', # original name: lip_7
+ 'right_lip_3', # original name: lip_8
+ 'right_contour_1', # original name: face_contour_1
+ 'right_contour_2', # original name: face_contour_2
+ 'right_contour_3', # original name: face_contour_3
+ 'right_contour_4', # original name: face_contour_4
+ 'right_contour_5', # original name: face_contour_5
+ 'right_contour_6', # original name: face_contour_6
+ 'right_contour_7', # original name: face_contour_7
+ 'right_contour_8', # original name: face_contour_8
+ 'contour_middle', # original name: face_contour_9
+ 'left_contour_8', # original name: face_contour_10
+ 'left_contour_7', # original name: face_contour_11
+ 'left_contour_6', # original name: face_contour_12
+ 'left_contour_5', # original name: face_contour_13
+ 'left_contour_4', # original name: face_contour_14
+ 'left_contour_3', # original name: face_contour_15
+ 'left_contour_2', # original name: face_contour_16
+ 'left_contour_1', # original name: face_contour_17
+]
+
+SMPLX_LIMBS = {
+ 'body': [['pelvis', 'left_hip'], ['pelvis', 'right_hip'],
+ ['left_hip', 'right_hip'], ['left_shoulder', 'right_shoulder'],
+ ['pelvis', 'spine_1'], ['spine_1', 'spine_2'],
+ ['spine_2', 'spine_3'], ['spine_3', 'neck'], ['neck', 'head'],
+ ['left_ankle', 'left_knee'], ['left_knee', 'left_hip'],
+ ['right_ankle', 'right_knee'], ['right_knee', 'right_hip'],
+ ['right_ankle', 'right_foot'], ['left_ankle', 'left_foot'],
+ ['left_hip', 'right_hip'], ['left_shoulder', 'left_hip'],
+ ['right_shoulder', 'right_hip'], ['left_collar', 'spine_3'],
+ ['right_collar', 'spine_3'], ['right_collar', 'right_shoulder'],
+ ['left_collar', 'left_shoulder'],
+ ['left_shoulder', 'right_shoulder'],
+ ['left_shoulder',
+ 'left_elbow'], ['right_shoulder', 'right_elbow'],
+ ['left_elbow', 'left_wrist'], ['right_elbow', 'right_wrist'],
+ ['left_ankle', 'left_bigtoe'], ['left_ankle', 'left_smalltoe'],
+ ['left_ankle', 'left_heel'], ['right_ankle', 'right_bigtoe'],
+ ['right_ankle', 'right_smalltoe'], ['right_ankle', 'right_heel'],
+ ['left_shoulder', 'left_ear'], ['right_shoulder', 'right_ear'],
+ ['right_ear', 'right_eye'], ['right_eye', 'nose'],
+ ['nose', 'left_eye'], ['left_eye', 'left_ear'], ['nose', 'jaw'],
+ ['jaw', 'neck']],
+ 'face': [['right_contour_1', 'right_contour_2'],
+ ['right_contour_2', 'right_contour_3'],
+ ['right_contour_3', 'right_contour_4'],
+ ['right_contour_4', 'right_contour_5'],
+ ['right_contour_5', 'right_contour_6'],
+ ['right_contour_6', 'right_contour_7'],
+ ['right_contour_7', 'right_contour_8'],
+ ['right_contour_8', 'contour_middle'],
+ ['contour_middle', 'left_contour_8'],
+ ['left_contour_8', 'left_contour_7'],
+ ['left_contour_7', 'left_contour_6'],
+ ['left_contour_6', 'left_contour_5'],
+ ['left_contour_5', 'left_contour_4'],
+ ['left_contour_4', 'left_contour_3'],
+ ['left_contour_3', 'left_contour_2'],
+ ['left_contour_2', 'left_contour_1']],
+ 'left_hand':
+ [['left_wrist', 'left_thumb_1'], ['left_thumb_1', 'left_thumb_2'],
+ ['left_thumb_2', 'left_thumb_3'], ['left_thumb_3', 'left_thumb'],
+ ['left_wrist', 'left_index_1'], ['left_index_1', 'left_index_2'],
+ ['left_index_2', 'left_index_3'], ['left_index_3', 'left_index'],
+ ['left_wrist', 'left_middle_1'], ['left_middle_1', 'left_middle_2'],
+ ['left_middle_2', 'left_middle_3'], ['left_middle_3', 'left_middle'],
+ ['left_wrist', 'left_ring_1'], ['left_ring_1', 'left_ring_2'],
+ ['left_ring_2', 'left_ring_3'], ['left_ring_3', 'left_ring'],
+ ['left_wrist', 'left_pinky_1'], ['left_pinky_1', 'left_pinky_2'],
+ ['left_pinky_2', 'left_pinky_3'], ['left_pinky_3', 'left_pinky']],
+ 'right_hand': [['right_wrist', 'right_thumb_1'],
+ ['right_thumb_1', 'right_thumb_2'],
+ ['right_thumb_2', 'right_thumb_3'],
+ ['right_thumb_3', 'right_thumb'],
+ ['right_wrist', 'right_index_1'],
+ ['right_index_1', 'right_index_2'],
+ ['right_index_2', 'right_index_3'],
+ ['right_index_3', 'right_index'],
+ ['right_wrist', 'right_middle_1'],
+ ['right_middle_1', 'right_middle_2'],
+ ['right_middle_2', 'right_middle_3'],
+ ['right_middle_3', 'right_middle'],
+ ['right_wrist', 'right_ring_1'],
+ ['right_ring_1', 'right_ring_2'],
+ ['right_ring_2', 'right_ring_3'],
+ ['right_ring_3', 'right_ring'],
+ ['right_wrist', 'right_pinky_1'],
+ ['right_pinky_1', 'right_pinky_2'],
+ ['right_pinky_2', 'right_pinky_3'],
+ ['right_pinky_3', 'right_pinky']],
+ 'right_eye':
+ [['right_eye_1', 'right_eye_2'], ['right_eye_2', 'right_eye_3'],
+ ['right_eye_3', 'right_eye_4'], ['right_eye_4', 'right_eye_5'],
+ ['right_eye_5', 'right_eye_6'], ['right_eye_6', 'right_eye_1'],
+ ['right_eyebrow_1', 'right_eyebrow_2'],
+ ['right_eyebrow_2', 'right_eyebrow_3'],
+ ['right_eyebrow_3', 'right_eyebrow_4'],
+ ['right_eyebrow_4', 'right_eyebrow_5']],
+ 'left_eye': [['left_eye_4', 'left_eye_3'], ['left_eye_3', 'left_eye_2'],
+ ['left_eye_2', 'left_eye_1'], ['left_eye_1', 'left_eye_6'],
+ ['left_eye_6', 'left_eye_5'], ['left_eye_5', 'left_eye_4'],
+ ['left_eyebrow_1', 'left_eyebrow_2'],
+ ['left_eyebrow_2', 'left_eyebrow_3'],
+ ['left_eyebrow_3', 'left_eyebrow_4'],
+ ['left_eyebrow_4', 'left_eyebrow_5']],
+ 'mouth':
+ [['right_mouth_1', 'right_mouth_2'], ['right_mouth_2', 'right_mouth_3'],
+ ['right_mouth_3', 'mouth_top'], ['mouth_top', 'left_mouth_3'],
+ ['left_mouth_3', 'left_mouth_2'], ['left_mouth_2', 'left_mouth_1'],
+ ['left_mouth_1', 'left_mouth_5'], ['left_mouth_5', 'left_mouth_4'],
+ ['left_mouth_4', 'mouth_bottom'], ['mouth_bottom', 'right_mouth_4'],
+ ['right_mouth_4', 'right_mouth_5'], ['right_mouth_5', 'right_mouth_1'],
+ ['right_lip_1', 'right_lip_2'], ['right_lip_2', 'lip_top'],
+ ['lip_top', 'left_lip_2'], ['left_lip_2', 'left_lip_1'],
+ ['left_lip_1', 'left_lip_3'], ['left_lip_3', 'lip_bottom'],
+ ['lip_bottom', 'right_lip_3'], ['right_lip_3', 'right_lip_1']],
+ 'nose': [
+ ['nosebridge_1', 'nosebridge_2'],
+ ['nosebridge_2', 'nosebridge_3'],
+ ['nosebridge_3', 'nosebridge_4'],
+ ['right_nose_2', 'right_nose_1'],
+ ['right_nose_1', 'nose_middle'],
+ ['nose_middle', 'left_nose_1'],
+ ['left_nose_1', 'left_nose_2'],
+ ]
+}
+
+SMPLX_LIMBS_INDEX = {}
+for k in SMPLX_LIMBS:
+ SMPLX_LIMBS_INDEX[k] = [[
+ SMPLX_KEYPOINTS.index(limb[0]),
+ SMPLX_KEYPOINTS.index(limb[1])
+ ] for limb in SMPLX_LIMBS[k]]
+
+SMPLX_PALETTE = {
+ 'left_eye': [[0, 0, 0]],
+ 'right_eye': [[0, 0, 0]],
+ 'nose': [[0, 0, 255]],
+ 'mouth': [[0, 255, 255]],
+ 'face': [[255, 0, 0]],
+ 'left_hand': [[0, 0, 0]],
+ 'right_hand': [[0, 0, 0]]
+}
+
+
+joint_idx = \
+ (0,1,2,4,5,7,8,12,16,17,18,19,20,21,60,61,62,63,64,65,59,58,57,56,55, # body joints
+ 37,38,39,66,25,26,27,67,28,29,30,68,34,35,36,69,31,32,33,70, # left hand joints
+ 52,53,54,71,40,41,42,72,43,44,45,73,49,50,51,74,46,47,48,75, # right hand joints
+ 22,15, # jaw, head
+ 57,56, # eyeballs
+ 76,77,78,79,80,81,82,83,84,85, # eyebrow
+ 86,87,88,89, # nose
+ 90,91,92,93,94, # below nose
+ 95,96,97,98,99,100,101,102,103,104,105,106, # eyes
+ 107, # right mouth
+ 108,109,110,111,112, # upper mouth
+ 113, # left mouth
+ 114,115,116,117,118, # lower mouth
+ 119, # right lip
+ 120,121,122, # upper lip
+ 123, # left lip
+ 124,125,126, # lower lip
+ 127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143 # face contour
+ )
+
+SMPLX_137_KEYPOINTS = []
+for idx in joint_idx:
+ SMPLX_137_KEYPOINTS.append(SMPLX_KEYPOINTS[idx])
+
+
+SMPLX_LHAND = [
+ # 'left_thumb_2',
+ 'left_wrist',
+ 'left_thumb',
+ # 'left_index_1',
+ 'left_index',
+ # 'left_middle_1',
+ 'left_middle',
+ # 'left_ring_1',
+ 'left_ring',
+ # 'left_pinky_1',
+ 'left_pinky',
+]
+SMPLX_RHAND = [
+ # 'right_thumb_2',
+ 'right_wrist',
+ 'right_thumb',
+ # 'right_index_1',
+ 'right_index',
+ # 'right_middle_1',
+ 'right_middle',
+ # 'right_ring_1',
+ 'right_ring',
+ # 'right_pinky_1',
+ 'right_pinky',
+]
+
+SMPLX_FACE = [
+ 'nose',
+ 'mouth_top',
+ 'jaw',
+ 'right_contour_1',
+ 'contour_middle',
+ 'left_contour_1'
+]
+
+AiOS_35_KEYPOINTS = [
+ 'nose',
+ 'left_eye',
+ 'right_eye',
+ 'left_ear',
+ 'right_ear',
+ 'left_shoulder',
+ 'right_shoulder',
+ 'left_elbow',
+ 'right_elbow',
+ 'left_wrist',
+ 'right_wrist',
+ 'left_hip_extra',
+ 'right_hip_extra',
+ 'left_knee',
+ 'right_knee',
+ 'left_ankle',
+ 'right_ankle',
+ 'left_wrist',
+ 'left_thumb',
+ # 'left_index_1',
+ 'left_index',
+ # 'left_middle_1',
+ 'left_middle',
+ # 'left_ring_1',
+ 'left_ring',
+ # 'left_pinky_1',
+ 'left_pinky',
+
+ 'right_wrist',
+ 'right_thumb',
+ # 'right_index_1',
+ 'right_index',
+ # 'right_middle_1',
+ 'right_middle',
+ # 'right_ring_1',
+ 'right_ring',
+ # 'right_pinky_1',
+ 'right_pinky',
+
+ 'nose',
+ 'mouth_top',
+ 'jaw',
+ 'right_contour_1',
+ 'contour_middle',
+ 'left_contour_1'
+
+]
\ No newline at end of file
diff --git a/detrsmpl/core/conventions/keypoints_mapping/spin_smplx.py b/detrsmpl/core/conventions/keypoints_mapping/spin_smplx.py
new file mode 100644
index 0000000000000000000000000000000000000000..0be4d124804a51b694b4f67805b8bb666b99b04b
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/spin_smplx.py
@@ -0,0 +1,35 @@
+"""SPIN in smplx convention.
+
+SPIN_SMPLX_KEYPOINTS can be found in https://github.com/vchoutas/expose
+"""
+
+# TODO: SMPL_24->HumanData->SMPLX causes hip, spine to be lost.
+# SMPL_24: left/right_hip_extra
+# SMPLX: left/right_hip
+
+SPIN_SMPLX_KEYPOINTS = [
+ 'right_ankle',
+ 'right_knee',
+ 'right_hip',
+ 'left_hip',
+ 'left_knee',
+ 'left_ankle',
+ 'right_wrist',
+ 'right_elbow',
+ 'right_shoulder',
+ 'left_shoulder',
+ 'left_elbow',
+ 'left_wrist',
+ 'neck',
+ 'head_top',
+ 'pelvis',
+ 'thorax',
+ 'spine',
+ 'h36m_jaw',
+ 'h36m_head',
+ 'nose',
+ 'left_eye',
+ 'right_eye',
+ 'left_ear',
+ 'right_ear',
+]
diff --git a/detrsmpl/core/conventions/keypoints_mapping/star.py b/detrsmpl/core/conventions/keypoints_mapping/star.py
new file mode 100644
index 0000000000000000000000000000000000000000..d774d8efc722f397bb39d036c70bf3f0332dedd5
--- /dev/null
+++ b/detrsmpl/core/conventions/keypoints_mapping/star.py
@@ -0,0 +1,26 @@
+STAR_KEYPOINTS = [
+ 'pelvis',
+ 'left_hip',
+ 'right_hip',
+ 'spine_1',
+ 'left_knee',
+ 'right_knee',
+ 'spine_2',
+ 'left_ankle',
+ 'right_ankle',
+ 'spine_3',
+ 'left_foot',
+ 'right_foot',
+ 'neck',
+ 'left_collar',
+ 'right_collar',
+ 'head',
+ 'left_shoulder',
+ 'right_shoulder',
+ 'left_elbow',
+ 'right_elbow',
+ 'left_wrist',
+ 'right_wrist',
+ 'left_hand',
+ 'right_hand',
+]
diff --git a/detrsmpl/core/conventions/segmentation/__init__.py b/detrsmpl/core/conventions/segmentation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7a843ccb92d6113c415f0cf32e95cca15f075e4
--- /dev/null
+++ b/detrsmpl/core/conventions/segmentation/__init__.py
@@ -0,0 +1,94 @@
+from .smpl import SMPL_SEGMENTATION_DICT, SMPL_SUPER_SET
+from .smplx import SMPLX_SEGMENTATION_DICT, SMPLX_SUPER_SET
+
+
+class body_segmentation(object):
+ """SMPL(X) body mesh vertex segmentation."""
+ def __init__(self, model_type='smpl') -> None:
+ if model_type == 'smpl':
+ self.DICT = SMPL_SEGMENTATION_DICT
+ self.super_set = SMPL_SUPER_SET
+ self.NUM_VERTS = 6890
+ elif model_type == 'smplx':
+ self.DICT = SMPLX_SEGMENTATION_DICT
+ self.super_set = SMPLX_SUPER_SET
+ self.NUM_VERTS = 10475
+ else:
+ raise ValueError(f'Wrong model_type: {model_type}.'
+ f' Should be in {["smpl", "smplx"]}')
+ self.model_type = model_type
+ self.len = len(list(self.DICT))
+
+ def items(self, ):
+ return zip(self.keys(), [self.__getitem__(key) for key in self.keys()])
+
+ def keys(self, ):
+ return self.DICT.keys()
+
+ def values(self, ):
+ return [self.__getitem__(key) for key in self.keys()]
+
+ def __len__(self, ):
+ return self.len
+
+ def __getitem__(self, key):
+ if key in self.DICT.keys():
+ part_segmentation = []
+ raw_segmentation = self.DICT[key]
+ for continuous in raw_segmentation:
+ if len(continuous) == 2:
+ part_segmentation.extend(
+ list(range(continuous[0], continuous[1] + 1)))
+ elif len(continuous) == 1:
+ part_segmentation.extend(continuous)
+ return part_segmentation
+ elif key in self.super_set.keys():
+ super_part_segmentation = []
+ for body_part_key in self.super_set[key]:
+ super_part_segmentation += self.__getitem__(body_part_key)
+ return super_part_segmentation
+ elif key.lower() == 'all':
+ return list(range(self.NUM_VERTS))
+ else:
+ raise KeyError(f'{key} not in {self.model_type} conventions.')
+
+
+def _preprocess_segmentation_dict(segmentation_dict):
+ """help to preprocess the indexes to list."""
+ final_dict = {}
+ for k in segmentation_dict:
+ final_dict[k] = [[]]
+ final_part_indexes = final_dict[k]
+ part_indexes = segmentation_dict[k]
+ part_indexes.sort()
+ for index in range(len(part_indexes)):
+ if len(final_part_indexes[-1]) == 0:
+ final_part_indexes[-1].append(part_indexes[index])
+ elif len(final_part_indexes[-1]) == 2:
+ final_part_indexes.append([part_indexes[index]])
+ elif len(final_part_indexes[-1]) == 1:
+ if index != len(part_indexes) - 1:
+ this_index = part_indexes[index]
+ last_index = part_indexes[index - 1]
+ next_index = part_indexes[index + 1]
+ if (this_index == last_index + 1) and (this_index
+ == next_index - 1):
+ pass
+ elif (this_index == last_index +
+ 1) and (this_index != next_index - 1):
+ final_part_indexes[-1].append(this_index)
+ elif (this_index != last_index + 1) and (this_index !=
+ next_index - 1):
+ final_part_indexes.append([this_index])
+ final_part_indexes.append([])
+ elif (this_index !=
+ last_index + 1) and (this_index == next_index - 1):
+ final_part_indexes.append([this_index])
+ else:
+ this_index = part_indexes[index]
+ last_index = part_indexes[index - 1]
+ if (this_index == last_index + 1):
+ final_part_indexes[-1].append(this_index)
+ else:
+ final_part_indexes.append([this_index])
+ return final_dict
diff --git a/detrsmpl/core/conventions/segmentation/smpl.py b/detrsmpl/core/conventions/segmentation/smpl.py
new file mode 100644
index 0000000000000000000000000000000000000000..b841d74c0cdf08479c97ae7e5d6968af00aa2a94
--- /dev/null
+++ b/detrsmpl/core/conventions/segmentation/smpl.py
@@ -0,0 +1,239 @@
+"""Raw index information can be found from smpl-wiki website:
+
+https://meshcapade.wiki/SMPL#mesh-templates--samples
+"""
+SMPL_SEGMENTATION_DICT = {
+ 'rightHand':
+ [[5442, 5487], [5492, 5497], [5502, 5527], [5530, 5562], [5569], [5571],
+ [5574, 5583], [5588, 5589], [5592, 5605], [5610, 5614], [5621, 5622],
+ [5625], [5631, 5641], [5643, 5646], [5649, 5650], [5652, 5664], [5667],
+ [5670, 5675], [5682, 5690], [5692], [5695], [5697, 5701], [5707, 5721],
+ [5723, 5732], [5735, 5740], [5745, 5746], [5748, 5752], [6056, 6057],
+ [6066, 6067], [6158, 6239]],
+ 'rightUpLeg': [[4320, 4321], [4323, 4324], [4333, 4340], [4356, 4367],
+ [4383, 4401], [4419, 4422], [4430, 4532], [4623, 4634],
+ [4645, 4660], [4670, 4673], [4704, 4713], [4745, 4746],
+ [4757, 4760], [4801, 4802], [4829], [4834, 4841],
+ [4924, 4926], [4928, 4936], [4948, 4952], [4970, 4973],
+ [4983, 4993], [5004, 5005], [6546, 6549], [6552, 6556],
+ [6873], [6877]],
+ 'leftArm': [[626, 629], [634, 635], [680, 681], [716, 719], [769, 780],
+ [784, 793], [1231, 1234], [1258, 1261], [1271], [1281, 1282],
+ [1310, 1311], [1314, 1315], [1340, 1343], [1355, 1358],
+ [1376, 1400], [1402, 1403], [1405, 1416], [1428, 1433],
+ [1438, 1445], [1502], [1505, 1510], [1538],
+ [1541, 1543], [1545], [1619, 1622], [1631, 1642], [1645, 1656],
+ [1658, 1659], [1661, 1662], [1664], [1666, 1684], [1696, 1698],
+ [1703, 1720], [1725], [1731, 1735], [1737], [1739, 1740],
+ [1745, 1749], [1751], [1761], [1830, 1831], [1844, 1846],
+ [1850, 1851], [1854, 1855], [1858], [1860], [1865, 1867],
+ [1869, 1871], [1874, 1878], [1882, 1883], [1888, 1889], [1892],
+ [1900, 1904], [1909], [2819, 2822], [2895, 2903], [2945, 2946],
+ [2974, 2996], [3002], [3013]],
+ 'leftLeg': [[995], [998, 999], [1002], [1004, 1005], [1008], [1010],
+ [1012], [1015, 1016], [1018, 1019], [1043, 1044], [1047, 1136],
+ [1148, 1158], [1175, 1183], [1369, 1375], [1464, 1474],
+ [1522, 1532], [3174, 3210], [3319, 3335], [3432, 3436], [3469],
+ [3472, 3474]],
+ 'leftToeBase': [[3211, 3318], [3336, 3337], [3340], [3342], [3344], [3346],
+ [3348], [3350], [3352], [3354], [3357, 3358], [3360],
+ [3362]],
+ 'leftFoot': [[3327, 3469]],
+ 'spine1':
+ [[598, 601], [610, 621], [642], [645, 647], [652, 653], [658, 661],
+ [668, 671], [684, 692], [722, 725], [736], [750, 751], [761], [764],
+ [766, 767], [794, 795], [891, 894], [925, 929], [940, 943], [1190, 1197],
+ [1200, 1202], [1212], [1236], [1252, 1255], [1268, 1270], [1329, 1330],
+ [1348, 1349], [1351], [1420, 1421], [1423, 1426], [1436, 1437],
+ [1756, 1758], [2839, 2851], [2870, 2871], [2883], [2906], [2908], [3014],
+ [3017], [3025], [3030], [3033, 3034], [3037], [3039, 3044], [3076, 3077],
+ [3079], [3480], [3505], [3511], [4086, 4089], [4098, 4109], [4130, 4131],
+ [4134, 4135], [4140, 4141], [4146, 4149], [4156, 4159], [4172, 4180],
+ [4210, 4213], [4225], [4239, 4240], [4249, 4250], [4255, 4256],
+ [4282, 4283], [4377, 4380], [4411, 4415], [4426, 4429], [4676, 4683],
+ [4686, 4688], [4695], [4719], [4735, 4737], [4740], [4751, 4753],
+ [4824, 4825], [4828], [4893, 4895], [4897, 4899], [4908, 4909],
+ [5223, 5225], [6300, 6312], [6331, 6332], [6342], [6366, 6367], [6475],
+ [6477, 6478], [6481, 6482], [6485], [6487, 6491], [6878]],
+ 'spine2': [[570, 573], [584, 597], [602, 609], [622, 625], [638, 641],
+ [643, 644], [648, 651], [666, 667], [672, 675], [680, 683],
+ [693, 704], [713, 717], [726, 733], [735],
+ [737, 749], [752, 760], [762, 763], [803, 806], [811, 814],
+ [817, 821], [824, 828], [895, 896], [930, 931], [1198, 1199],
+ [1213, 1220], [1235], [1237], [1256, 1257], [1271, 1273],
+ [1279, 1280], [1283, 1309], [1312, 1313], [1319, 1320],
+ [1346, 1347], [1350], [1352], [1401], [1417, 1419], [1422],
+ [1427], [1434, 1435], [1503, 1504], [1536, 1537], [1544, 1545],
+ [1753, 1755], [1759, 1763], [1808, 1811], [1816, 1820],
+ [1834, 1839], [1868], [1879, 1880], [2812, 2813], [2852, 2869],
+ [2872], [2875, 2878], [2881, 2882], [2884, 2886], [2904, 2905],
+ [2907], [2931, 2937], [2941], [2950, 2973], [2997, 2998],
+ [3006, 3007], [3012], [3015], [3026, 3029], [3031, 3032],
+ [3035, 3036], [3038], [3059, 3067], [3073, 3075], [3078],
+ [3168, 3169], [3171], [3470, 3471], [3482, 3483], [3495, 3498],
+ [3506], [3508], [4058, 4061], [4072, 4085], [4090, 4097],
+ [4110, 4113], [4126, 4129], [4132, 4133], [4136, 4139],
+ [4154, 4155], [4160, 4163], [4168, 4171], [4181, 4192],
+ [4201, 4204], [4207], [4214, 4221], [4223, 4224], [4226, 4238],
+ [4241, 4248], [4251, 4252], [4291, 4294], [4299, 4302],
+ [4305, 4309], [4312, 4315], [4381, 4382], [4416, 4417],
+ [4684, 4685], [4696, 4703], [4718], [4720], [4738, 4739],
+ [4754, 4756], [4761, 4762], [4765, 4789], [4792, 4793],
+ [4799, 4800], [4822, 4823], [4826, 4827], [4874], [4890, 4892],
+ [4896], [4900], [4907], [4910], [4975, 4976], [5007, 5008],
+ [5013, 5014], [5222], [5226, 5230], [5269, 5272], [5277, 5281],
+ [5295, 5300], [5329], [5340, 5341], [6273, 6274], [6313, 6330],
+ [6333], [6336, 6337], [6340, 6341], [6343, 6345], [6363, 6365],
+ [6390, 6396], [6398], [6409, 6432], [6456, 6457], [6465, 6466],
+ [6476], [6479, 6480], [6483, 6484], [6486], [6496,
+ 6503], [6879]],
+ 'leftShoulder': [[591], [604, 606], [609], [634, 637], [674], [706, 713],
+ [715], [717], [730], [733, 735], [781, 783], [1238, 1245],
+ [1290, 1291], [1294], [1316, 1318], [1401, 1404], [1509],
+ [1535], [1545], [1808], [1810, 1815], [1818, 1819],
+ [1821, 1833], [1837], [1840, 1859], [1861, 1864],
+ [1872, 1873], [1880, 1881], [1884, 1887], [1890, 1891],
+ [1893, 1899], [2879, 2881], [2886, 2894], [2903],
+ [2938, 2949], [2965], [2967], [2969], [2999, 3005],
+ [3008, 3011]],
+ 'rightShoulder': [[4077], [4091, 4092], [4094, 4095], [4122, 4125], [4162],
+ [4194, 4201], [4203], [4207], [4218, 4219], [4222, 4223],
+ [4269, 4271], [4721, 4728], [4773, 4774], [4778],
+ [4796, 4798], [4874, 4877], [4982], [5006], [5014],
+ [5269], [5271, 5276], [5279], [5281, 5294], [5298],
+ [5301, 5320], [5322, 5325], [5333, 5334], [5341, 5342],
+ [5345, 5348], [5351, 5352], [5354, 5360], [6338, 6340],
+ [6345, 6353], [6362], [6397, 6408], [6424, 6425], [6428],
+ [6458, 6464], [6467, 6470]],
+ 'rightFoot': [[6727, 6869]],
+ 'head': [[0, 149], [154, 173], [176, 205], [220, 221], [225, 255],
+ [258, 283], [286, 295], [303, 304], [306, 307], [310, 332],
+ [335, 422], [427, 439], [442, 450], [454, 459], [461, 569],
+ [574, 583], [1764, 1766], [1770, 1778], [1905, 1908],
+ [2779, 2811], [2814, 2818], [3045, 3048], [3051, 3056], [3058],
+ [3069, 3072], [3161, 3163], [3165, 3167], [3485, 3494], [3499],
+ [3512, 3661], [3666, 3685], [3688, 3717], [3732, 3733],
+ [3737, 3767], [3770, 3795], [3798, 3807], [3815, 3816],
+ [3819, 3838], [3841, 3917], [3922, 3933], [3936, 3941],
+ [3945, 4057], [4062, 4071], [5231, 5233], [5235, 5243],
+ [5366, 5369], [6240, 6272], [6275, 6279], [6492, 6495],
+ [6880, 6889]],
+ 'rightArm': [[4114, 4117], [4122], [4125], [4168], [4171], [4204, 4207],
+ [4257, 4268], [4272, 4281], [4714, 4717], [4741,
+ 4744], [4756],
+ [4763, 4764], [4790, 4791], [4794, 4795], [4816, 4819],
+ [4830, 4833], [4849, 4873], [4876, 4889], [4901, 4906],
+ [4911, 4918], [4974], [4977, 4982], [5009, 5012], [5014],
+ [5088, 5091], [5100, 5111], [5114, 5125], [5128, 5131],
+ [5134, 5153], [5165, 5167],
+ [5172, 5189], [5194], [5200, 5204], [5206], [5208, 5209],
+ [5214, 5218], [5220], [5229], [5292, 5293], [5303], [5306],
+ [5309], [5311], [5314, 5315], [5318, 5319], [5321],
+ [5326, 5328], [5330, 5332], [5335, 5339], [5343, 5344],
+ [5349, 5350], [5353], [5361, 5365], [5370], [6280, 6283],
+ [6354, 6362], [6404, 6405], [6433, 6455], [6461], [6471]],
+ 'leftHandIndex1': [[2027, 2030], [2037, 2040], [2057], [2067, 2068],
+ [2123, 2130], [2132], [2145, 2146], [2152, 2154],
+ [2156, 2169], [2177, 2179], [2181], [2186, 2187],
+ [2190, 2191], [2204, 2205], [2215, 2220], [2232, 2233],
+ [2245, 2247], [2258, 2259], [2261, 2263], [2269, 2270],
+ [2272, 2274], [2276, 2277], [2280, 2283], [2291, 2594],
+ [2596, 2597], [2599, 2604], [2606, 2607], [2609, 2696]],
+ 'rightLeg': [[4481, 4482], [4485, 4486], [4491, 4493], [4495], [4498],
+ [4500, 4501], [4505, 4506], [4529], [4532, 4622],
+ [4634, 4644], [4661, 4669], [4842, 4848], [4937, 4947],
+ [4993, 5003], [6574, 6610], [6719, 6735], [6832, 6836],
+ [6869, 6872]],
+ 'rightHandIndex1': [[5488, 5491], [5498, 5501], [5518], [5528, 5529],
+ [5584, 5592], [5606, 5607], [5613], [5615, 5630],
+ [5638, 5640], [5642], [5647, 5648], [5650, 5651],
+ [5665, 5666], [5676, 5681], [5693, 5694], [5706, 5708],
+ [5719], [5721, 5724], [5730, 5731], [5733, 5735],
+ [5737, 5738], [5741, 5744], [5752, 6055], [6058, 6065],
+ [6068, 6157]],
+ 'leftForeArm': [[1546, 1618], [1620, 1621], [1623, 1630], [1643, 1644],
+ [1646, 1647], [1650, 1651], [1654, 1655], [1657, 1666],
+ [1685, 1695], [1699, 1702], [1721, 1730], [1732], [1736],
+ [1738], [1741, 1744], [1750], [1752], [1900], [1909, 1980],
+ [2019], [2059, 2060], [2073], [2089], [2098, 2112],
+ [2147, 2148], [2206, 2209], [2228], [2230], [2234, 2235],
+ [2241, 2244], [2279], [2286], [2873, 2874]],
+ 'rightForeArm': [[5015, 5087], [5090, 5099], [5112, 5113], [5116, 5117],
+ [5120, 5121], [5124, 5135], [5154, 5164], [5168, 5171],
+ [5190, 5199], [5202],
+ [5205], [5207], [5210, 5213], [5219], [5221], [5361],
+ [5370, 5441], [5480], [5520, 5521], [5534], [5550],
+ [5559, 5573], [5608, 5609], [5667, 5670], [5689], [5691],
+ [5695, 5696], [5702, 5705], [5740], [5747], [6334, 6335]],
+ 'neck': [[148], [150, 153], [172], [174, 175], [201, 202], [204, 219],
+ [222, 225], [256, 257], [284, 285], [295, 309], [333, 334],
+ [423, 426], [440, 441], [451, 453], [460, 461], [571, 572],
+ [824, 829], [1279, 1280], [1312, 1313], [1319, 1320], [1331],
+ [3049, 3050], [3057, 3059], [3068], [3164], [3661, 3665],
+ [3685, 3687], [3714, 3731], [3734, 3737], [3768, 3769],
+ [3796, 3797], [3807, 3819], [3839, 3840], [3918, 3921],
+ [3934, 3935], [3942, 3944], [3950], [4060, 4061], [4312, 4315],
+ [4761, 4762], [4792, 4793], [4799, 4800], [4807]],
+ 'rightToeBase': [[6611, 6718], [6736], [6739], [6741], [6743], [6745],
+ [6747], [6749, 6750], [6752], [6754], [6757, 6758],
+ [6760], [6762]],
+ 'spine': [[616, 617], [630, 633], [654, 657], [662, 665], [720, 721],
+ [765, 768], [796, 799], [889, 890], [916, 919], [921, 926],
+ [1188, 1189], [1211, 1212], [1248, 1251], [1264, 1267],
+ [1323, 1328], [1332, 1336], [1344, 1345], [1481, 1496], [1767],
+ [2823, 2845], [2847, 2848], [2851], [3016, 3020], [3023, 3024],
+ [3124], [3173], [3476, 3478], [3480], [3500,
+ 3502], [3504], [3509],
+ [3511], [4103, 4104], [4118, 4121], [4142, 4145], [4150, 4153],
+ [4208, 4209], [4253, 4256], [4284, 4287], [4375, 4376],
+ [4402, 4403], [4405, 4412], [4674, 4675], [4694, 4695],
+ [4731, 4734], [4747, 4750], [4803, 4806], [4808, 4812],
+ [4820, 4821], [4953, 4968], [5234], [6284, 6306], [6308, 6309],
+ [6312], [6472, 6474], [6545], [6874, 6876], [6878]],
+ 'leftUpLeg': [[833, 834], [838, 839], [847, 854], [870, 881], [897, 915],
+ [933, 936], [944, 1046], [1137, 1148], [1159, 1174],
+ [1184, 1187], [1221, 1230], [1262, 1263], [1274, 1277],
+ [1321, 1322], [1354], [1359, 1362], [1365,
+ 1368], [1451, 1453],
+ [1455, 1463], [1475], [1477, 1480], [1498, 1501],
+ [1511, 1514], [1516, 1522], [1533, 1534], [3125, 3128],
+ [3131, 3135], [3475], [3479]],
+ 'leftHand': [[1981, 2026], [2031, 2036],
+ [2041, 2066], [2069, 2101], [2107], [2111], [2113, 2122],
+ [2127], [2130, 2144], [2149, 2152], [2155], [2160],
+ [2163, 2164], [2170, 2180], [2182, 2185], [2188, 2189],
+ [2191, 2203], [2207], [2209, 2214], [2221, 2229], [2231],
+ [2234], [2236, 2240], [2246, 2260], [2262,
+ 2271], [2274, 2279],
+ [2284, 2285], [2287, 2290], [2293], [2595], [2598], [2605],
+ [2608], [2697, 2778]],
+ 'hips': [[631, 632], [654], [657], [662], [665], [676, 679], [705], [720],
+ [796], [799, 802], [807, 810], [815, 816], [822, 823], [830, 846],
+ [855, 869], [871], [878], [881, 890], [912], [915, 920], [932],
+ [937, 939], [1163], [1166], [1203, 1210], [1246, 1247],
+ [1262, 1263], [1276, 1278], [1321], [1336, 1339], [1353, 1354],
+ [1361, 1364], [1446, 1450], [1454], [1476], [1497], [1511],
+ [1513, 1515], [1533, 1534], [1539, 1540], [1768, 1769],
+ [1779, 1807], [2909, 2930], [3018, 3019], [3021, 3022],
+ [3080, 3124], [3128, 3130], [3136, 3160], [3170], [3172], [3481],
+ [3484], [3500], [3502, 3503], [3507], [3510], [4120, 4121],
+ [4142, 4143], [4150, 4151], [4164, 4167], [4193], [4208],
+ [4284, 4285], [4288, 4290], [4295, 4298], [4303, 4304],
+ [4310, 4311], [4316, 4332], [4341, 4356], [4364, 4365],
+ [4368, 4376], [4398, 4399], [4402, 4406], [4418], [4423, 4425],
+ [4649, 4650], [4689, 4693], [4729, 4730], [4745, 4746],
+ [4759, 4760], [4801], [4812, 4815], [4829], [4836, 4837],
+ [4919, 4923], [4927], [4969], [4983, 4984], [4986], [5004, 5005],
+ [5244, 5268], [6368, 6389], [6473, 6474], [6504, 6545],
+ [6549, 6551], [6557, 6573]]
+}
+
+SMPL_SUPER_SET = {
+ 'FOOT': ['leftFoot', 'leftToeBase', 'rightFoot', 'rightToeBase'],
+ 'HAND': ['leftHand', 'rightHand', 'leftHandIndex1', 'rightHandIndex1'],
+ 'LEG': ['rightUpLeg', 'leftUpLeg', 'leftLeg', 'rightLeg'],
+ 'ARM': ['leftForeArm', 'rightForeArm', 'leftArm', 'rightArm'],
+ 'HEAD': ['neck', 'head'],
+ 'UPBODY': ['spine1', 'spine2', 'leftShoulder', 'rightShoulder'],
+ 'DOWNBODY': ['spine', 'hips']
+}
diff --git a/detrsmpl/core/conventions/segmentation/smplx.py b/detrsmpl/core/conventions/segmentation/smplx.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9d0927341ffc133385ab1d11ff1c37c9fbc4be9
--- /dev/null
+++ b/detrsmpl/core/conventions/segmentation/smplx.py
@@ -0,0 +1,269 @@
+"""Raw index information can be found from smpl-wiki website:
+
+https://meshcapade.wiki/SMPL#mesh-templates--samples
+"""
+SMPLX_SEGMENTATION_DICT = {
+ 'rightHand':
+ [[7331, 7376], [7381, 7386], [7391, 7416], [7419, 7451], [7456], [7459],
+ [7463, 7472], [7479, 7494], [7499, 7501], [7504, 7505], [7512, 7514],
+ [7520, 7530], [7532, 7535], [7539, 7553], [7556], [7558], [7560, 7564],
+ [7571, 7579], [7581], [7585, 7590], [7596, 7610], [7612, 7621],
+ [7624, 7629], [7634, 7635], [7637, 7640], [7643], [7947, 7948],
+ [7957, 7958], [8047, 8128]],
+ 'rightUpLeg': [[6225, 6226], [6228, 6229], [6238, 6245], [6261, 6272],
+ [6288, 6306], [6324, 6327], [6335, 6437], [6528, 6539],
+ [6550, 6565], [6575, 6578], [6609, 6618], [6650, 6651],
+ [6662, 6665], [6706, 6707], [6734], [6739, 6746],
+ [6829, 6831], [6833, 6841], [6853, 6857], [6875, 6878],
+ [6888, 6898], [6909, 6910], [8394, 8397], [8400, 8404],
+ [8721], [8725]],
+ 'leftArm': [[3256, 3259], [3266, 3267], [3311, 3312], [3346, 3349],
+ [3401, 3412], [3416, 3425], [3868, 3871], [3898, 3901], [3912],
+ [3920, 3921], [3947, 3948], [3951, 3952], [3973, 3976],
+ [3987, 3990], [4007, 4031], [4034, 4040], [4042, 4048],
+ [4060, 4064], [4067], [4072, 4079], [4135], [4138, 4143],
+ [4170, 4174], [4249, 4252], [4261, 4272], [4275, 4278],
+ [4281, 4290], [4295, 4296], [4301, 4319], [4322], [4334, 4336],
+ [4341, 4358], [4363], [4369, 4373], [4375], [4377, 4378],
+ [4383, 4387], [4389], [4398], [4449, 4450], [4460],
+ [4464, 4465], [4470, 4471], [4474, 4476], [4478], [4483, 4485],
+ [4487, 4489], [4492, 4496], [4500, 4501], [4506, 4507], [4510],
+ [4518, 4523], [5397, 5400], [5471, 5479], [5542, 5543],
+ [5572, 5573], [5576, 5595], [5597], [5607], [5628]],
+ 'head': [[0, 11], [16, 218], [223, 371], [376, 461], [464,
+ 495], [498, 551],
+ [554, 557], [560, 562], [565, 648], [651, 735], [738, 1209],
+ [1214, 1325], [1327, 1358], [1361, 1385], [1387, 1725],
+ [1728, 1758], [1760, 1789], [1791, 1885], [1887, 1897],
+ [1899, 1930], [1935, 1939], [1942, 1947], [1950, 2035],
+ [2037, 2148], [2152, 2217], [2220, 2483], [2485, 2530],
+ [2532, 2869], [2871, 2892], [2894, 2963], [2965, 2975],
+ [2977, 3011], [3014, 3183], [8731, 8810], [8815, 8838],
+ [8926, 8928], [8931, 8933], [8939], [8941, 8987], [8989, 9019],
+ [9028, 9160], [9162, 9164], [9166, 9382]],
+ 'leftEye': [[9383, 9928]],
+ 'rightEye': [[9929, 10474]],
+ 'leftLeg': [[3625, 3626], [3629, 3630], [3635, 3637], [3639], [3642, 3644],
+ [3649, 3650], [3675, 3733], [3737, 3769], [3781, 3791],
+ [3809, 3817], [3999, 4001], [4003, 4006], [4098, 4108],
+ [4154, 4164], [5728, 5764], [5873, 5889], [8892, 8896],
+ [8935, 8937], [9020]],
+ 'leftToeBase': [[5765, 5872], [5890], [5893], [5895], [5897], [5899],
+ [5901], [5903, 5904], [5906], [5908], [5911, 5912], [5914],
+ [5916]],
+ 'leftFoot': [[5881, 5919], [5922, 5930], [5933], [8728, 8730],
+ [8839, 8925], [8929, 8930], [8934, 8935]],
+ 'spine1':
+ [[3228, 3231], [3240, 3251], [3272, 3273], [3276, 3277], [3282, 3283],
+ [3288, 3291], [3298, 3301], [3314, 3322], [3352], [3355, 3357], [3369],
+ [3383, 3384], [3393, 3394], [3399, 3400], [3426, 3427], [3521, 3524],
+ [3555, 3559], [3570, 3573], [3824, 3830], [3833], [3836, 3838], [3844],
+ [3855, 3856], [3873], [3892, 3893], [3896, 3897], [3908, 3910],
+ [3981, 3982], [3985], [4052, 4054], [4056, 4058], [4069, 4070],
+ [4392, 4394], [5417, 5429], [5448, 5449], [5459], [5483], [5485, 5486],
+ [5489], [5531, 5532], [5534], [5632], [5634, 5635], [5638, 5639], [5642],
+ [5644, 5648], [5944], [5950], [5991, 5994], [6003, 6014], [6035, 6036],
+ [6039, 6040], [6045, 6046], [6051, 6054], [6061, 6064], [6077, 6085],
+ [6115, 6118], [6130], [6144, 6145], [6154, 6155], [6160, 6161],
+ [6187, 6188], [6282, 6285], [6316, 6320], [6331, 6334], [6581, 6588],
+ [6591, 6593], [6599], [6624], [6640, 6641], [6644, 6645], [6656, 6658],
+ [6729, 6730], [6733], [6798, 6800], [6802, 6804], [6813, 6814],
+ [7128, 7130], [8151, 8163], [8182, 8183], [8193], [8217, 8218], [8326],
+ [8328, 8329], [8332, 8333], [8336], [8338, 8342], [8726], [9026]],
+ 'spine2': [[3210, 3211], [3214, 3227], [3232, 3239], [3252, 3255],
+ [3268, 3271], [3274, 3275], [3278, 3281], [3296, 3297],
+ [3302, 3305], [3310, 3313], [3323, 3334], [3342, 3343],
+ [3345, 3347], [3358, 3365], [3367, 3368], [3370, 3382],
+ [3385, 3392], [3395, 3396], [3435, 3438], [3443, 3446],
+ [3449, 3453], [3525, 3526], [3560, 3561], [3831, 3832],
+ [3834, 3835], [3846, 3850], [3853, 3854], [3857], [3872],
+ [3874], [3894, 3895], [3911, 3913], [3922, 3946], [3979, 3980],
+ [3983, 3984], [4032], [4049, 4051], [4055], [4059], [4068],
+ [4071], [4136, 4137], [4168, 4169], [4174, 4175], [4279, 4280],
+ [4391], [4395, 4399], [4426, 4429], [4434, 4438], [4452, 4457],
+ [4486], [4497, 4498], [5349, 5350], [5395, 5396], [5430, 5447],
+ [5450], [5453, 5454], [5457, 5458], [5460, 5462], [5480, 5482],
+ [5484], [5487], [5499, 5501], [5519], [5521,
+ 5526], [5528, 5530],
+ [5533], [5536], [5547, 5556], [5558, 5571], [5598, 5599],
+ [5611, 5612], [5618, 5619], [5621], [5633], [5636, 5637],
+ [5640, 5641], [5643], [5650, 5657], [5920, 5921], [5932],
+ [5935, 5938], [5945], [5947], [5973, 5974], [5977, 5990],
+ [5995, 6002], [6015, 6018], [6031, 6034], [6037, 6038],
+ [6041, 6044], [6059, 6060], [6065, 6068], [6073, 6076],
+ [6086, 6097], [6105, 6106], [6108, 6110], [6119, 6126],
+ [6128, 6129], [6131, 6143], [6146, 6153], [6156, 6157],
+ [6196, 6199], [6204, 6207], [6210, 6214], [6286, 6287],
+ [6321, 6322], [6589, 6590], [6601, 6608], [6623], [6625],
+ [6642, 6643], [6659, 6661], [6670, 6694], [6727, 6728],
+ [6731, 6732], [6779], [6795, 6797], [6801], [6805], [6812],
+ [6815], [6880, 6881], [6912, 6913], [6918, 6919], [7127],
+ [7131, 7135], [7162, 7165], [7170, 7174], [7188, 7193], [7222],
+ [7233, 7234], [8129, 8130], [8164, 8181], [8184], [8187, 8188],
+ [8191, 8192], [8194, 8196], [8214, 8216], [8241, 8247], [8249],
+ [8260, 8283], [8307, 8308], [8316, 8317], [8327], [8330, 8331],
+ [8334, 8335], [8337], [8344, 8351], [8727], [9027]],
+ 'leftShoulder': [[3219], [3233, 3234], [3236, 3237], [3264, 3267], [3303],
+ [3336, 3341], [3343, 3346], [3362, 3363], [3366, 3367],
+ [3413, 3415], [3875, 3878], [3880, 3883], [3929, 3930],
+ [3935], [3953, 3955], [4032, 4035], [4143], [4167],
+ [4174], [4426, 4428], [4430, 4433], [4436], [4438, 4451],
+ [4455], [4458, 4477], [4479, 4482], [4490, 4491],
+ [4498, 4499], [4502, 4505], [4508, 4509], [4511, 4517],
+ [5455, 5457], [5462, 5470], [5479], [5535, 5546],
+ [5563, 5564], [5566], [5602], [5605, 5610], [5624, 5627]],
+ 'rightShoulder': [[5982], [5996, 5997], [5999, 6000], [6027, 6030], [6066],
+ [6099, 6104], [6106, 6109], [6123, 6124], [6127, 6128],
+ [6174, 6176], [6626, 6633], [6677, 6678], [6683],
+ [6701, 6703], [6779, 6782], [6887], [6911], [6918],
+ [7162, 7164], [7166, 7169], [7172], [7174, 7187], [7191],
+ [7194, 7213], [7215, 7218], [7226, 7227], [7234, 7235],
+ [7238, 7241], [7244, 7245], [7247, 7253], [8189, 8191],
+ [8196, 8204], [8213], [8248, 8259], [8275, 8276], [8278],
+ [8309, 8315], [8318, 8321]],
+ 'rightFoot': [[8575, 8717]],
+ 'rightArm':
+ [[6019, 6022], [6029, 6030], [6074, 6075], [6109, 6112], [6162, 6173],
+ [6177, 6186], [6619, 6622], [6646, 6649], [6660], [6668, 6669],
+ [6695, 6696], [6699, 6700], [6721, 6724], [6735, 6738], [6754, 6778],
+ [6781, 6794], [6806, 6811], [6816, 6823], [6879], [6882, 6887],
+ [6914, 6918], [6993, 6996], [7005, 7016], [7019, 7032], [7035, 7036],
+ [7039, 7058], [7070, 7072], [7077, 7094], [7099], [7105, 7109], [7111],
+ [7113, 7114], [7119, 7123], [7125], [7134], [7185, 7186], [7196],
+ [7200, 7201], [7206, 7207], [7210, 7212], [7214], [7219, 7221],
+ [7223, 7225], [7228, 7232], [7236, 7237], [7242, 7243], [7246],
+ [7254, 7259], [8131, 8134], [8205, 8213], [8255, 8256], [8284, 8306],
+ [8312], [8322]],
+ 'leftHandIndex1': [[4641, 4644], [4651, 4654], [4669], [4681, 4682],
+ [4737, 4745], [4759, 4760], [4766, 4768], [4770, 4783],
+ [4791, 4793], [4795], [4800, 4802], [4805],
+ [4818, 4819], [4829, 4834], [4846, 4847], [4859, 4861],
+ [4872], [4874, 4877], [4883, 4884], [4886, 4888],
+ [4890, 4891], [4894, 4897], [4905, 5210], [5213, 5220],
+ [5223, 5310]],
+ 'rightLeg': [[6386, 6387], [6390, 6391], [6396, 6398], [6400],
+ [6403, 6405], [6410, 6411], [6436, 6527], [6539, 6549],
+ [6566, 6574], [6747, 6753], [6842, 6852], [6898, 6908],
+ [8422, 8458], [8567, 8583], [8680, 8684], [8717, 8720]],
+ 'rightHandIndex1': [[7377, 7380], [7387, 7390], [7405], [7417, 7418],
+ [7473, 7481], [7495, 7496], [7502, 7504], [7506, 7519],
+ [7527, 7529], [7531], [7536, 7538], [7541],
+ [7554, 7555], [7565, 7570], [7582, 7583], [7595, 7597],
+ [7608], [7610, 7613], [7619, 7620], [7622, 7624],
+ [7626, 7627], [7630, 7633], [7641, 7946], [7949, 7956],
+ [7959, 8046]],
+ 'leftForeArm': [[4176, 4248], [4251, 4260], [4273, 4274], [4277, 4278],
+ [4283, 4284], [4287, 4290], [4293, 4296], [4299, 4302],
+ [4323, 4333], [4337, 4340], [4359, 4368], [4371], [4374],
+ [4376], [4379, 4382], [4388], [4390], [4518], [4523, 4594],
+ [4632], [4673, 4674], [4686], [4703], [4712, 4726],
+ [4761, 4762], [4820, 4823], [4842], [4844], [4848, 4849],
+ [4855, 4858], [4893], [4900], [5451, 5452]],
+ 'rightForeArm': [[6920, 6992], [6995, 7004], [7017, 7018], [7021, 7022],
+ [7025, 7026], [7029, 7040], [7059, 7069], [7073, 7076],
+ [7095, 7104], [7107],
+ [7110], [7112], [7115, 7118], [7124], [7126], [7254],
+ [7259, 7330], [7368], [7409, 7410], [7422], [7439],
+ [7448, 7462], [7497, 7498], [7556, 7559], [7578], [7580],
+ [7584, 7585], [7591, 7594], [7629], [7636], [8185, 8186]],
+ 'neck': [[12, 15], [219, 222], [372, 375], [462, 463], [496, 497],
+ [552, 553], [558, 559], [563, 564], [649, 650], [736, 737],
+ [1210, 1213], [1326], [1359, 1360], [1386], [1726, 1727], [1759],
+ [1790], [1886], [1898], [1931, 1934], [1940, 1941], [1948, 1949],
+ [2036], [2149, 2151], [2218, 2219], [2484], [2531], [2870],
+ [2893], [2964], [2976], [3012, 3013], [3184, 3213], [3353, 3354],
+ [3435, 3436], [3445, 3446], [3450], [3452, 3453], [3456, 3459],
+ [3857], [3918, 3919], [3944, 3945], [3949, 3950], [3956, 3957],
+ [3964], [5518, 5519], [5527], [5616, 5617], [5649], [5920],
+ [5951, 5976], [6196, 6197], [6206, 6207], [6211], [6213, 6214],
+ [6217, 6220], [6608], [6666, 6667], [6692, 6693], [6697, 6698],
+ [6704, 6705], [6712], [8343], [8938], [8940], [8988]],
+ 'rightToeBase': [[8459, 8566], [8584], [8587], [8589], [8591], [8593],
+ [8595], [8597, 8598], [8600], [8602], [8605, 8606],
+ [8608], [8610]],
+ 'spine': [[3244, 3245], [3260, 3263], [3284, 3287], [3292, 3295],
+ [3350, 3351], [3397, 3400], [3428, 3431], [3519, 3520],
+ [3546, 3547], [3549, 3556], [3822, 3823], [3844, 3845],
+ [3851, 3852], [3886, 3888], [3891], [3904, 3907], [3960, 3963],
+ [3965, 3968], [3970], [3977, 3978], [4114, 4129], [4400],
+ [5401, 5423], [5425, 5426], [5429], [5488, 5489], [5495, 5496],
+ [5623], [5629, 5631], [5699], [5939, 5941], [5943], [5948],
+ [5950], [6007, 6008], [6023, 6026], [6047, 6050], [6055, 6058],
+ [6113, 6114], [6158, 6161], [6189, 6192], [6280, 6281],
+ [6307, 6308], [6310, 6317], [6579, 6580], [6599, 6600],
+ [6636, 6639], [6652, 6655], [6708, 6711], [6713, 6716], [6718],
+ [6725, 6726], [6858, 6873], [7136], [8135, 8157], [8159, 8160],
+ [8163], [8323, 8325], [8393], [8722, 8724], [8726], [9022, 9024],
+ [9026]],
+ 'leftUpLeg': [[3464, 3465], [3467, 3468], [3477, 3484], [3500, 3511],
+ [3527, 3545], [3563, 3566], [3574, 3676], [3770, 3781],
+ [3792, 3803], [3805, 3808], [3818, 3821], [3858, 3867],
+ [3902, 3903], [3914, 3917], [3958, 3959], [3986],
+ [3991, 3998], [4085, 4087], [4089, 4097], [4109, 4113],
+ [4131, 4134], [4144, 4154], [4165, 4166], [5700, 5703],
+ [5706, 5710], [9021], [9025]],
+ 'eyeballs': [[9383, 9516], [9518, 9529], [9531, 9542], [9544, 9555],
+ [9557, 9568], [9570, 9581], [9583, 9594], [9596, 9607],
+ [9609, 9620], [9622, 9633], [9635, 9646], [9648, 9659],
+ [9661, 9672], [9674, 9685], [9687, 9698], [9700, 9711],
+ [9713, 9724], [9726, 9737], [9739, 9750], [9752, 9763],
+ [9765, 9776], [9778, 9789], [9791, 9803], [9805, 9816],
+ [9818, 9829], [9831, 9842], [9844, 9855], [9857, 9868],
+ [9870, 9881], [9883, 9894], [9896, 9907], [9909, 9920],
+ [9922, 10062], [10064, 10075], [10077, 10088], [10090, 10101],
+ [10103, 10114], [10116, 10127], [10129,
+ 10140], [10142, 10153],
+ [10155, 10166], [10168, 10179], [10181,
+ 10192], [10194, 10205],
+ [10207, 10218], [10220, 10231], [10233,
+ 10244], [10246, 10257],
+ [10259, 10270], [10272, 10283], [10285,
+ 10296], [10298, 10309],
+ [10311, 10322], [10324, 10335], [10337,
+ 10349], [10351, 10362],
+ [10364, 10375], [10377, 10388], [10390,
+ 10401], [10403, 10414],
+ [10416, 10427], [10429, 10440], [10442, 10453],
+ [10455, 10466], [10468, 10474]],
+ 'leftHand': [[4595, 4640],
+ [4645, 4650], [4655, 4680], [4683, 4715], [4720], [4723],
+ [4727, 4736], [4743, 4758], [4763, 4765], [4768, 4769],
+ [4776, 4778], [4784, 4794], [4796, 4799], [4803,
+ 4817], [4820],
+ [4822], [4824, 4828], [4835, 4843], [4845], [4849, 4854],
+ [4860, 4874], [4876, 4885], [4888, 4893], [4898, 4899],
+ [4901, 4904], [4907], [5211, 5212], [5221,
+ 5222], [5311, 5348],
+ [5351, 5394]],
+ 'hips': [[3262, 3263], [3284, 3285], [3292, 3293], [3306, 3309], [3335],
+ [3350], [3428, 3429], [3432, 3434], [3439, 3442], [3447, 3448],
+ [3454, 3455], [3460, 3476], [3485, 3500], [3510, 3520],
+ [3542, 3543], [3546, 3550], [3562], [3567, 3569], [3734, 3736],
+ [3798, 3799], [3804], [3839, 3843], [3879], [3884, 3885],
+ [3889, 3890], [3902, 3903], [3916, 3917], [3958], [3969, 3972],
+ [3986], [3993, 3994], [4002], [4041], [4065, 4066], [4080, 4084],
+ [4088], [4130], [4144, 4145], [4147], [4165, 4166], [4291, 4292],
+ [4297, 4298], [4320, 4321], [4401, 4425], [5490, 5494],
+ [5497, 5498], [5502, 5517], [5520], [5557], [5574, 5575], [5596],
+ [5600, 5601], [5603, 5604], [5613, 5615], [5620], [5622],
+ [5630, 5631], [5658, 5699], [5703, 5705], [5711, 5727], [5931],
+ [5934], [5939], [5941, 5942], [5946], [5949], [6025, 6026],
+ [6047, 6048], [6055, 6056], [6069, 6072], [6098], [6113],
+ [6189, 6190], [6193, 6195], [6200, 6203], [6208, 6209],
+ [6215, 6216], [6221, 6237], [6246, 6261], [6271, 6281],
+ [6303, 6304], [6307, 6311], [6323], [6328, 6330], [6556, 6557],
+ [6594, 6598], [6634, 6635], [6650, 6651], [6664, 6665], [6706],
+ [6717, 6720], [6734], [6741, 6742], [6824, 6828], [6832], [6874],
+ [6888, 6889], [6891], [6909, 6910], [7137, 7161], [8219, 8240],
+ [8324, 8325], [8352, 8393], [8397, 8399], [8405, 8421]]
+}
+
+SMPLX_SUPER_SET = {
+ 'FOOT': ['leftFoot', 'leftToeBase', 'rightFoot', 'rightToeBase'],
+ 'HAND': ['leftHand', 'rightHand', 'leftHandIndex1', 'rightHandIndex1'],
+ 'LEG': ['rightUpLeg', 'leftUpLeg', 'leftLeg', 'rightLeg'],
+ 'ARM': ['leftForeArm', 'rightForeArm', 'leftArm', 'rightArm'],
+ 'HEAD': ['neck', 'head', 'leftEye', 'rightEye', 'eyeballs'],
+ 'UPBODY': ['spine1', 'spine2', 'leftShoulder', 'rightShoulder'],
+ 'LOWBODY': ['spine', 'hips'],
+}
diff --git a/detrsmpl/core/distributed_wrapper.py b/detrsmpl/core/distributed_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..106d6553f2b1d66d1b50ab8c0fd4f3dba0e38f45
--- /dev/null
+++ b/detrsmpl/core/distributed_wrapper.py
@@ -0,0 +1,134 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel
+from mmcv.parallel.scatter_gather import scatter_kwargs
+from torch.cuda._utils import _get_device_index
+
+
+@MODULE_WRAPPERS.register_module()
+class DistributedDataParallelWrapper(nn.Module):
+ """A DistributedDataParallel wrapper for models in 3D mesh estimation task.
+
+ In 3D mesh estimation task, there is a need to wrap different modules in
+ the models with separate DistributedDataParallel. Otherwise, it will cause
+ errors for GAN training.
+ More specific, the GAN model, usually has two sub-modules:
+ generator and discriminator. If we wrap both of them in one
+ standard DistributedDataParallel, it will cause errors during training,
+ because when we update the parameters of the generator (or discriminator),
+ the parameters of the discriminator (or generator) is not updated, which is
+ not allowed for DistributedDataParallel.
+ So we design this wrapper to separately wrap DistributedDataParallel
+ for generator and discriminator.
+ In this wrapper, we perform two operations:
+ 1. Wrap the modules in the models with separate MMDistributedDataParallel.
+ Note that only modules with parameters will be wrapped.
+ 2. Do scatter operation for 'forward', 'train_step' and 'val_step'.
+ Note that the arguments of this wrapper is the same as those in
+ `torch.nn.parallel.distributed.DistributedDataParallel`.
+ Args:
+ module (nn.Module): Module that needs to be wrapped.
+ device_ids (list[int | `torch.device`]): Same as that in
+ `torch.nn.parallel.distributed.DistributedDataParallel`.
+ dim (int, optional): Same as that in the official scatter function in
+ pytorch. Defaults to 0.
+ broadcast_buffers (bool): Same as that in
+ `torch.nn.parallel.distributed.DistributedDataParallel`.
+ Defaults to False.
+ find_unused_parameters (bool, optional): Same as that in
+ `torch.nn.parallel.distributed.DistributedDataParallel`.
+ Traverse the autograd graph of all tensors contained in returned
+ value of the wrapped module’s forward function. Defaults to False.
+ kwargs (dict): Other arguments used in
+ `torch.nn.parallel.distributed.DistributedDataParallel`.
+ """
+ def __init__(self,
+ module,
+ device_ids,
+ dim=0,
+ broadcast_buffers=False,
+ find_unused_parameters=False,
+ **kwargs):
+ super().__init__()
+ assert len(device_ids) == 1, (
+ 'Currently, DistributedDataParallelWrapper only supports one'
+ 'single CUDA device for each process.'
+ f'The length of device_ids must be 1, but got {len(device_ids)}.')
+ self.module = module
+ self.dim = dim
+ self.to_ddp(device_ids=device_ids,
+ dim=dim,
+ broadcast_buffers=broadcast_buffers,
+ find_unused_parameters=find_unused_parameters,
+ **kwargs)
+ self.output_device = _get_device_index(device_ids[0], True)
+
+ def to_ddp(self, device_ids, dim, broadcast_buffers,
+ find_unused_parameters, **kwargs):
+ """Wrap models with separate MMDistributedDataParallel.
+
+ It only wraps the modules with parameters.
+ """
+ for name, module in self.module._modules.items():
+ if next(module.parameters(), None) is None:
+ module = module.cuda()
+ elif all(not p.requires_grad for p in module.parameters()):
+ module = module.cuda()
+ else:
+ module = MMDistributedDataParallel(
+ module.cuda(),
+ device_ids=device_ids,
+ dim=dim,
+ broadcast_buffers=broadcast_buffers,
+ find_unused_parameters=find_unused_parameters,
+ **kwargs)
+ self.module._modules[name] = module
+
+ def scatter(self, inputs, kwargs, device_ids):
+ """Scatter function.
+
+ Args:
+ inputs (Tensor): Input Tensor.
+ kwargs (dict): Args for
+ ``mmcv.parallel.scatter_gather.scatter_kwargs``.
+ device_ids (int): Device id.
+ """
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+ def forward(self, *inputs, **kwargs):
+ """Forward function.
+
+ Args:
+ inputs (tuple): Input data.
+ kwargs (dict): Args for
+ ``mmcv.parallel.scatter_gather.scatter_kwargs``.
+ """
+ inputs, kwargs = self.scatter(inputs, kwargs,
+ [torch.cuda.current_device()])
+ return self.module(*inputs[0], **kwargs[0])
+
+ def train_step(self, *inputs, **kwargs):
+ """Train step function.
+
+ Args:
+ inputs (Tensor): Input Tensor.
+ kwargs (dict): Args for
+ ``mmcv.parallel.scatter_gather.scatter_kwargs``.
+ """
+ inputs, kwargs = self.scatter(inputs, kwargs,
+ [torch.cuda.current_device()])
+ output = self.module.train_step(*inputs[0], **kwargs[0])
+ return output
+
+ def val_step(self, *inputs, **kwargs):
+ """Validation step function.
+
+ Args:
+ inputs (tuple): Input data.
+ kwargs (dict): Args for ``scatter_kwargs``.
+ """
+ inputs, kwargs = self.scatter(inputs, kwargs,
+ [torch.cuda.current_device()])
+ output = self.module.val_step(*inputs[0], **kwargs[0])
+ return output
diff --git a/detrsmpl/core/evaluation/__init__.py b/detrsmpl/core/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6ee4cb3215cdb31960ccf00a7d81c69bd5df482
--- /dev/null
+++ b/detrsmpl/core/evaluation/__init__.py
@@ -0,0 +1,17 @@
+from detrsmpl.core.evaluation import mesh_eval
+from detrsmpl.core.evaluation.eval_hooks import DistEvalHook, EvalHook
+from detrsmpl.core.evaluation.eval_utils import (
+ fg_vertices_to_mesh_distance,
+ keypoint_3d_auc,
+ keypoint_3d_pck,
+ keypoint_accel_error,
+ keypoint_mpjpe,
+ vertice_pve,
+)
+from detrsmpl.core.evaluation.mesh_eval import compute_similarity_transform
+
+__all__ = [
+ 'compute_similarity_transform', 'keypoint_mpjpe', 'mesh_eval',
+ 'DistEvalHook', 'EvalHook', 'vertice_pve', 'keypoint_3d_pck',
+ 'keypoint_3d_auc', 'keypoint_accel_error', 'fg_vertices_to_mesh_distance'
+]
diff --git a/detrsmpl/core/evaluation/eval_hooks.py b/detrsmpl/core/evaluation/eval_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..f18709963e7f534a7241cd8ac60c7b6cbfcf03b3
--- /dev/null
+++ b/detrsmpl/core/evaluation/eval_hooks.py
@@ -0,0 +1,135 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import tempfile
+import warnings
+
+from mmcv.runner import DistEvalHook as BaseDistEvalHook
+from mmcv.runner import EvalHook as BaseEvalHook
+
+MMHUMAN3D_GREATER_KEYS = ['3dpck', 'pa-3dpck', '3dauc', 'pa-3dauc']
+MMHUMAN3D_LESS_KEYS = ['mpjpe', 'pa-mpjpe', 'pve']
+
+
+class EvalHook(BaseEvalHook):
+ def __init__(self,
+ dataloader,
+ start=None,
+ interval=1,
+ by_epoch=True,
+ save_best=None,
+ rule=None,
+ test_fn=None,
+ greater_keys=MMHUMAN3D_GREATER_KEYS,
+ less_keys=MMHUMAN3D_LESS_KEYS,
+ **eval_kwargs):
+ if test_fn is None:
+ from detrsmpl.apis import single_gpu_test
+ test_fn = single_gpu_test
+
+ # remove "gpu_collect" from eval_kwargs
+ if 'gpu_collect' in eval_kwargs:
+ warnings.warn(
+ '"gpu_collect" will be deprecated in EvalHook.'
+ 'Please remove it from the config.', DeprecationWarning)
+ _ = eval_kwargs.pop('gpu_collect')
+
+ # update "save_best" according to "key_indicator" and remove the
+ # latter from eval_kwargs
+ if 'key_indicator' in eval_kwargs or isinstance(save_best, bool):
+ warnings.warn(
+ '"key_indicator" will be deprecated in EvalHook.'
+ 'Please use "save_best" to specify the metric key,'
+ 'e.g., save_best="pa-mpjpe".', DeprecationWarning)
+
+ key_indicator = eval_kwargs.pop('key_indicator', None)
+ if save_best is True and key_indicator is None:
+ raise ValueError('key_indicator should not be None, when '
+ 'save_best is set to True.')
+ save_best = key_indicator
+
+ super().__init__(dataloader, start, interval, by_epoch, save_best,
+ rule, test_fn, greater_keys, less_keys, **eval_kwargs)
+
+ def evaluate(self, runner, results):
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ eval_res = self.dataloader.dataset.evaluate(results,
+ res_folder=tmp_dir,
+ logger=runner.logger,
+ **self.eval_kwargs)
+
+ for name, val in eval_res.items():
+ runner.log_buffer.output[name] = val
+ runner.log_buffer.ready = True
+
+ if self.save_best is not None:
+ if self.key_indicator == 'auto':
+ self._init_rule(self.rule, list(eval_res.keys())[0])
+
+ return eval_res[self.key_indicator]
+
+ return None
+
+
+class DistEvalHook(BaseDistEvalHook):
+ def __init__(self,
+ dataloader,
+ start=None,
+ interval=1,
+ by_epoch=True,
+ save_best=None,
+ rule=None,
+ test_fn=None,
+ greater_keys=MMHUMAN3D_GREATER_KEYS,
+ less_keys=MMHUMAN3D_LESS_KEYS,
+ broadcast_bn_buffer=True,
+ tmpdir=None,
+ gpu_collect=False,
+ **eval_kwargs):
+
+ if test_fn is None:
+ from detrsmpl.apis import multi_gpu_test
+ test_fn = multi_gpu_test
+
+ # update "save_best" according to "key_indicator" and remove the
+ # latter from eval_kwargs
+ if 'key_indicator' in eval_kwargs or isinstance(save_best, bool):
+ warnings.warn(
+ '"key_indicator" will be deprecated in EvalHook.'
+ 'Please use "save_best" to specify the metric key,'
+ 'e.g., save_best="pa-mpjpe".', DeprecationWarning)
+
+ key_indicator = eval_kwargs.pop('key_indicator', None)
+ if save_best is True and key_indicator is None:
+ raise ValueError('key_indicator should not be None, when '
+ 'save_best is set to True.')
+ save_best = key_indicator
+
+ super().__init__(dataloader, start, interval, by_epoch, save_best,
+ rule, test_fn, greater_keys, less_keys,
+ broadcast_bn_buffer, tmpdir, gpu_collect,
+ **eval_kwargs)
+
+ def evaluate(self, runner, results):
+ """Evaluate the results.
+
+ Args:
+ runner (:obj:`mmcv.Runner`): The underlined training runner.
+ results (list): Output results.
+ """
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ eval_res = self.dataloader.dataset.evaluate(results,
+ res_folder=tmp_dir,
+ logger=runner.logger,
+ **self.eval_kwargs)
+
+ for name, val in eval_res.items():
+ runner.log_buffer.output[name] = val
+ runner.log_buffer.ready = True
+
+ if self.save_best is not None:
+ if self.key_indicator == 'auto':
+ # infer from eval_results
+ self._init_rule(self.rule, list(eval_res.keys())[0])
+ return eval_res[self.key_indicator]
+
+ return None
diff --git a/detrsmpl/core/evaluation/eval_utils.py b/detrsmpl/core/evaluation/eval_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd0fddd9324cc504271c034d810ac4b975dfd495
--- /dev/null
+++ b/detrsmpl/core/evaluation/eval_utils.py
@@ -0,0 +1,287 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import trimesh
+from trimesh.proximity import closest_point
+
+from .mesh_eval import compute_similarity_transform
+
+
+def keypoint_mpjpe(pred, gt, mask, alignment='none'):
+ """Calculate the mean per-joint position error (MPJPE) and the error after
+ rigid alignment with the ground truth (PA-MPJPE).
+ batch_size: N
+ num_keypoints: K
+ keypoint_dims: C
+ Args:
+ pred (np.ndarray[N, K, C]): Predicted keypoint location.
+ gt (np.ndarray[N, K, C]): Groundtruth keypoint location.
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
+ joints, and True for visible. Invisible joints will be ignored for
+ accuracy calculation.
+ alignment (str, optional): method to align the prediction with the
+ groundtruth. Supported options are:
+ - ``'none'``: no alignment will be applied
+ - ``'scale'``: align in the least-square sense in scale
+ - ``'procrustes'``: align in the least-square sense in scale,
+ rotation and translation.
+ Returns:
+ tuple: A tuple containing joint position errors
+ - mpjpe (float|np.ndarray[N]): mean per-joint position error.
+ - pa-mpjpe (float|np.ndarray[N]): mpjpe after rigid alignment with the
+ ground truth
+ """
+ assert mask.any()
+
+ if alignment == 'none':
+ pass
+ elif alignment == 'procrustes':
+ pred = np.stack([
+ compute_similarity_transform(pred_i, gt_i)
+ for pred_i, gt_i in zip(pred, gt)
+ ])
+ elif alignment == 'scale':
+ pred_dot_pred = np.einsum('nkc,nkc->n', pred, pred)
+ pred_dot_gt = np.einsum('nkc,nkc->n', pred, gt)
+ scale_factor = pred_dot_gt / pred_dot_pred
+ pred = pred * scale_factor[:, None, None]
+ else:
+ raise ValueError(f'Invalid value for alignment: {alignment}')
+
+ error = np.linalg.norm(pred - gt, ord=2, axis=-1)[mask].mean()
+
+ return error
+
+
+def keypoint_accel_error(gt, pred, mask=None):
+ """Computes acceleration error:
+
+ Note that for each frame that is not visible, three entries in the
+ acceleration error should be zero'd out.
+ Args:
+ gt (Nx14x3).
+ pred (Nx14x3).
+ mask (N).
+ Returns:
+ error_accel (N-2).
+ """
+ # (N-2)x14x3
+ accel_gt = gt[:-2] - 2 * gt[1:-1] + gt[2:]
+ accel_pred = pred[:-2] - 2 * pred[1:-1] + pred[2:]
+
+ normed = np.linalg.norm(accel_pred - accel_gt, axis=2)
+
+ if mask is None:
+ new_vis = np.ones(len(normed), dtype=bool)
+ else:
+ invis = np.logical_not(mask)
+ invis1 = np.roll(invis, -1)
+ invis2 = np.roll(invis, -2)
+ new_invis = np.logical_or(invis, np.logical_or(invis1, invis2))[:-2]
+ new_vis = np.logical_not(new_invis)
+
+ return np.mean(normed[new_vis], axis=1)
+
+
+def vertice_pve(pred_verts, target_verts, alignment='none'):
+ """Computes per vertex error (PVE).
+
+ Args:
+ verts_gt (N x verts_num x 3).
+ verts_pred (N x verts_num x 3).
+ alignment (str, optional): method to align the prediction with the
+ groundtruth. Supported options are:
+ - ``'none'``: no alignment will be applied
+ - ``'scale'``: align in the least-square sense in scale
+ - ``'procrustes'``: align in the least-square sense in scale,
+ rotation and translation.
+ Returns:
+ error_verts.
+ """
+ assert len(pred_verts) == len(target_verts)
+ if alignment == 'none':
+ pass
+ elif alignment == 'procrustes':
+ pred_verts = np.stack([
+ compute_similarity_transform(pred_i, gt_i)
+ for pred_i, gt_i in zip(pred_verts, target_verts)
+ ])
+ elif alignment == 'scale':
+ pred_dot_pred = np.einsum('nkc,nkc->n', pred_verts, pred_verts)
+ pred_dot_gt = np.einsum('nkc,nkc->n', pred_verts, target_verts)
+ scale_factor = pred_dot_gt / pred_dot_pred
+ pred_verts = pred_verts * scale_factor[:, None, None]
+ else:
+ raise ValueError(f'Invalid value for alignment: {alignment}')
+ error = np.linalg.norm(pred_verts - target_verts, ord=2, axis=-1).mean()
+ return error
+
+
+def keypoint_3d_pck(pred, gt, mask, alignment='none', threshold=150.):
+ """Calculate the Percentage of Correct Keypoints (3DPCK) w. or w/o rigid
+ alignment.
+ Paper ref: `Monocular 3D Human Pose Estimation In The Wild Using Improved
+ CNN Supervision' 3DV'2017. `__ .
+ Note:
+ - batch_size: N
+ - num_keypoints: K
+ - keypoint_dims: C
+ Args:
+ pred (np.ndarray[N, K, C]): Predicted keypoint location.
+ gt (np.ndarray[N, K, C]): Groundtruth keypoint location.
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
+ joints, and True for visible. Invisible joints will be ignored for
+ accuracy calculation.
+ alignment (str, optional): method to align the prediction with the
+ groundtruth. Supported options are:
+ - ``'none'``: no alignment will be applied
+ - ``'scale'``: align in the least-square sense in scale
+ - ``'procrustes'``: align in the least-square sense in scale,
+ rotation and translation.
+ threshold: If L2 distance between the prediction and the groundtruth
+ is less then threshold, the predicted result is considered as
+ correct. Default: 150 (mm).
+ Returns:
+ pck: percentage of correct keypoints.
+ """
+ assert mask.any()
+
+ if alignment == 'none':
+ pass
+ elif alignment == 'procrustes':
+ pred = np.stack([
+ compute_similarity_transform(pred_i, gt_i)
+ for pred_i, gt_i in zip(pred, gt)
+ ])
+ elif alignment == 'scale':
+ pred_dot_pred = np.einsum('nkc,nkc->n', pred, pred)
+ pred_dot_gt = np.einsum('nkc,nkc->n', pred, gt)
+ scale_factor = pred_dot_gt / pred_dot_pred
+ pred = pred * scale_factor[:, None, None]
+ else:
+ raise ValueError(f'Invalid value for alignment: {alignment}')
+
+ error = np.linalg.norm(pred - gt, ord=2, axis=-1)
+ pck = (error < threshold).astype(np.float32)[mask].mean() * 100
+
+ return pck
+
+
+def keypoint_3d_auc(pred, gt, mask, alignment='none'):
+ """Calculate the Area Under the Curve (3DAUC) computed for a range of 3DPCK
+ thresholds.
+ Paper ref: `Monocular 3D Human Pose Estimation In The Wild Using Improved
+ CNN Supervision' 3DV'2017. `__ .
+ This implementation is derived from mpii_compute_3d_pck.m, which is
+ provided as part of the MPI-INF-3DHP test data release.
+ Note:
+ batch_size: N
+ num_keypoints: K
+ keypoint_dims: C
+ Args:
+ pred (np.ndarray[N, K, C]): Predicted keypoint location.
+ gt (np.ndarray[N, K, C]): Groundtruth keypoint location.
+ mask (np.ndarray[N, K]): Visibility of the target. False for invisible
+ joints, and True for visible. Invisible joints will be ignored for
+ accuracy calculation.
+ alignment (str, optional): method to align the prediction with the
+ groundtruth. Supported options are:
+ - ``'none'``: no alignment will be applied
+ - ``'scale'``: align in the least-square sense in scale
+ - ``'procrustes'``: align in the least-square sense in scale,
+ rotation and translation.
+ Returns:
+ auc: AUC computed for a range of 3DPCK thresholds.
+ """
+ assert mask.any()
+
+ if alignment == 'none':
+ pass
+ elif alignment == 'procrustes':
+ pred = np.stack([
+ compute_similarity_transform(pred_i, gt_i)
+ for pred_i, gt_i in zip(pred, gt)
+ ])
+ elif alignment == 'scale':
+ pred_dot_pred = np.einsum('nkc,nkc->n', pred, pred)
+ pred_dot_gt = np.einsum('nkc,nkc->n', pred, gt)
+ scale_factor = pred_dot_gt / pred_dot_pred
+ pred = pred * scale_factor[:, None, None]
+ else:
+ raise ValueError(f'Invalid value for alignment: {alignment}')
+
+ error = np.linalg.norm(pred - gt, ord=2, axis=-1)
+
+ thresholds = np.linspace(0., 150, 31)
+ pck_values = np.zeros(len(thresholds))
+ for i in range(len(thresholds)):
+ pck_values[i] = (error < thresholds[i]).astype(np.float32)[mask].mean()
+
+ auc = pck_values.mean() * 100
+
+ return auc
+
+
+def fg_vertices_to_mesh_distance(groundtruth_vertices,
+ grundtruth_landmark_points,
+ predicted_mesh_vertices, predicted_mesh_faces,
+ predicted_mesh_landmark_points):
+ """This script computes the reconstruction error between an input mesh and
+ a ground truth mesh.
+ Args:
+ groundtruth_vertices (np.ndarray[N,3]): Ground truth vertices.
+ grundtruth_landmark_points (np.ndarray[7,3]): Ground truth annotations.
+ predicted_mesh_vertices (np.ndarray[M,3]): Predicted vertices.
+ predicted_mesh_faces (np.ndarray[K,3]): Vertex indices
+ composing the predicted mesh.
+ predicted_mesh_landmark_points (np.ndarray[7,3]): Predicted points.
+
+ Return:
+ distance: Mean point to mesh distance.
+
+ The grundtruth_landmark_points and predicted_mesh_landmark_points have to
+ contain points in the following order:
+ (1) right eye outer corner, (2) right eye inner corner,
+ (3) left eye inner corner, (4) left eye outer corner,
+ (5) nose bottom, (6) right mouth corner, (7) left mouth corner.
+ """
+
+ # Do procrustes based on the 7 points:
+ _, tform = compute_similarity_transform(predicted_mesh_landmark_points,
+ grundtruth_landmark_points,
+ return_tform=True)
+ # Use tform to transform all vertices.
+ predicted_mesh_vertices_aligned = (
+ tform['scale'] * tform['rotation'].dot(predicted_mesh_vertices.T) +
+ tform['translation']).T
+
+ # Compute the mask: A circular area around the center of the face.
+ nose_bottom = np.array(grundtruth_landmark_points[4])
+ nose_bridge = (np.array(grundtruth_landmark_points[1]) + np.array(
+ grundtruth_landmark_points[2])) / 2 # between the inner eye corners
+ face_centre = nose_bottom + 0.3 * (nose_bridge - nose_bottom)
+ # Compute the radius for the face mask:
+ outer_eye_dist = np.linalg.norm(
+ np.array(grundtruth_landmark_points[0]) -
+ np.array(grundtruth_landmark_points[3]))
+ nose_dist = np.linalg.norm(nose_bridge - nose_bottom)
+ mask_radius = 1.2 * (outer_eye_dist + nose_dist) / 2
+
+ # Find all the vertex indices in mask area.
+ vertex_indices_mask = []
+ # vertex indices in the source mesh (the ground truth scan)
+ points_on_groundtruth_scan_to_measure_from = []
+ for vertex_idx, vertex in enumerate(groundtruth_vertices):
+ dist = np.linalg.norm(
+ vertex - face_centre
+ ) # We use Euclidean distance for the mask area for now.
+ if dist <= mask_radius:
+ vertex_indices_mask.append(vertex_idx)
+ points_on_groundtruth_scan_to_measure_from.append(vertex)
+ assert len(vertex_indices_mask) == len(
+ points_on_groundtruth_scan_to_measure_from)
+ # Calculate the distance to the surface of the predicted mesh.
+ predicted_mesh = trimesh.Trimesh(predicted_mesh_vertices_aligned,
+ predicted_mesh_faces)
+ _, distance, _ = closest_point(predicted_mesh,
+ points_on_groundtruth_scan_to_measure_from)
+ return distance.mean()
diff --git a/detrsmpl/core/evaluation/mesh_eval.py b/detrsmpl/core/evaluation/mesh_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..913052c4a3e09d380de8e643a047f78d7a11cda6
--- /dev/null
+++ b/detrsmpl/core/evaluation/mesh_eval.py
@@ -0,0 +1,77 @@
+# ------------------------------------------------------------------------------
+# Adapted from https://github.com/akanazawa/hmr
+# Original licence: Copyright (c) 2018 akanazawa, under the MIT License.
+# ------------------------------------------------------------------------------
+
+import numpy as np
+
+
+def compute_similarity_transform(source_points,
+ target_points,
+ return_tform=False):
+ """Computes a similarity transform (sR, t) that takes a set of 3D points
+ source_points (N x 3) closest to a set of 3D points target_points, where R
+ is an 3x3 rotation matrix, t 3x1 translation, s scale.
+
+ And return the
+ transformed 3D points source_points_hat (N x 3). i.e. solves the orthogonal
+ Procrutes problem.
+ Notes:
+ Points number: N
+ Args:
+ source_points (np.ndarray([N, 3])): Source point set.
+ target_points (np.ndarray([N, 3])): Target point set.
+ return_tform (bool) : Whether return transform
+ Returns:
+ source_points_hat (np.ndarray([N, 3])): Transformed source point set.
+ transform (dict): Returns if return_tform is True.
+ Returns rotation: r, 'scale': s, 'translation':t.
+ """
+
+ assert target_points.shape[0] == source_points.shape[0]
+ assert target_points.shape[1] == 3 and source_points.shape[1] == 3
+
+ source_points = source_points.T
+ target_points = target_points.T
+
+ # 1. Remove mean.
+ mu1 = source_points.mean(axis=1, keepdims=True)
+ mu2 = target_points.mean(axis=1, keepdims=True)
+ X1 = source_points - mu1
+ X2 = target_points - mu2
+
+ # 2. Compute variance of X1 used for scale.
+ var1 = np.sum(X1**2)
+
+ # 3. The outer product of X1 and X2.
+ K = X1.dot(X2.T)
+
+ # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
+ # singular vectors of K.
+ U, _, Vh = np.linalg.svd(K)
+ V = Vh.T
+ # Construct Z that fixes the orientation of R to get det(R)=1.
+ Z = np.eye(U.shape[0])
+ Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
+ # Construct R.
+ R = V.dot(Z.dot(U.T))
+
+ # 5. Recover scale.
+ scale = np.trace(R.dot(K)) / var1
+
+ # 6. Recover translation.
+ t = mu2 - scale * (R.dot(mu1))
+
+ # 7. Transform the source points:
+ source_points_hat = scale * R.dot(source_points) + t
+
+ source_points_hat = source_points_hat.T
+
+ if return_tform:
+ return source_points_hat, {
+ 'rotation': R,
+ 'scale': scale,
+ 'translation': t
+ }
+
+ return source_points_hat
diff --git a/detrsmpl/core/optimizer/__init__.py b/detrsmpl/core/optimizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4340ffc075afdcdf3d9f7a398ead394ca5a168a1
--- /dev/null
+++ b/detrsmpl/core/optimizer/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .builder import OPTIMIZERS, build_optimizers
+
+__all__ = ['build_optimizers', 'OPTIMIZERS']
diff --git a/detrsmpl/core/optimizer/builder.py b/detrsmpl/core/optimizer/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f659453a797def86628250cfbb7638ec0f323f
--- /dev/null
+++ b/detrsmpl/core/optimizer/builder.py
@@ -0,0 +1,52 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.runner import build_optimizer
+from mmcv.utils import Registry
+
+OPTIMIZERS = Registry('optimizers')
+
+
+def build_optimizers(model, cfgs):
+ """Build multiple optimizers from configs. If `cfgs` contains several dicts
+ for optimizers, then a dict for each constructed optimizers will be
+ returned. If `cfgs` only contains one optimizer config, the constructed
+ optimizer itself will be returned. For example,
+
+ 1) Multiple optimizer configs:
+
+ .. code-block:: python
+
+ optimizer_cfg = dict(
+ model1=dict(type='SGD', lr=lr),
+ model2=dict(type='SGD', lr=lr))
+
+ The return dict is
+ ``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)``
+
+ 2) Single optimizer config:
+
+ .. code-block:: python
+
+ optimizer_cfg = dict(type='SGD', lr=lr)
+
+ The return is ``torch.optim.Optimizer``.
+
+ Args:
+ model (:obj:`nn.Module`): The model with parameters to be optimized.
+ cfgs (dict): The config dict of the optimizer.
+
+ Returns:
+ dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`:
+ The initialized optimizers.
+ """
+ optimizers = {}
+ if hasattr(model, 'module'):
+ model = model.module
+ # determine whether 'cfgs' has several dicts for optimizers
+ if all(isinstance(v, dict) for v in cfgs.values()):
+ for key, cfg in cfgs.items():
+ cfg_ = cfg.copy()
+ module = getattr(model, key)
+ optimizers[key] = build_optimizer(module, cfg_)
+ return optimizers
+
+ return build_optimizer(model, cfgs)
diff --git a/detrsmpl/core/post_processing/__init__.py b/detrsmpl/core/post_processing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b3945d2377ba675b2f20b63b3b550acccdbdd4f
--- /dev/null
+++ b/detrsmpl/core/post_processing/__init__.py
@@ -0,0 +1,15 @@
+from .builder import build_post_processing
+from .smooth.gaus1d_filter import Gaus1dFilter
+from .smooth.oneeuro_filter import OneEuroFilter
+from .smooth.savgol_filter import SGFilter
+from .smooth.smoothnet import SmoothNetFilter
+from .speed_up.deciwatch import DeciWatchPostProcessing
+
+__all__ = [
+ 'build_post_processing',
+ 'OneEuroFilter',
+ 'SGFilter',
+ 'Gaus1dFilter',
+ 'SmoothNetFilter',
+ 'DeciWatchPostProcessing',
+]
diff --git a/detrsmpl/core/post_processing/bbox/__init__.py b/detrsmpl/core/post_processing/bbox/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..59755b46c6df58c4c5dc74b5fe3d08fa10d1c49f
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/__init__.py
@@ -0,0 +1,6 @@
+from .assigners import AssignResult, BaseAssigner
+
+__all__ = [
+ 'AssignResult',
+ 'BaseAssigner',
+]
diff --git a/detrsmpl/core/post_processing/bbox/assigners/__init__.py b/detrsmpl/core/post_processing/bbox/assigners/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5aa9ad6c937ca6a98afe04d2ac9b81c12d20b185
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/assigners/__init__.py
@@ -0,0 +1,8 @@
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+from .builder import build_assigner
+from .hungarian_assigner import HungarianAssigner
+
+__all__ = [
+ 'build_assigner', 'HungarianAssigner', 'AssignResult', 'BaseAssigner'
+]
diff --git a/detrsmpl/core/post_processing/bbox/assigners/assign_result.py b/detrsmpl/core/post_processing/bbox/assigners/assign_result.py
new file mode 100644
index 0000000000000000000000000000000000000000..a44da61acc0c2bcb5ef0ce74f165f17d1da2764a
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/assigners/assign_result.py
@@ -0,0 +1,208 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+# from mmdet.utils import util_mixins
+from detrsmpl.utils import util_mixins
+
+
+class AssignResult(util_mixins.NiceRepr):
+ """Stores assignments between predicted and truth boxes.
+
+ Attributes:
+ num_gts (int): the number of truth boxes considered when computing this
+ assignment
+
+ gt_inds (LongTensor): for each predicted box indicates the 1-based
+ index of the assigned truth box. 0 means unassigned and -1 means
+ ignore.
+
+ max_overlaps (FloatTensor): the iou between the predicted box and its
+ assigned truth box.
+
+ labels (None | LongTensor): If specified, for each predicted box
+ indicates the category label of the assigned truth box.
+
+ Example:
+ >>> # An assign result between 4 predicted boxes and 9 true boxes
+ >>> # where only two boxes were assigned.
+ >>> num_gts = 9
+ >>> max_overlaps = torch.LongTensor([0, .5, .9, 0])
+ >>> gt_inds = torch.LongTensor([-1, 1, 2, 0])
+ >>> labels = torch.LongTensor([0, 3, 4, 0])
+ >>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels)
+ >>> print(str(self)) # xdoctest: +IGNORE_WANT
+
+ >>> # Force addition of gt labels (when adding gt as proposals)
+ >>> new_labels = torch.LongTensor([3, 4, 5])
+ >>> self.add_gt_(new_labels)
+ >>> print(str(self)) # xdoctest: +IGNORE_WANT
+
+ """
+ def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
+ self.num_gts = num_gts
+ self.gt_inds = gt_inds
+ self.max_overlaps = max_overlaps
+ self.labels = labels
+ # Interface for possible user-defined properties
+ self._extra_properties = {}
+
+ @property
+ def num_preds(self):
+ """int: the number of predictions in this assignment"""
+ return len(self.gt_inds)
+
+ def set_extra_property(self, key, value):
+ """Set user-defined new property."""
+ assert key not in self.info
+ self._extra_properties[key] = value
+
+ def get_extra_property(self, key):
+ """Get user-defined property."""
+ return self._extra_properties.get(key, None)
+
+ @property
+ def info(self):
+ """dict: a dictionary of info about the object"""
+ basic_info = {
+ 'num_gts': self.num_gts,
+ 'num_preds': self.num_preds,
+ 'gt_inds': self.gt_inds,
+ 'max_overlaps': self.max_overlaps,
+ 'labels': self.labels,
+ }
+ basic_info.update(self._extra_properties)
+ return basic_info
+
+ def __nice__(self):
+ """str: a "nice" summary string describing this assign result"""
+ parts = []
+ parts.append(f'num_gts={self.num_gts!r}')
+ if self.gt_inds is None:
+ parts.append(f'gt_inds={self.gt_inds!r}')
+ else:
+ parts.append(f'gt_inds.shape={tuple(self.gt_inds.shape)!r}')
+ if self.max_overlaps is None:
+ parts.append(f'max_overlaps={self.max_overlaps!r}')
+ else:
+ parts.append('max_overlaps.shape='
+ f'{tuple(self.max_overlaps.shape)!r}')
+ if self.labels is None:
+ parts.append(f'labels={self.labels!r}')
+ else:
+ parts.append(f'labels.shape={tuple(self.labels.shape)!r}')
+ return ', '.join(parts)
+
+ @classmethod
+ def random(cls, **kwargs):
+ """Create random AssignResult for tests or debugging.
+
+ Args:
+ num_preds: number of predicted boxes
+ num_gts: number of true boxes
+ p_ignore (float): probability of a predicted box assigned to an
+ ignored truth
+ p_assigned (float): probability of a predicted box not being
+ assigned
+ p_use_label (float | bool): with labels or not
+ rng (None | int | numpy.random.RandomState): seed or state
+
+ Returns:
+ :obj:`AssignResult`: Randomly generated assign results.
+
+ Example:
+ >>> from mmdet.core.bbox.assigners.assign_result import * # NOQA
+ >>> self = AssignResult.random()
+ >>> print(self.info)
+ """
+ from mmdet.core.bbox import demodata
+ rng = demodata.ensure_rng(kwargs.get('rng', None))
+
+ num_gts = kwargs.get('num_gts', None)
+ num_preds = kwargs.get('num_preds', None)
+ p_ignore = kwargs.get('p_ignore', 0.3)
+ p_assigned = kwargs.get('p_assigned', 0.7)
+ p_use_label = kwargs.get('p_use_label', 0.5)
+ num_classes = kwargs.get('p_use_label', 3)
+
+ if num_gts is None:
+ num_gts = rng.randint(0, 8)
+ if num_preds is None:
+ num_preds = rng.randint(0, 16)
+
+ if num_gts == 0:
+ max_overlaps = torch.zeros(num_preds, dtype=torch.float32)
+ gt_inds = torch.zeros(num_preds, dtype=torch.int64)
+ if p_use_label is True or p_use_label < rng.rand():
+ labels = torch.zeros(num_preds, dtype=torch.int64)
+ else:
+ labels = None
+ else:
+ import numpy as np
+
+ # Create an overlap for each predicted box
+ max_overlaps = torch.from_numpy(rng.rand(num_preds))
+
+ # Construct gt_inds for each predicted box
+ is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned)
+ # maximum number of assignments constraints
+ n_assigned = min(num_preds, min(num_gts, is_assigned.sum()))
+
+ assigned_idxs = np.where(is_assigned)[0]
+ rng.shuffle(assigned_idxs)
+ assigned_idxs = assigned_idxs[0:n_assigned]
+ assigned_idxs.sort()
+
+ is_assigned[:] = 0
+ is_assigned[assigned_idxs] = True
+
+ is_ignore = torch.from_numpy(
+ rng.rand(num_preds) < p_ignore) & is_assigned
+
+ gt_inds = torch.zeros(num_preds, dtype=torch.int64)
+
+ true_idxs = np.arange(num_gts)
+ rng.shuffle(true_idxs)
+ true_idxs = torch.from_numpy(true_idxs)
+ gt_inds[is_assigned] = true_idxs[:n_assigned].long()
+
+ gt_inds = torch.from_numpy(
+ rng.randint(1, num_gts + 1, size=num_preds))
+ gt_inds[is_ignore] = -1
+ gt_inds[~is_assigned] = 0
+ max_overlaps[~is_assigned] = 0
+
+ if p_use_label is True or p_use_label < rng.rand():
+ if num_classes == 0:
+ labels = torch.zeros(num_preds, dtype=torch.int64)
+ else:
+ labels = torch.from_numpy(
+ # remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ rng.randint(0, num_classes, size=num_preds))
+ labels[~is_assigned] = 0
+ else:
+ labels = None
+
+ self = cls(num_gts, gt_inds, max_overlaps, labels)
+ return self
+
+ def add_gt_(self, gt_labels):
+ """Add ground truth as assigned results.
+
+ Args:
+ gt_labels (torch.Tensor): Labels of gt boxes
+ """
+ self_inds = torch.arange(1,
+ len(gt_labels) + 1,
+ dtype=torch.long,
+ device=gt_labels.device)
+ self.gt_inds = torch.cat([self_inds, self.gt_inds])
+
+ self.max_overlaps = torch.cat(
+ [self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
+
+ if self.labels is not None:
+ self.labels = torch.cat([gt_labels, self.labels])
diff --git a/detrsmpl/core/post_processing/bbox/assigners/base_assigner.py b/detrsmpl/core/post_processing/bbox/assigners/base_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..e63b71a206234e861d574d7c569f9fb93fc6883e
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/assigners/base_assigner.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+
+class BaseAssigner(metaclass=ABCMeta):
+ """Base assigner that assigns boxes to ground truth boxes."""
+ @abstractmethod
+ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
+ """Assign boxes to either a ground truth boxes or a negative boxes."""
diff --git a/detrsmpl/core/post_processing/bbox/assigners/builder.py b/detrsmpl/core/post_processing/bbox/assigners/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c27ec35f7e314113e33a48a959dac523258125a8
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/assigners/builder.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.utils import Registry, build_from_cfg
+
+BBOX_ASSIGNERS = Registry('bbox_assigner')
+
+
+def build_assigner(cfg, **default_args):
+ """Builder of box assigner."""
+ return build_from_cfg(cfg, BBOX_ASSIGNERS, default_args)
diff --git a/detrsmpl/core/post_processing/bbox/assigners/hungarian_assigner.py b/detrsmpl/core/post_processing/bbox/assigners/hungarian_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd6fabf88c7f39bb561c7656b2ee2f2e24a32788
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/assigners/hungarian_assigner.py
@@ -0,0 +1,189 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import torch
+
+# from detrsmpl.core.post_processing.bbox.transforms
+# import bbox_cxcywh_to_xyxy
+from ..match_costs import build_match_cost
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+from .builder import BBOX_ASSIGNERS
+
+try:
+ from scipy.optimize import linear_sum_assignment
+except ImportError:
+ linear_sum_assignment = None
+
+
+@BBOX_ASSIGNERS.register_module()
+class HungarianAssigner(BaseAssigner):
+ """Computes one-to-one matching between predictions and ground truth.
+
+ This class computes an assignment between the targets and the predictions
+ based on the costs. The costs are weighted sum of three components:
+ classification cost, regression L1 cost and regression iou cost. The
+ targets don't include the no_object, so generally there are more
+ predictions than targets. After the one-to-one matching, the un-matched
+ are treated as backgrounds. Thus each query prediction will be assigned
+ with `0` or a positive integer indicating the ground truth index:
+
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+
+ Args:
+ cls_weight (int | float, optional): The scale factor for classification
+ cost. Default 1.0.
+ bbox_weight (int | float, optional): The scale factor for regression
+ L1 cost. Default 1.0.
+ iou_weight (int | float, optional): The scale factor for regression
+ iou cost. Default 1.0.
+ iou_calculator (dict | optional): The config for the iou calculation.
+ Default type `BboxOverlaps2D`.
+ iou_mode (str | optional): "iou" (intersection over union), "iof"
+ (intersection over foreground), or "giou" (generalized
+ intersection over union). Default "giou".
+ """
+ def __init__(
+ self,
+ # cls_cost=dict(type='ClassificationCost', weight=1.),
+ kp3d_cost=dict(type='Keypoints3DCost', covention='smpl_54',
+ weight=1.0),
+ kp2d_cost=dict(type='Keypoints2DCost', covention='smpl_54',
+ weight=1.0),
+ ):
+ # self.cls_cost = build_match_cost(cls_cost)
+ self.kp2d_cost = build_match_cost(kp2d_cost)
+ self.kp3d_cost = build_match_cost(kp3d_cost)
+
+ def assign(
+ self,
+ pred_smpl_pose,
+ pred_smpl_shape,
+ pred_kp3d,
+ pred_vert,
+ pred_cam,
+ gt_smpl_pose,
+ gt_smpl_shape,
+ gt_kp2d,
+ gt_kp3d,
+ has_keypoints2d,
+ has_keypoints3d,
+ has_smpl,
+ img_meta,
+ gt_bboxes_ignore=None,
+ eps=1e-7,
+ # pred_smpl_orient,
+ # pred_keypoints2d,
+ # gt_bboxes,
+ # gt_labels,
+ ):
+ """Computes one-to-one matching based on the weighted costs.
+
+ This method assign each query prediction to a ground truth or
+ background. The `assigned_gt_inds` with -1 means don't care,
+ 0 means negative sample, and positive number is the index (1-based)
+ of assigned gt.
+ The assignment is done in the following steps, the order matters.
+
+ 1. assign every prediction to -1
+ 2. compute the weighted costs
+ 3. do Hungarian matching on CPU based on the costs
+ 4. assign all to 0 (background) first, then for each matched pair
+ between predictions and gts, treat this prediction as foreground
+ and assign the corresponding gt index (plus 1) to it.
+
+ Args:
+ bbox_pred (Tensor): Predicted boxes with normalized coordinates
+ (cx, cy, w, h), which are all in range [0, 1]. Shape
+ [num_query, 4].
+ cls_pred (Tensor): Predicted classification logits, shape
+ [num_query, num_class].
+ gt_bboxes (Tensor): Ground truth boxes with unnormalized
+ coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
+ gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
+ img_meta (dict): Meta information for current image.
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`. Default None.
+ eps (int | float, optional): A value added to the denominator for
+ numerical stability. Default 1e-7.
+
+ Returns:
+ :obj:`AssignResult`: The assigned result.
+ """
+ assert gt_bboxes_ignore is None, \
+ 'Only case when gt_bboxes_ignore is None is supported.'
+ num_gts, num_bboxes = gt_smpl_pose.size(0), pred_smpl_pose.size(0)
+
+ # 1. assign -1 by default
+ assigned_gt_inds = pred_smpl_pose.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ assigned_labels = pred_smpl_pose.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ if num_gts == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ if num_gts == 0:
+ # No ground truth, assign all to background
+ assigned_gt_inds[:] = 0
+ return AssignResult(num_gts,
+ assigned_gt_inds,
+ None,
+ labels=assigned_labels)
+ # img_h, img_w, _ = img_meta['img_shape']
+ # factor = gt_bboxes.new_tensor([img_w, img_h, img_w,
+ # img_h]).unsqueeze(0)
+
+ # 2. compute the weighted costs
+ # classification and bboxcost.
+ # cls_cost = self.cls_cost(cls_pred, gt_labels)
+ # regression L1 cost
+ # normalize_gt_bboxes = gt_bboxes / factor
+
+ # kp3d_cost
+ kp3d_cost = self.kp3d_cost(pred_kp3d, gt_kp3d)
+
+ # kp2d_cost
+ kp2d_cost = self.kp2d_cost(pred_kp3d, pred_cam, gt_kp2d)
+ # smpl_pose_cost
+
+ # smpl_betas_cost
+
+ # verts_cost
+
+ # TODO: bbox_cost
+
+ # TODO: occlusion == bbox insecaa
+
+ # regression iou cost, defaultly giou is used in official DETR.
+ # bboxes = bbox_cxcywh_to_xyxy(bbox_pred) * factor
+ # iou_cost = self.iou_cost(pred_smpl_pose, gt_smpl_pose)
+ # weighted sum of above three costs
+ cost = kp2d_cost # + kp3d_cost
+
+ # 3. do Hungarian matching on CPU using linear_sum_assignment
+ cost = cost.detach().cpu()
+ if linear_sum_assignment is None:
+ raise ImportError('Please run "pip install scipy" '
+ 'to install scipy first.')
+ matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
+ matched_row_inds = torch.from_numpy(matched_row_inds).to(
+ pred_smpl_pose.device)
+ matched_col_inds = torch.from_numpy(matched_col_inds).to(
+ pred_smpl_pose.device)
+
+ # 4. assign backgrounds and foregrounds
+ # assign all indices to backgrounds first
+ assigned_gt_inds[:] = 0
+ # assign foregrounds based on matching results
+ assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
+ # assigned_labels[matched_row_inds] = None # gt_smpl_pose[matched_col_inds]
+ assigned_labels = None
+ return AssignResult(num_gts,
+ assigned_gt_inds,
+ None,
+ labels=assigned_labels)
+
+ # num_gt: instance_num
+ # assigned_gt_inds: self.gt_inds
+ #
diff --git a/detrsmpl/core/post_processing/bbox/coder/__init__.py b/detrsmpl/core/post_processing/bbox/coder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8217120e4bac7e1044f85242eda60ce09a9d58a
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/coder/__init__.py
@@ -0,0 +1,5 @@
+from .base_bbox_coder import BaseBBoxCoder
+from .builder import build_bbox_coder
+from .distance_point_bbox_coder import DistancePointBBoxCoder
+
+__all__ = ['build_bbox_coder', 'BaseBBoxCoder', 'DistancePointBBoxCoder']
diff --git a/detrsmpl/core/post_processing/bbox/coder/base_bbox_coder.py b/detrsmpl/core/post_processing/bbox/coder/base_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd1f4f97ea96763f85daf0d7010a911a2e9cf88a
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/coder/base_bbox_coder.py
@@ -0,0 +1,17 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+
+class BaseBBoxCoder(metaclass=ABCMeta):
+ """Base bounding box coder."""
+ def __init__(self, **kwargs):
+ pass
+
+ @abstractmethod
+ def encode(self, bboxes, gt_bboxes):
+ """Encode deltas between bboxes and ground truth boxes."""
+
+ @abstractmethod
+ def decode(self, bboxes, bboxes_pred):
+ """Decode the predicted bboxes according to prediction and base
+ boxes."""
diff --git a/detrsmpl/core/post_processing/bbox/coder/builder.py b/detrsmpl/core/post_processing/bbox/coder/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..154eba608de21b22e1273f479c6644c4e6220e85
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/coder/builder.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.utils import Registry, build_from_cfg
+
+BBOX_CODERS = Registry('bbox_coder')
+
+
+def build_bbox_coder(cfg, **default_args):
+ """Builder of box coder."""
+ return build_from_cfg(cfg, BBOX_CODERS, default_args)
diff --git a/detrsmpl/core/post_processing/bbox/coder/distance_point_bbox_coder.py b/detrsmpl/core/post_processing/bbox/coder/distance_point_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c04acbd87ff5832b78f98e90f16920da1ec3428e
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/coder/distance_point_bbox_coder.py
@@ -0,0 +1,63 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..transforms import bbox2distance, distance2bbox
+from .base_bbox_coder import BaseBBoxCoder
+from .builder import BBOX_CODERS
+
+
+@BBOX_CODERS.register_module()
+class DistancePointBBoxCoder(BaseBBoxCoder):
+ """Distance Point BBox coder.
+
+ This coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
+ right) and decode it back to the original.
+
+ Args:
+ clip_border (bool, optional): Whether clip the objects outside the
+ border of the image. Defaults to True.
+ """
+
+ def __init__(self, clip_border=True):
+ super(BaseBBoxCoder, self).__init__()
+ self.clip_border = clip_border
+
+ def encode(self, points, gt_bboxes, max_dis=None, eps=0.1):
+ """Encode bounding box to distances.
+
+ Args:
+ points (Tensor): Shape (N, 2), The format is [x, y].
+ gt_bboxes (Tensor): Shape (N, 4), The format is "xyxy"
+ max_dis (float): Upper bound of the distance. Default None.
+ eps (float): a small value to ensure target < max_dis, instead <=.
+ Default 0.1.
+
+ Returns:
+ Tensor: Box transformation deltas. The shape is (N, 4).
+ """
+ assert points.size(0) == gt_bboxes.size(0)
+ assert points.size(-1) == 2
+ assert gt_bboxes.size(-1) == 4
+ return bbox2distance(points, gt_bboxes, max_dis, eps)
+
+ def decode(self, points, pred_bboxes, max_shape=None):
+ """Decode distance prediction to bounding box.
+
+ Args:
+ points (Tensor): Shape (B, N, 2) or (N, 2).
+ pred_bboxes (Tensor): Distance from the given point to 4
+ boundaries (left, top, right, bottom). Shape (B, N, 4)
+ or (N, 4)
+ max_shape (Sequence[int] or torch.Tensor or Sequence[
+ Sequence[int]],optional): Maximum bounds for boxes, specifies
+ (H, W, C) or (H, W). If priors shape is (B, N, 4), then
+ the max_shape should be a Sequence[Sequence[int]],
+ and the length of max_shape should also be B.
+ Default None.
+ Returns:
+ Tensor: Boxes with shape (N, 4) or (B, N, 4)
+ """
+ assert points.size(0) == pred_bboxes.size(0)
+ assert points.size(-1) == 2
+ assert pred_bboxes.size(-1) == 4
+ if self.clip_border is False:
+ max_shape = None
+ return distance2bbox(points, pred_bboxes, max_shape)
diff --git a/detrsmpl/core/post_processing/bbox/match_costs/__init__.py b/detrsmpl/core/post_processing/bbox/match_costs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..abebdda7277e180ae0a00d78eee3e821061c81e2
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/match_costs/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .builder import build_match_cost
+from .match_cost import (
+ BBoxL1Cost,
+ ClassificationCost,
+ CrossEntropyLossCost,
+ DiceCost,
+ FocalLossCost,
+ IoUCost,
+)
+
+__all__ = [
+ 'build_match_cost', 'ClassificationCost', 'BBoxL1Cost', 'IoUCost',
+ 'FocalLossCost', 'DiceCost', 'CrossEntropyLossCost'
+]
diff --git a/detrsmpl/core/post_processing/bbox/match_costs/builder.py b/detrsmpl/core/post_processing/bbox/match_costs/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea086adff23c5adbc35d448d5a93daf1a04bdc53
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/match_costs/builder.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.utils import Registry, build_from_cfg
+
+MATCH_COST = Registry('Match Cost')
+
+
+def build_match_cost(cfg, default_args=None):
+ """Builder of IoU calculator."""
+ return build_from_cfg(cfg, MATCH_COST, default_args)
diff --git a/detrsmpl/core/post_processing/bbox/match_costs/match_cost.py b/detrsmpl/core/post_processing/bbox/match_costs/match_cost.py
new file mode 100644
index 0000000000000000000000000000000000000000..883803c20abcc574773fe0c5a2473ef5d1d0b919
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/match_costs/match_cost.py
@@ -0,0 +1,551 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn.functional as F
+from mmdet.core.bbox.iou_calculators import bbox_overlaps
+from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
+from typing import Optional, Tuple, Union
+from detrsmpl.core.conventions.keypoints_mapping import get_keypoint_idx
+from detrsmpl.utils.geometry import project_points
+
+from .builder import MATCH_COST
+
+
+@MATCH_COST.register_module()
+class BBoxL1Cost:
+ """BBoxL1Cost.
+
+ Args:
+ weight (int | float, optional): loss_weight
+ box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN
+
+ Examples:
+ >>> from mmdet.core.bbox.match_costs.match_cost import BBoxL1Cost
+ >>> import torch
+ >>> self = BBoxL1Cost()
+ >>> bbox_pred = torch.rand(1, 4)
+ >>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
+ >>> factor = torch.tensor([10, 8, 10, 8])
+ >>> self(bbox_pred, gt_bboxes, factor)
+ tensor([[1.6172, 1.6422]])
+ """
+ def __init__(self, weight=1., box_format='xyxy'):
+ self.weight = weight
+ assert box_format in ['xyxy', 'xywh']
+ self.box_format = box_format
+
+ def __call__(self, bbox_pred, gt_bboxes):
+ """
+ Args:
+ bbox_pred (Tensor): Predicted boxes with normalized coordinates
+ (cx, cy, w, h), which are all in range [0, 1]. Shape
+ (num_query, 4).
+ gt_bboxes (Tensor): Ground truth boxes with normalized
+ coordinates (x1, y1, x2, y2). Shape (num_gt, 4).
+
+ Returns:
+ torch.Tensor: bbox_cost value with weight
+ """
+ if self.box_format == 'xywh':
+ gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes)
+ elif self.box_format == 'xyxy':
+ bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred)
+ bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
+ return bbox_cost * self.weight
+
+
+@MATCH_COST.register_module()
+class FocalLossCost:
+ """FocalLossCost.
+
+ Args:
+ weight (int | float, optional): loss_weight
+ alpha (int | float, optional): focal_loss alpha
+ gamma (int | float, optional): focal_loss gamma
+ eps (float, optional): default 1e-12
+ binary_input (bool, optional): Whether the input is binary,
+ default False.
+
+ Examples:
+ >>> from mmdet.core.bbox.match_costs.match_cost import FocalLossCost
+ >>> import torch
+ >>> self = FocalLossCost()
+ >>> cls_pred = torch.rand(4, 3)
+ >>> gt_labels = torch.tensor([0, 1, 2])
+ >>> factor = torch.tensor([10, 8, 10, 8])
+ >>> self(cls_pred, gt_labels)
+ tensor([[-0.3236, -0.3364, -0.2699],
+ [-0.3439, -0.3209, -0.4807],
+ [-0.4099, -0.3795, -0.2929],
+ [-0.1950, -0.1207, -0.2626]])
+ """
+ def __init__(self,
+ weight=1.,
+ alpha=0.25,
+ gamma=2,
+ eps=1e-12,
+ binary_input=False):
+ self.weight = weight
+ self.alpha = alpha
+ self.gamma = gamma
+ self.eps = eps
+ self.binary_input = binary_input
+
+ def _focal_loss_cost(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): Predicted classification logits, shape
+ (num_query, num_class).
+ gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
+
+ Returns:
+ torch.Tensor: cls_cost value with weight
+ """
+ cls_pred = cls_pred.sigmoid()
+ neg_cost = -(1 - cls_pred + self.eps).log() * (
+ 1 - self.alpha) * cls_pred.pow(self.gamma)
+ pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
+ 1 - cls_pred).pow(self.gamma)
+
+ cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels]
+ return cls_cost * self.weight
+
+ def _mask_focal_loss_cost(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): Predicted classfication logits
+ in shape (num_query, d1, ..., dn), dtype=torch.float32.
+ gt_labels (Tensor): Ground truth in shape (num_gt, d1, ..., dn),
+ dtype=torch.long. Labels should be binary.
+
+ Returns:
+ Tensor: Focal cost matrix with weight in shape\
+ (num_query, num_gt).
+ """
+ cls_pred = cls_pred.flatten(1)
+ gt_labels = gt_labels.flatten(1).float()
+ n = cls_pred.shape[1]
+ cls_pred = cls_pred.sigmoid()
+ neg_cost = -(1 - cls_pred + self.eps).log() * (
+ 1 - self.alpha) * cls_pred.pow(self.gamma)
+ pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
+ 1 - cls_pred).pow(self.gamma)
+
+ cls_cost = torch.einsum('nc,mc->nm', pos_cost, gt_labels) + \
+ torch.einsum('nc,mc->nm', neg_cost, (1 - gt_labels))
+ return cls_cost / n * self.weight
+
+ def __call__(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): Predicted classfication logits.
+ gt_labels (Tensor)): Labels.
+
+ Returns:
+ Tensor: Focal cost matrix with weight in shape\
+ (num_query, num_gt).
+ """
+ if self.binary_input:
+ return self._mask_focal_loss_cost(cls_pred, gt_labels)
+ else:
+ return self._focal_loss_cost(cls_pred, gt_labels)
+
+
+@MATCH_COST.register_module()
+class ClassificationCost:
+ """ClsSoftmaxCost.
+
+ Args:
+ weight (int | float, optional): loss_weight
+
+ Examples:
+ >>> from mmdet.core.bbox.match_costs.match_cost import \
+ ... ClassificationCost
+ >>> import torch
+ >>> self = ClassificationCost()
+ >>> cls_pred = torch.rand(4, 3)
+ >>> gt_labels = torch.tensor([0, 1, 2])
+ >>> factor = torch.tensor([10, 8, 10, 8])
+ >>> self(cls_pred, gt_labels)
+ tensor([[-0.3430, -0.3525, -0.3045],
+ [-0.3077, -0.2931, -0.3992],
+ [-0.3664, -0.3455, -0.2881],
+ [-0.3343, -0.2701, -0.3956]])
+ """
+ def __init__(self, weight=1.):
+ self.weight = weight
+
+ def __call__(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): Predicted classification logits, shape
+ (num_query, num_class).
+ gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
+
+ Returns:
+ torch.Tensor: cls_cost value with weight
+ """
+ # Following the official DETR repo, contrary to the loss that
+ # NLL is used, we approximate it in 1 - cls_score[gt_label].
+ # The 1 is a constant that doesn't change the matching,
+ # so it can be omitted.
+ cls_score = cls_pred.softmax(-1)
+ cls_cost = -cls_score[:, gt_labels]
+ return cls_cost * self.weight
+
+
+@MATCH_COST.register_module()
+class IoUCost:
+ """IoUCost.
+
+ Args:
+ iou_mode (str, optional): iou mode such as 'iou' | 'giou'
+ weight (int | float, optional): loss weight
+
+ Examples:
+ >>> from mmdet.core.bbox.match_costs.match_cost import IoUCost
+ >>> import torch
+ >>> self = IoUCost()
+ >>> bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]])
+ >>> gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
+ >>> self(bboxes, gt_bboxes)
+ tensor([[-0.1250, 0.1667],
+ [ 0.1667, -0.5000]])
+ """
+ def __init__(self, iou_mode='giou', weight=1.):
+ self.weight = weight
+ self.iou_mode = iou_mode
+
+ def __call__(self, bboxes, gt_bboxes):
+ """
+ Args:
+ bboxes (Tensor): Predicted boxes with unnormalized coordinates
+ (x1, y1, x2, y2). Shape (num_query, 4).
+ gt_bboxes (Tensor): Ground truth boxes with unnormalized
+ coordinates (x1, y1, x2, y2). Shape (num_gt, 4).
+
+ Returns:
+ torch.Tensor: iou_cost value with weight
+ """
+ # overlaps: [num_bboxes, num_gt]
+ overlaps = bbox_overlaps(bboxes,
+ gt_bboxes,
+ mode=self.iou_mode,
+ is_aligned=False)
+ # The 1 is a constant that doesn't change the matching, so omitted.
+ iou_cost = -overlaps
+ return iou_cost * self.weight
+
+
+@MATCH_COST.register_module()
+class DiceCost:
+ """Cost of mask assignments based on dice losses.
+
+ Args:
+ weight (int | float, optional): loss_weight. Defaults to 1.
+ pred_act (bool, optional): Whether to apply sigmoid to mask_pred.
+ Defaults to False.
+ eps (float, optional): default 1e-12.
+ naive_dice (bool, optional): If True, use the naive dice loss
+ in which the power of the number in the denominator is
+ the first power. If Flase, use the second power that
+ is adopted by K-Net and SOLO.
+ Defaults to True.
+ """
+ def __init__(self, weight=1., pred_act=False, eps=1e-3, naive_dice=True):
+ self.weight = weight
+ self.pred_act = pred_act
+ self.eps = eps
+ self.naive_dice = naive_dice
+
+ def binary_mask_dice_loss(self, mask_preds, gt_masks):
+ """
+ Args:
+ mask_preds (Tensor): Mask prediction in shape (num_query, *).
+ gt_masks (Tensor): Ground truth in shape (num_gt, *)
+ store 0 or 1, 0 for negative class and 1 for
+ positive class.
+
+ Returns:
+ Tensor: Dice cost matrix in shape (num_query, num_gt).
+ """
+ mask_preds = mask_preds.flatten(1)
+ gt_masks = gt_masks.flatten(1).float()
+ numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
+ if self.naive_dice:
+ denominator = mask_preds.sum(-1)[:, None] + \
+ gt_masks.sum(-1)[None, :]
+ else:
+ denominator = mask_preds.pow(2).sum(1)[:, None] + \
+ gt_masks.pow(2).sum(1)[None, :]
+ loss = 1 - (numerator + self.eps) / (denominator + self.eps)
+ return loss
+
+ def __call__(self, mask_preds, gt_masks):
+ """
+ Args:
+ mask_preds (Tensor): Mask prediction logits in shape (num_query, *)
+ gt_masks (Tensor): Ground truth in shape (num_gt, *)
+
+ Returns:
+ Tensor: Dice cost matrix with weight in shape (num_query, num_gt).
+ """
+ if self.pred_act:
+ mask_preds = mask_preds.sigmoid()
+ dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks)
+ return dice_cost * self.weight
+
+
+@MATCH_COST.register_module()
+class CrossEntropyLossCost:
+ """CrossEntropyLossCost.
+
+ Args:
+ weight (int | float, optional): loss weight. Defaults to 1.
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
+ of softmax. Defaults to True.
+ Examples:
+ >>> from mmdet.core.bbox.match_costs import CrossEntropyLossCost
+ >>> import torch
+ >>> bce = CrossEntropyLossCost(use_sigmoid=True)
+ >>> cls_pred = torch.tensor([[7.6, 1.2], [-1.3, 10]])
+ >>> gt_labels = torch.tensor([[1, 1], [1, 0]])
+ >>> print(bce(cls_pred, gt_labels))
+ """
+ def __init__(self, weight=1., use_sigmoid=True):
+ assert use_sigmoid, 'use_sigmoid = False is not supported yet.'
+ self.weight = weight
+ self.use_sigmoid = use_sigmoid
+
+ def _binary_cross_entropy(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): The prediction with shape (num_query, 1, *) or
+ (num_query, *).
+ gt_labels (Tensor): The learning label of prediction with
+ shape (num_gt, *).
+
+ Returns:
+ Tensor: Cross entropy cost matrix in shape (num_query, num_gt).
+ """
+ cls_pred = cls_pred.flatten(1).float()
+ gt_labels = gt_labels.flatten(1).float()
+ n = cls_pred.shape[1]
+ pos = F.binary_cross_entropy_with_logits(cls_pred,
+ torch.ones_like(cls_pred),
+ reduction='none')
+ neg = F.binary_cross_entropy_with_logits(cls_pred,
+ torch.zeros_like(cls_pred),
+ reduction='none')
+ cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
+ torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
+ cls_cost = cls_cost / n
+
+ return cls_cost
+
+ def __call__(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): Predicted classification logits.
+ gt_labels (Tensor): Labels.
+
+ Returns:
+ Tensor: Cross entropy cost matrix with weight in
+ shape (num_query, num_gt).
+ """
+ if self.use_sigmoid:
+ cls_cost = self._binary_cross_entropy(cls_pred, gt_labels)
+ else:
+ raise NotImplementedError
+
+ return cls_cost * self.weight
+
+
+@MATCH_COST.register_module()
+class Keypoints3DCost(object):
+ """_summary_
+
+ Args:
+ object (_type_): _description_
+ """
+ def __init__(
+ self,
+ convention,
+ weight=1.0,
+ ) -> None:
+ self.weight = weight
+ self.convention = convention
+
+ def __call__(self,
+ pred_keypoints3d: torch.Tensor,
+ gt_keypoints3d: torch.Tensor,
+ has_keypoints3d: Optional[torch.Tensor] = None):
+ """_summary_
+
+ Args:
+ pred (torch.Tensor): pred kp3d with shape [instance_num, kp_num, 3/4]
+ target (torch.Tensor): gt kp3d with shape [batch_size, kp_num, 3/4]
+ pred_conf (_type_, optional): _description_. Defaults to None.
+ target_conf (_type_, optional): _description_. Defaults to None.
+ keypoint_weight (_type_, optional): _description_. Defaults to None.
+
+ Returns:
+ _type_: _description_
+ """
+ # B: batch_size N: instance_num K: kp_num D: 2 for 2D; 3 for 3D
+ Q = pred_keypoints3d.shape[0] # Q means query num
+ N, K, D = gt_keypoints3d.shape
+
+ gt_keypoints3d = gt_keypoints3d.unsqueeze(1).repeat([1, Q, 1, 1])
+ keypoints3d_conf = gt_keypoints3d[..., 3].float().unsqueeze(-1)
+ keypoints3d_conf = keypoints3d_conf.repeat(1, 1, 1, 3)
+ gt_keypoints3d = gt_keypoints3d[..., :3].float()
+ pred_keypoints3d = pred_keypoints3d.unsqueeze(0).repeat([N, 1, 1,
+ 1]).float()
+
+ right_hip_idx = get_keypoint_idx('right_hip_extra', self.convention)
+ left_hip_idx = get_keypoint_idx('left_hip_extra', self.convention)
+
+ gt_pelvis = (gt_keypoints3d[:, :, right_hip_idx, :] +
+ gt_keypoints3d[:, :, left_hip_idx, :]) / 2
+ pred_pelvis = (pred_keypoints3d[:, :, right_hip_idx, :] +
+ pred_keypoints3d[:, :, left_hip_idx, :]) / 2
+
+ gt_keypoints3d = gt_keypoints3d - gt_pelvis[:, :, None, :]
+ pred_keypoints3d = pred_keypoints3d - pred_pelvis[:, :, None, :]
+
+ # [Q, N]
+ loss = torch.abs(gt_keypoints3d - pred_keypoints3d).sum([-2,
+ -1]).permute(
+ 1, 0)
+ # shape: N
+ avg_factor = (keypoints3d_conf[:, 0, :, 0] > 0).sum(-1)
+
+ loss = self.weight * (loss / avg_factor)
+ return loss
+
+
+@MATCH_COST.register_module()
+class Keypoints2DCost(object):
+ """_summary_
+
+ Args:
+ object (_type_): _description_
+ """
+ def __init__(
+ self,
+ convention,
+ weight=1.0,
+ img_res=512,
+ focal_length=5000.,
+ ) -> None:
+ self.weight = weight
+ self.convention = convention
+ self.img_res = img_res
+ self.focal_length = focal_length
+
+ def __call__(self,
+ pred_keypoints3d: torch.Tensor,
+ pred_camera: torch.Tensor,
+ gt_keypoints2d: torch.Tensor,
+ has_keypoints2d: Optional[torch.Tensor] = None):
+ """_summary_
+
+ Args:
+ pred (torch.Tensor): pred kp3d with shape [instance_num, kp_num, 3/4]
+ target (torch.Tensor): gt kp3d with shape [batch_size, kp_num, 3/4]
+ pred_conf (_type_, optional): _description_. Defaults to None.
+ target_conf (_type_, optional): _description_. Defaults to None.
+ keypoint_weight (_type_, optional): _description_. Defaults to None.
+
+ Returns:
+ _type_: _description_
+ """
+ # B: batch_size N: instance_num K: kp_num D: 2 for 2D; 3 for 3D
+ Q = pred_keypoints3d.shape[0] # Q means query num
+ N, K, D = gt_keypoints2d.shape
+
+ gt_keypoints2d = gt_keypoints2d.unsqueeze(1).repeat([1, Q, 1, 1])
+ keypoints2d_conf = gt_keypoints2d[..., 2].float().unsqueeze(-1)
+ keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 1, 2)
+ gt_keypoints2d = gt_keypoints2d[..., :2].float()
+ pred_keypoints3d = pred_keypoints3d.unsqueeze(0).repeat([N, 1, 1,
+ 1]).float()
+ pred_camera = pred_camera.unsqueeze(0).repeat([N, 1, 1]).float()
+
+ cam_t = torch.stack([
+ pred_camera[..., 1], pred_camera[..., 2], 2 * self.focal_length /
+ (self.img_res * pred_camera[..., 0] + 1e-9)
+ ],
+ dim=-1)
+
+ K = torch.zeros([N, Q, 3, 3], device=pred_keypoints3d.device)
+ K[..., 0, 0] = self.focal_length
+ K[..., 1, 1] = self.focal_length
+ K[..., 2, 2] = 1.
+ K[..., :-1, -1] = torch.tensor([self.img_res / 2., self.img_res / 2.],
+ device=pred_keypoints3d.device)
+
+ # transform
+ pred_keypoints3d_ = pred_keypoints3d + cam_t.unsqueeze(2)
+ projected_kp3d = pred_keypoints3d_ / pred_keypoints3d_[
+ ..., -1].unsqueeze(-1)
+
+ # apply camera instrics
+ projected_kp3d = torch.einsum('nqij,nqkj->nqki', K, projected_kp3d)
+ pred_keypoints2d = projected_kp3d[..., :-1]
+
+ # Normalize keypoints to [-1, 1]
+ pred_keypoints2d = 2 * pred_keypoints2d / (self.img_res - 1)
+ gt_keypoints2d = 2 * gt_keypoints2d / (self.img_res - 1)
+
+ # computer loss
+ # [Q, N]
+ loss = torch.abs(gt_keypoints2d - pred_keypoints2d).sum([-2,
+ -1]).permute(
+ 1, 0)
+ # shape: N
+ avg_factor = (keypoints2d_conf[:, 0, :, 0] > 0).sum(-1)
+
+ loss = self.weight * (loss / avg_factor)
+ return loss
+
+
+@MATCH_COST.register_module()
+class KeypointsMSECost(object):
+ """_summary_
+
+ Args:
+ object (_type_): _description_
+ """
+ def __init__(self, weight=1.0) -> None:
+ self.weight = weight
+
+ def __call__(self,
+ pred,
+ target,
+ pred_conf=None,
+ target_conf=None,
+ keypoint_weight=None):
+
+ N = pred.shape[0] # N means instance num
+ B, K, D = pred
+
+ pred_conf = pred_conf.view((N, B, K, 1)) \
+ if pred_conf is not None else 1.0
+ target_conf = target_conf.view((N, B, K, 1)) \
+ if target_conf is not None else 1.0
+ keypoint_weight = keypoint_weight.view((1, 1, K, 1)) \
+ if keypoint_weight is not None else \
+ self.keypoint_weight.view((1, 1, K, 1)).type_as(pred) \
+ if self.keypoint_weight is not None else 1.0
+
+ weight = keypoint_weight * pred_conf * target_conf
+
+ # B: batch_size N: instance_num K: kp_num D: 2 for 2D; 3 for 3D
+ pred = pred.unsqueeze(0).repeat([B, 1, 1, 1]) # B, N, K, D
+ target = target.unsqueeze(1).repeat([1, N, 1, 1])
+
+ loss = self.weight * (weight * F.mse_loss(
+ pred, target, reduction='none').sum(-1)).permute(1, 0)
+
+ return loss
diff --git a/detrsmpl/core/post_processing/bbox/samplers/__init__.py b/detrsmpl/core/post_processing/bbox/samplers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8361def97989d1d08d978879a63ecdf7e5458c5a
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/samplers/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from .base_sampler import BaseSampler
+from .builder import build_sampler
+from .pseudo_sampler import PseudoSampler
+
+__all__ = ['build_sampler', 'BaseSampler', 'PseudoSampler']
diff --git a/detrsmpl/core/post_processing/bbox/samplers/base_sampler.py b/detrsmpl/core/post_processing/bbox/samplers/base_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..dee649739e03013050089d831a28e4c549b06768
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/samplers/base_sampler.py
@@ -0,0 +1,105 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+import torch
+
+from .sampling_result import SamplingResult
+
+
+class BaseSampler(metaclass=ABCMeta):
+ """Base class of samplers."""
+ def __init__(self,
+ num,
+ pos_fraction,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True,
+ **kwargs):
+ self.num = num
+ self.pos_fraction = pos_fraction
+ self.neg_pos_ub = neg_pos_ub
+ self.add_gt_as_proposals = add_gt_as_proposals
+ self.pos_sampler = self
+ self.neg_sampler = self
+
+ @abstractmethod
+ def _sample_pos(self, assign_result, num_expected, **kwargs):
+ """Sample positive samples."""
+ pass
+
+ @abstractmethod
+ def _sample_neg(self, assign_result, num_expected, **kwargs):
+ """Sample negative samples."""
+ pass
+
+ def sample(self,
+ assign_result,
+ bboxes,
+ gt_bboxes,
+ gt_labels=None,
+ **kwargs):
+ """Sample positive and negative bboxes.
+
+ This is a simple implementation of bbox sampling given candidates,
+ assigning results and ground truth bboxes.
+
+ Args:
+ assign_result (:obj:`AssignResult`): Bbox assigning results.
+ bboxes (Tensor): Boxes to be sampled from.
+ gt_bboxes (Tensor): Ground truth bboxes.
+ gt_labels (Tensor, optional): Class labels of ground truth bboxes.
+
+ Returns:
+ :obj:`SamplingResult`: Sampling result.
+
+ Example:
+ >>> from mmdet.core.bbox import RandomSampler
+ >>> from mmdet.core.bbox import AssignResult
+ >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
+ >>> rng = ensure_rng(None)
+ >>> assign_result = AssignResult.random(rng=rng)
+ >>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
+ >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
+ >>> gt_labels = None
+ >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
+ >>> add_gt_as_proposals=False)
+ >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
+ """
+ if len(bboxes.shape) < 2:
+ bboxes = bboxes[None, :]
+
+ bboxes = bboxes[:, :4]
+
+ gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
+ if self.add_gt_as_proposals and len(gt_bboxes) > 0:
+ if gt_labels is None:
+ raise ValueError(
+ 'gt_labels must be given when add_gt_as_proposals is True')
+ bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
+ assign_result.add_gt_(gt_labels)
+ gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
+ gt_flags = torch.cat([gt_ones, gt_flags])
+
+ num_expected_pos = int(self.num * self.pos_fraction)
+ pos_inds = self.pos_sampler._sample_pos(assign_result,
+ num_expected_pos,
+ bboxes=bboxes,
+ **kwargs)
+ # We found that sampled indices have duplicated items occasionally.
+ # (may be a bug of PyTorch)
+ pos_inds = pos_inds.unique()
+ num_sampled_pos = pos_inds.numel()
+ num_expected_neg = self.num - num_sampled_pos
+ if self.neg_pos_ub >= 0:
+ _pos = max(1, num_sampled_pos)
+ neg_upper_bound = int(self.neg_pos_ub * _pos)
+ if num_expected_neg > neg_upper_bound:
+ num_expected_neg = neg_upper_bound
+ neg_inds = self.neg_sampler._sample_neg(assign_result,
+ num_expected_neg,
+ bboxes=bboxes,
+ **kwargs)
+ neg_inds = neg_inds.unique()
+
+ sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
+ assign_result, gt_flags)
+ return sampling_result
diff --git a/detrsmpl/core/post_processing/bbox/samplers/builder.py b/detrsmpl/core/post_processing/bbox/samplers/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c56bad5c8e5502c44d70f9b6660f16e98626949
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/samplers/builder.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.utils import Registry, build_from_cfg
+
+BBOX_SAMPLERS = Registry('bbox_sampler')
+
+
+def build_sampler(cfg, **default_args):
+ """Builder of box sampler."""
+ return build_from_cfg(cfg, BBOX_SAMPLERS, default_args)
diff --git a/detrsmpl/core/post_processing/bbox/samplers/pseudo_sampler.py b/detrsmpl/core/post_processing/bbox/samplers/pseudo_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..f37aa52327b319504b6b7e9a8290f1d4728181dd
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/samplers/pseudo_sampler.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from .base_sampler import BaseSampler
+from .builder import BBOX_SAMPLERS
+from .sampling_result import SamplingResult
+
+
+@BBOX_SAMPLERS.register_module()
+class PseudoSampler(BaseSampler):
+ """A pseudo sampler that does not do sampling actually."""
+ def __init__(self, **kwargs):
+ pass
+
+ def _sample_pos(self, **kwargs):
+ """Sample positive samples."""
+ raise NotImplementedError
+
+ def _sample_neg(self, **kwargs):
+ """Sample negative samples."""
+ raise NotImplementedError
+
+ def sample(self, assign_result, bboxes, gt_bboxes, *args, **kwargs):
+ """Directly returns the positive and negative indices of samples.
+
+ Args:
+ assign_result (:obj:`AssignResult`): Assigned results
+ bboxes (torch.Tensor): Bounding boxes
+ gt_bboxes (torch.Tensor): Ground truth boxes
+
+ Returns:
+ :obj:`SamplingResult`: sampler results
+ """
+ pos_inds = torch.nonzero(assign_result.gt_inds > 0,
+ as_tuple=False).squeeze(-1).unique()
+ neg_inds = torch.nonzero(assign_result.gt_inds == 0,
+ as_tuple=False).squeeze(-1).unique()
+ gt_flags = bboxes.new_zeros(bboxes.shape[0], dtype=torch.uint8)
+ sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
+ assign_result, gt_flags)
+ return sampling_result
diff --git a/detrsmpl/core/post_processing/bbox/samplers/sampling_result.py b/detrsmpl/core/post_processing/bbox/samplers/sampling_result.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1ac5785b8df0b5335b61cc64d78c39aa46cfe25
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/samplers/sampling_result.py
@@ -0,0 +1,150 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmdet.utils import util_mixins
+
+
+class SamplingResult(util_mixins.NiceRepr):
+ """Bbox sampling result.
+
+ Example:
+ >>> # xdoctest: +IGNORE_WANT
+ >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
+ >>> self = SamplingResult.random(rng=10)
+ >>> print(f'self = {self}')
+ self =
+ """
+ def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
+ gt_flags):
+ self.pos_inds = pos_inds
+ self.neg_inds = neg_inds
+ self.pos_bboxes = bboxes[pos_inds]
+ self.neg_bboxes = bboxes[neg_inds]
+ self.pos_is_gt = gt_flags[pos_inds]
+
+ self.num_gts = gt_bboxes.shape[0]
+ self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
+
+ if gt_bboxes.numel() == 0:
+ # hack for index error case
+ assert self.pos_assigned_gt_inds.numel() == 0
+ self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
+ else:
+ if len(gt_bboxes.shape) < 2:
+ gt_bboxes = gt_bboxes.view(-1, 4)
+
+ self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long(), :]
+
+ if assign_result.labels is not None:
+ self.pos_gt_labels = assign_result.labels[pos_inds]
+ else:
+ self.pos_gt_labels = None
+
+ @property
+ def bboxes(self):
+ """torch.Tensor: concatenated positive and negative boxes"""
+ return torch.cat([self.pos_bboxes, self.neg_bboxes])
+
+ def to(self, device):
+ """Change the device of the data inplace.
+
+ Example:
+ >>> self = SamplingResult.random()
+ >>> print(f'self = {self.to(None)}')
+ >>> # xdoctest: +REQUIRES(--gpu)
+ >>> print(f'self = {self.to(0)}')
+ """
+ _dict = self.__dict__
+ for key, value in _dict.items():
+ if isinstance(value, torch.Tensor):
+ _dict[key] = value.to(device)
+ return self
+
+ def __nice__(self):
+ data = self.info.copy()
+ data['pos_bboxes'] = data.pop('pos_bboxes').shape
+ data['neg_bboxes'] = data.pop('neg_bboxes').shape
+ parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
+ body = ' ' + ',\n '.join(parts)
+ return '{\n' + body + '\n}'
+
+ @property
+ def info(self):
+ """Returns a dictionary of info about the object."""
+ return {
+ 'pos_inds': self.pos_inds,
+ 'neg_inds': self.neg_inds,
+ 'pos_bboxes': self.pos_bboxes,
+ 'neg_bboxes': self.neg_bboxes,
+ 'pos_is_gt': self.pos_is_gt,
+ 'num_gts': self.num_gts,
+ 'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
+ }
+
+ @classmethod
+ def random(cls, rng=None, **kwargs):
+ """
+ Args:
+ rng (None | int | numpy.random.RandomState): seed or state.
+ kwargs (keyword arguments):
+ - num_preds: number of predicted boxes
+ - num_gts: number of true boxes
+ - p_ignore (float): probability of a predicted box assigned to \
+ an ignored truth.
+ - p_assigned (float): probability of a predicted box not being \
+ assigned.
+ - p_use_label (float | bool): with labels or not.
+
+ Returns:
+ :obj:`SamplingResult`: Randomly generated sampling result.
+
+ Example:
+ >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
+ >>> self = SamplingResult.random()
+ >>> print(self.__dict__)
+ """
+ from mmdet.core.bbox import demodata
+ from mmdet.core.bbox.assigners.assign_result import AssignResult
+ from mmdet.core.bbox.samplers.random_sampler import RandomSampler
+ rng = demodata.ensure_rng(rng)
+
+ # make probabalistic?
+ num = 32
+ pos_fraction = 0.5
+ neg_pos_ub = -1
+
+ assign_result = AssignResult.random(rng=rng, **kwargs)
+
+ # Note we could just compute an assignment
+ bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
+ gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)
+
+ if rng.rand() > 0.2:
+ # sometimes algorithms squeeze their data, be robust to that
+ gt_bboxes = gt_bboxes.squeeze()
+ bboxes = bboxes.squeeze()
+
+ if assign_result.labels is None:
+ gt_labels = None
+ else:
+ gt_labels = None # todo
+
+ if gt_labels is None:
+ add_gt_as_proposals = False
+ else:
+ add_gt_as_proposals = True # make probabalistic?
+
+ sampler = RandomSampler(num,
+ pos_fraction,
+ neg_pos_ub=neg_pos_ub,
+ add_gt_as_proposals=add_gt_as_proposals,
+ rng=rng)
+ self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
+ return self
diff --git a/detrsmpl/core/post_processing/bbox/transforms.py b/detrsmpl/core/post_processing/bbox/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d72076a5621c5b59c081a8a190b4c8d167c26a5
--- /dev/null
+++ b/detrsmpl/core/post_processing/bbox/transforms.py
@@ -0,0 +1,270 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+
+def find_inside_bboxes(bboxes, img_h, img_w):
+ """Find bboxes as long as a part of bboxes is inside the image.
+
+ Args:
+ bboxes (Tensor): Shape (N, 4).
+ img_h (int): Image height.
+ img_w (int): Image width.
+
+ Returns:
+ Tensor: Index of the remaining bboxes.
+ """
+ inside_inds = (bboxes[:, 0] < img_w) & (bboxes[:, 2] > 0) \
+ & (bboxes[:, 1] < img_h) & (bboxes[:, 3] > 0)
+ return inside_inds
+
+
+def bbox_flip(bboxes, img_shape, direction='horizontal'):
+ """Flip bboxes horizontally or vertically.
+
+ Args:
+ bboxes (Tensor): Shape (..., 4*k)
+ img_shape (tuple): Image shape.
+ direction (str): Flip direction, options are "horizontal", "vertical",
+ "diagonal". Default: "horizontal"
+
+ Returns:
+ Tensor: Flipped bboxes.
+ """
+ assert bboxes.shape[-1] % 4 == 0
+ assert direction in ['horizontal', 'vertical', 'diagonal']
+ flipped = bboxes.clone()
+ if direction == 'horizontal':
+ flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4]
+ flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4]
+ elif direction == 'vertical':
+ flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4]
+ flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4]
+ else:
+ flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4]
+ flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4]
+ flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4]
+ flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4]
+ return flipped
+
+
+def bbox_mapping(bboxes,
+ img_shape,
+ scale_factor,
+ flip,
+ flip_direction='horizontal'):
+ """Map bboxes from the original image scale to testing scale."""
+ new_bboxes = bboxes * bboxes.new_tensor(scale_factor)
+ if flip:
+ new_bboxes = bbox_flip(new_bboxes, img_shape, flip_direction)
+ return new_bboxes
+
+
+def bbox_mapping_back(bboxes,
+ img_shape,
+ scale_factor,
+ flip,
+ flip_direction='horizontal'):
+ """Map bboxes from testing scale to original image scale."""
+ new_bboxes = bbox_flip(bboxes, img_shape,
+ flip_direction) if flip else bboxes
+ new_bboxes = new_bboxes.view(-1, 4) / new_bboxes.new_tensor(scale_factor)
+ return new_bboxes.view(bboxes.shape)
+
+
+def bbox2roi(bbox_list):
+ """Convert a list of bboxes to roi format.
+
+ Args:
+ bbox_list (list[Tensor]): a list of bboxes corresponding to a batch
+ of images.
+
+ Returns:
+ Tensor: shape (n, 5), [batch_ind, x1, y1, x2, y2]
+ """
+ rois_list = []
+ for img_id, bboxes in enumerate(bbox_list):
+ if bboxes.size(0) > 0:
+ img_inds = bboxes.new_full((bboxes.size(0), 1), img_id)
+ rois = torch.cat([img_inds, bboxes[:, :4]], dim=-1)
+ else:
+ rois = bboxes.new_zeros((0, 5))
+ rois_list.append(rois)
+ rois = torch.cat(rois_list, 0)
+ return rois
+
+
+def roi2bbox(rois):
+ """Convert rois to bounding box format.
+
+ Args:
+ rois (torch.Tensor): RoIs with the shape (n, 5) where the first
+ column indicates batch id of each RoI.
+
+ Returns:
+ list[torch.Tensor]: Converted boxes of corresponding rois.
+ """
+ bbox_list = []
+ img_ids = torch.unique(rois[:, 0].cpu(), sorted=True)
+ for img_id in img_ids:
+ inds = (rois[:, 0] == img_id.item())
+ bbox = rois[inds, 1:]
+ bbox_list.append(bbox)
+ return bbox_list
+
+
+def bbox2result(bboxes, labels, num_classes):
+ """Convert detection results to a list of numpy arrays.
+
+ Args:
+ bboxes (torch.Tensor | np.ndarray): shape (n, 5)
+ labels (torch.Tensor | np.ndarray): shape (n, )
+ num_classes (int): class number, including background class
+
+ Returns:
+ list(ndarray): bbox results of each class
+ """
+ if bboxes.shape[0] == 0:
+ return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)]
+ else:
+ if isinstance(bboxes, torch.Tensor):
+ bboxes = bboxes.detach().cpu().numpy()
+ labels = labels.detach().cpu().numpy()
+ return [bboxes[labels == i, :] for i in range(num_classes)]
+
+
+def distance2bbox(points, distance, max_shape=None):
+ """Decode distance prediction to bounding box.
+
+ Args:
+ points (Tensor): Shape (B, N, 2) or (N, 2).
+ distance (Tensor): Distance from the given point to 4
+ boundaries (left, top, right, bottom). Shape (B, N, 4) or (N, 4)
+ max_shape (Sequence[int] or torch.Tensor or Sequence[
+ Sequence[int]],optional): Maximum bounds for boxes, specifies
+ (H, W, C) or (H, W). If priors shape is (B, N, 4), then
+ the max_shape should be a Sequence[Sequence[int]]
+ and the length of max_shape should also be B.
+
+ Returns:
+ Tensor: Boxes with shape (N, 4) or (B, N, 4)
+ """
+
+ x1 = points[..., 0] - distance[..., 0]
+ y1 = points[..., 1] - distance[..., 1]
+ x2 = points[..., 0] + distance[..., 2]
+ y2 = points[..., 1] + distance[..., 3]
+
+ bboxes = torch.stack([x1, y1, x2, y2], -1)
+
+ if max_shape is not None:
+ if bboxes.dim() == 2 and not torch.onnx.is_in_onnx_export():
+ # speed up
+ bboxes[:, 0::2].clamp_(min=0, max=max_shape[1])
+ bboxes[:, 1::2].clamp_(min=0, max=max_shape[0])
+ return bboxes
+
+ # clip bboxes with dynamic `min` and `max` for onnx
+ if torch.onnx.is_in_onnx_export():
+ from mmdet.core.export import dynamic_clip_for_onnx
+ x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape)
+ bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
+ return bboxes
+ if not isinstance(max_shape, torch.Tensor):
+ max_shape = x1.new_tensor(max_shape)
+ max_shape = max_shape[..., :2].type_as(x1)
+ if max_shape.ndim == 2:
+ assert bboxes.ndim == 3
+ assert max_shape.size(0) == bboxes.size(0)
+
+ min_xy = x1.new_tensor(0)
+ max_xy = torch.cat([max_shape, max_shape],
+ dim=-1).flip(-1).unsqueeze(-2)
+ bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
+ bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
+
+ return bboxes
+
+
+def bbox2distance(points, bbox, max_dis=None, eps=0.1):
+ """Decode bounding box based on distances.
+
+ Args:
+ points (Tensor): Shape (n, 2), [x, y].
+ bbox (Tensor): Shape (n, 4), "xyxy" format
+ max_dis (float): Upper bound of the distance.
+ eps (float): a small value to ensure target < max_dis, instead <=
+
+ Returns:
+ Tensor: Decoded distances.
+ """
+ left = points[:, 0] - bbox[:, 0]
+ top = points[:, 1] - bbox[:, 1]
+ right = bbox[:, 2] - points[:, 0]
+ bottom = bbox[:, 3] - points[:, 1]
+ if max_dis is not None:
+ left = left.clamp(min=0, max=max_dis - eps)
+ top = top.clamp(min=0, max=max_dis - eps)
+ right = right.clamp(min=0, max=max_dis - eps)
+ bottom = bottom.clamp(min=0, max=max_dis - eps)
+ return torch.stack([left, top, right, bottom], -1)
+
+
+def bbox_rescale(bboxes, scale_factor=1.0):
+ """Rescale bounding box w.r.t. scale_factor.
+
+ Args:
+ bboxes (Tensor): Shape (n, 4) for bboxes or (n, 5) for rois
+ scale_factor (float): rescale factor
+
+ Returns:
+ Tensor: Rescaled bboxes.
+ """
+ if bboxes.size(1) == 5:
+ bboxes_ = bboxes[:, 1:]
+ inds_ = bboxes[:, 0]
+ else:
+ bboxes_ = bboxes
+ cx = (bboxes_[:, 0] + bboxes_[:, 2]) * 0.5
+ cy = (bboxes_[:, 1] + bboxes_[:, 3]) * 0.5
+ w = bboxes_[:, 2] - bboxes_[:, 0]
+ h = bboxes_[:, 3] - bboxes_[:, 1]
+ w = w * scale_factor
+ h = h * scale_factor
+ x1 = cx - 0.5 * w
+ x2 = cx + 0.5 * w
+ y1 = cy - 0.5 * h
+ y2 = cy + 0.5 * h
+ if bboxes.size(1) == 5:
+ rescaled_bboxes = torch.stack([inds_, x1, y1, x2, y2], dim=-1)
+ else:
+ rescaled_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
+ return rescaled_bboxes
+
+
+def bbox_cxcywh_to_xyxy(bbox):
+ """Convert bbox coordinates from (cx, cy, w, h) to (x1, y1, x2, y2).
+
+ Args:
+ bbox (Tensor): Shape (n, 4) for bboxes.
+
+ Returns:
+ Tensor: Converted bboxes.
+ """
+ cx, cy, w, h = bbox.split((1, 1, 1, 1), dim=-1)
+ bbox_new = [(cx - 0.5 * w), (cy - 0.5 * h), (cx + 0.5 * w), (cy + 0.5 * h)]
+ return torch.cat(bbox_new, dim=-1)
+
+
+def bbox_xyxy_to_cxcywh(bbox):
+ """Convert bbox coordinates from (x1, y1, x2, y2) to (cx, cy, w, h).
+
+ Args:
+ bbox (Tensor): Shape (n, 4) for bboxes.
+
+ Returns:
+ Tensor: Converted bboxes.
+ """
+ x1, y1, x2, y2 = bbox.split((1, 1, 1, 1), dim=-1)
+ bbox_new = [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)]
+ return torch.cat(bbox_new, dim=-1)
diff --git a/detrsmpl/core/post_processing/builder.py b/detrsmpl/core/post_processing/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae2acf4afdf20891a259d14e86a674f695a61c89
--- /dev/null
+++ b/detrsmpl/core/post_processing/builder.py
@@ -0,0 +1,8 @@
+from mmcv.utils import Registry
+
+POST_PROCESSING = Registry('post_processing')
+
+
+def build_post_processing(cfg):
+ """Build post processing function."""
+ return POST_PROCESSING.build(cfg)
diff --git a/detrsmpl/core/post_processing/smooth/__init__.py b/detrsmpl/core/post_processing/smooth/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/core/post_processing/smooth/gaus1d_filter.py b/detrsmpl/core/post_processing/smooth/gaus1d_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..403f79a39d4bfe8193e5c26b60a116acc4b7aaf3
--- /dev/null
+++ b/detrsmpl/core/post_processing/smooth/gaus1d_filter.py
@@ -0,0 +1,60 @@
+import warnings
+
+import numpy as np
+import scipy.signal as signal
+import torch
+from scipy.ndimage.filters import gaussian_filter1d
+
+from ..builder import POST_PROCESSING
+
+
+@POST_PROCESSING.register_module(name=['Gaus1dFilter', 'gaus1d'])
+class Gaus1dFilter:
+ """Applies median filter and then gaussian filter. code from:
+ https://github.com/akanazawa/human_dynamics/blob/mas
+ ter/src/util/smooth_bbox.py.
+
+ Args:
+ x (np.ndarray): input pose
+ window_size (int, optional): for median filters (must be odd).
+ sigma (float, optional): Sigma for gaussian smoothing.
+
+ Returns:
+ np.ndarray: Smoothed poses
+ """
+ def __init__(self, window_size=11, sigma=4):
+ super(Gaus1dFilter, self).__init__()
+
+ self.window_size = window_size
+ self.sigma = sigma
+
+ def __call__(self, x=None):
+ if self.window_size % 2 == 0:
+ window_size = self.window_size - 1
+ else:
+ window_size = self.window_size
+ if window_size > x.shape[0]:
+ window_size = x.shape[0]
+ if len(x.shape) != 3:
+ warnings.warn('x should be a tensor or numpy of [T*M,K,C]')
+ assert len(x.shape) == 3
+ x_type = x
+ if isinstance(x, torch.Tensor):
+ if x.is_cuda:
+ x = x.cpu().numpy()
+ else:
+ x = x.numpy()
+
+ smoothed = np.array(
+ [signal.medfilt(param, window_size) for param in x.T]).T
+ smooth_poses = np.array(
+ [gaussian_filter1d(traj, self.sigma) for traj in smoothed.T]).T
+
+ if isinstance(x_type, torch.Tensor):
+ # we also return tensor by default
+ if x_type.is_cuda:
+ smooth_poses = torch.from_numpy(smooth_poses).cuda()
+ else:
+ smooth_poses = torch.from_numpy(smooth_poses)
+
+ return smooth_poses
diff --git a/detrsmpl/core/post_processing/smooth/oneeuro_filter.py b/detrsmpl/core/post_processing/smooth/oneeuro_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ff5ea5dfde38273414c22e5ebd61739a0f3ae21
--- /dev/null
+++ b/detrsmpl/core/post_processing/smooth/oneeuro_filter.py
@@ -0,0 +1,114 @@
+import math
+import warnings
+
+import numpy as np
+import torch
+
+from ..builder import POST_PROCESSING
+
+
+def smoothing_factor(t_e, cutoff):
+ r = 2 * math.pi * cutoff * t_e
+ return r / (r + 1)
+
+
+def exponential_smoothing(a, x, x_prev):
+ return a * x + (1 - a) * x_prev
+
+
+class OneEuro:
+ def __init__(self,
+ t0,
+ x0,
+ dx0=0.0,
+ min_cutoff=1.0,
+ beta=0.0,
+ d_cutoff=1.0):
+ super(OneEuro, self).__init__()
+ """Initialize the one euro filter."""
+ # The parameters.
+ self.min_cutoff = float(min_cutoff)
+ self.beta = float(beta)
+ self.d_cutoff = float(d_cutoff)
+ # Previous values.
+ self.x_prev = x0
+ self.dx_prev = dx0
+ self.t_prev = t0
+
+ def __call__(self, t, x):
+ """Compute the filtered signal."""
+ t_e = t - self.t_prev
+
+ # The filtered derivative of the signal.
+ a_d = smoothing_factor(t_e, self.d_cutoff) # [k, c]
+ dx = (x - self.x_prev) / t_e
+ dx_hat = exponential_smoothing(a_d, dx, self.dx_prev)
+
+ # The filtered signal.
+ cutoff = self.min_cutoff + self.beta * np.abs(dx_hat)
+ a = smoothing_factor(t_e, cutoff)
+ x_hat = exponential_smoothing(a, x, self.x_prev)
+ # Memorize the previous values.
+ self.x_prev = x_hat
+ self.dx_prev = dx_hat
+ self.t_prev = t
+ return x_hat
+
+
+@POST_PROCESSING.register_module(name=['OneEuroFilter', 'oneeuro'])
+class OneEuroFilter:
+ """Oneeuro filter, source code: https://github.com/mkocabas/VIBE/blob/c0
+ c3f77d587351c806e901221a9dc05d1ffade4b/lib/utils/smooth_pose.py.
+
+ Args:
+ min_cutoff (float, optional):
+ Decreasing the minimum cutoff frequency decreases slow speed jitter
+ beta (float, optional):
+ Increasing the speed coefficient(beta) decreases speed lag.
+
+ Returns:
+ np.ndarray: smoothed poses
+ """
+ def __init__(self, min_cutoff=0.004, beta=0.7):
+ super(OneEuroFilter, self).__init__()
+
+ self.min_cutoff = min_cutoff
+ self.beta = beta
+
+ def __call__(self, x=None):
+ # x (np.ndarray): input poses.
+ if len(x.shape) != 3:
+ warnings.warn('x should be a tensor or numpy of [T*M,K,C]')
+ assert len(x.shape) == 3
+ x_type = x
+ if isinstance(x, torch.Tensor):
+ if x.is_cuda:
+ x = x.cpu().numpy()
+ else:
+ x = x.numpy()
+
+ one_euro_filter = OneEuro(
+ np.zeros_like(x[0]),
+ x[0],
+ min_cutoff=self.min_cutoff,
+ beta=self.beta,
+ )
+
+ pred_pose_hat = np.zeros_like(x)
+
+ # initialize
+ pred_pose_hat[0] = x[0]
+
+ for idx, pose in enumerate(x[1:]):
+ idx += 1
+ t = np.ones_like(pose) * idx
+ pose = one_euro_filter(t, pose)
+ pred_pose_hat[idx] = pose
+
+ if isinstance(x_type, torch.Tensor):
+ # we also return tensor by default
+ if x_type.is_cuda:
+ pred_pose_hat = torch.from_numpy(pred_pose_hat).cuda()
+ else:
+ pred_pose_hat = torch.from_numpy(pred_pose_hat)
+ return pred_pose_hat
diff --git a/detrsmpl/core/post_processing/smooth/savgol_filter.py b/detrsmpl/core/post_processing/smooth/savgol_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..a481833557f92c3b19af26ee5c75f091fab89540
--- /dev/null
+++ b/detrsmpl/core/post_processing/smooth/savgol_filter.py
@@ -0,0 +1,73 @@
+import warnings
+
+import numpy as np
+import scipy.signal as signal
+import torch
+
+from ..builder import POST_PROCESSING
+
+
+@POST_PROCESSING.register_module(name=['SGFilter', 'savgol'])
+class SGFilter:
+ """savgol_filter lib is from:
+ https://docs.scipy.org/doc/scipy/reference/generated/
+ scipy.signal.savgol_filter.html.
+
+ Args:
+ window_size (float):
+ The length of the filter window
+ (i.e., the number of coefficients).
+ window_length must be a positive odd integer.
+ polyorder (int):
+ The order of the polynomial used to fit the samples.
+ polyorder must be less than window_length.
+
+ Returns:
+ smoothed poses (np.ndarray, torch.tensor)
+ """
+ def __init__(self, window_size=11, polyorder=2):
+ super(SGFilter, self).__init__()
+
+ # 1-D Savitzky-Golay filter
+ self.window_size = window_size
+ self.polyorder = polyorder
+
+ def __call__(self, x=None):
+ # x.shape: [t,k,c]
+ if self.window_size % 2 == 0:
+ window_size = self.window_size - 1
+ else:
+ window_size = self.window_size
+ if window_size > x.shape[0]:
+ window_size = x.shape[0]
+ if window_size <= self.polyorder:
+ polyorder = window_size - 1
+ else:
+ polyorder = self.polyorder
+ assert polyorder > 0
+ assert window_size > polyorder
+ if len(x.shape) != 3:
+ warnings.warn('x should be a tensor or numpy of [T*M,K,C]')
+ assert len(x.shape) == 3
+ x_type = x
+ if isinstance(x, torch.Tensor):
+ if x.is_cuda:
+ x = x.cpu().numpy()
+ else:
+ x = x.numpy()
+ smooth_poses = np.zeros_like(x)
+ # smooth at different axis
+ C = x.shape[-1]
+ for i in range(C):
+ smooth_poses[..., i] = signal.savgol_filter(x[..., i],
+ window_size,
+ polyorder,
+ axis=0)
+
+ if isinstance(x_type, torch.Tensor):
+ # we also return tensor by default
+ if x_type.is_cuda:
+ smooth_poses = torch.from_numpy(smooth_poses).cuda()
+ else:
+ smooth_poses = torch.from_numpy(smooth_poses)
+ return smooth_poses
diff --git a/detrsmpl/core/post_processing/smooth/smoothnet.py b/detrsmpl/core/post_processing/smooth/smoothnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1412b72cb6119a07894939754a6343e8332f714d
--- /dev/null
+++ b/detrsmpl/core/post_processing/smooth/smoothnet.py
@@ -0,0 +1,237 @@
+from typing import Optional
+
+import numpy as np
+import torch
+from mmcv.runner import load_checkpoint
+from torch import Tensor, nn
+
+from detrsmpl.utils.transforms import (
+ aa_to_rotmat,
+ rot6d_to_rotmat,
+ rotmat_to_aa,
+ rotmat_to_rot6d,
+)
+from ..builder import POST_PROCESSING
+
+
+class SmoothNetResBlock(nn.Module):
+ """Residual block module used in SmoothNet.
+
+ Args:
+ in_channels (int): Input channel number.
+ hidden_channels (int): The hidden feature channel number.
+ dropout (float): Dropout probability. Default: 0.5
+ Shape:
+ Input: (*, in_channels)
+ Output: (*, in_channels)
+ """
+ def __init__(self, in_channels, hidden_channels, dropout=0.1):
+ super().__init__()
+ self.linear1 = nn.Linear(in_channels, hidden_channels)
+ self.linear2 = nn.Linear(hidden_channels, in_channels)
+ self.lrelu = nn.LeakyReLU(0.2, inplace=True)
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
+
+ def forward(self, x):
+ identity = x
+ x = self.linear1(x)
+ x = self.dropout(x)
+ x = self.lrelu(x)
+ x = self.linear2(x)
+ x = self.dropout(x)
+ x = self.lrelu(x)
+
+ out = x + identity
+ return out
+
+
+class SmoothNet(nn.Module):
+ """SmoothNet is a plug-and-play temporal-only network to refine human
+ poses. It works for 2d/3d/6d pose smoothing.
+ "SmoothNet: A Plug-and-Play Network for Refining Human Poses in Videos",
+ arXiv'2021. More details can be found in the `paper
+ `__ .
+ Note:
+ N: The batch size
+ T: The temporal length of the pose sequence
+ C: The total pose dimension (e.g. keypoint_number * keypoint_dim)
+ Args:
+ window_size (int): The size of the input window.
+ output_size (int): The size of the output window.
+ hidden_size (int): The hidden feature dimension in the encoder,
+ the decoder and between residual blocks. Default: 512
+ res_hidden_size (int): The hidden feature dimension inside the
+ residual blocks. Default: 256
+ num_blocks (int): The number of residual blocks. Default: 3
+ dropout (float): Dropout probability. Default: 0.5
+ Shape:
+ Input: (N, C, T) the original pose sequence
+ Output: (N, C, T) the smoothed pose sequence
+ """
+ def __init__(self,
+ window_size: int,
+ output_size: int,
+ hidden_size: int = 512,
+ res_hidden_size: int = 512,
+ num_blocks: int = 5,
+ dropout: float = 0.1):
+ super().__init__()
+ self.window_size = window_size
+ self.output_size = output_size
+ self.hidden_size = hidden_size
+ self.res_hidden_size = res_hidden_size
+ self.num_blocks = num_blocks
+ self.dropout = dropout
+
+ assert output_size <= window_size, (
+ 'The output size should be less than or equal to the window size.',
+ f' Got output_size=={output_size} and window_size=={window_size}')
+
+ # Build encoder layers
+ self.encoder = nn.Sequential(nn.Linear(window_size, hidden_size),
+ nn.LeakyReLU(0.1, inplace=True))
+
+ # Build residual blocks
+ res_blocks = []
+ for _ in range(num_blocks):
+ res_blocks.append(
+ SmoothNetResBlock(in_channels=hidden_size,
+ hidden_channels=res_hidden_size,
+ dropout=dropout))
+ self.res_blocks = nn.Sequential(*res_blocks)
+
+ # Build decoder layers
+ self.decoder = nn.Linear(hidden_size, output_size)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward function."""
+ N, C, T = x.shape
+ num_windows = T - self.window_size + 1
+
+ assert T >= self.window_size, (
+ 'Input sequence length must be no less than the window size. ',
+ f'Got x.shape[2]=={T} and window_size=={self.window_size}')
+
+ # Unfold x to obtain input sliding windows
+ # [N, C, num_windows, window_size]
+ x = x.unfold(2, self.window_size, 1)
+
+ # Forward layers
+ x = self.encoder(x)
+ x = self.res_blocks(x)
+ x = self.decoder(x) # [N, C, num_windows, output_size]
+
+ # Accumulate output ensembles
+ out = x.new_zeros(N, C, T)
+ count = x.new_zeros(T)
+
+ for t in range(num_windows):
+ out[..., t:t + self.output_size] += x[:, :, t]
+ count[t:t + self.output_size] += 1.0
+
+ return out.div(count)
+
+
+@POST_PROCESSING.register_module(name=['SmoothNetFilter', 'smoothnet'])
+class SmoothNetFilter:
+ """Apply SmoothNet filter.
+ "SmoothNet: A Plug-and-Play Network for Refining Human Poses in Videos",
+ arXiv'2021. More details can be found in the `paper
+ `__ .
+ Args:
+ window_size (int): The size of the filter window. It's also the
+ window_size of SmoothNet model.
+ output_size (int): The output window size of SmoothNet model.
+ checkpoint (str): The checkpoint file of the pretrained SmoothNet
+ model. Please note that `checkpoint` should be matched with
+ `window_size` and `output_size`.
+ hidden_size (int): SmoothNet argument. See :class:`SmoothNet` for
+ details. Default: 512
+ hidden_res_size (int): SmoothNet argument. See :class:`SmoothNet`
+ for details. Default: 256
+ num_blocks (int): SmoothNet argument. See :class:`SmoothNet` for
+ details. Default: 3
+ device (str): Device for model inference. Default: 'cpu'
+ root_index (int, optional): If not None, relative keypoint coordinates
+ will be calculated as the SmoothNet input, by centering the
+ keypoints around the root point. The model output will be
+ converted back to absolute coordinates. Default: None
+ """
+ def __init__(
+ self,
+ window_size: int,
+ output_size: int,
+ checkpoint: Optional[str] = None,
+ hidden_size: int = 512,
+ res_hidden_size: int = 512,
+ num_blocks: int = 5,
+ device: str = 'cpu',
+ ):
+ super(SmoothNetFilter, self).__init__()
+ self.window_size = window_size
+ self.device = device
+ self.smoothnet = SmoothNet(window_size, output_size, hidden_size,
+ res_hidden_size, num_blocks)
+ self.smoothnet.to(device)
+ if checkpoint:
+ load_checkpoint(self.smoothnet,
+ checkpoint,
+ map_location=self.device)
+ self.smoothnet.eval()
+
+ for p in self.smoothnet.parameters():
+ p.requires_grad_(False)
+
+ def __call__(self, x: np.ndarray):
+ x_type = 'tensor'
+ if not isinstance(x, torch.Tensor):
+ x_type = 'array'
+
+ assert x.ndim == 3, ('Input should be an array with shape [T, K, C]'
+ f', but got invalid shape {x.shape}')
+
+ T, K, C = x.shape
+
+ assert C == 3 or C == 6 or C == 9
+
+ if T < self.window_size:
+ # Skip smoothing if the input length is less than the window size
+ smoothed = x
+ else:
+ if x_type == 'array':
+ dtype = x.dtype
+
+ # Convert to tensor and forward the model
+ with torch.no_grad():
+ if x_type == 'array':
+ x = torch.tensor(x,
+ dtype=torch.float32,
+ device=self.device)
+ if C == 9:
+ input_type = 'matrix'
+ x = rotmat_to_rot6d(x.reshape(-1, 3, 3)).reshape(T, K, -1)
+ elif C == 3:
+ input_type = 'axis_angles'
+ x = rotmat_to_rot6d(aa_to_rotmat(x.reshape(-1,
+ 3))).reshape(
+ T, K, -1)
+ else:
+ input_type = 'rotation_6d'
+ x = x.view(1, T, -1).permute(0, 2, 1) # to [1, KC, T]
+ smoothed = self.smoothnet(x) # in shape [1, KC, T]
+
+ # Convert model output back to input shape and format
+ smoothed = smoothed.permute(0, 2, 1).view(T, K, -1) # to [T, K, C]
+
+ if input_type == 'matrix':
+ smoothed = rot6d_to_rotmat(smoothed.reshape(-1, 6)).reshape(
+ T, K, C)
+ elif input_type == 'axis_angles':
+ smoothed = rotmat_to_aa(
+ rot6d_to_rotmat(smoothed.reshape(-1, 6))).reshape(T, K, C)
+
+ if x_type == 'array':
+ smoothed = smoothed.cpu().numpy().astype(
+ dtype) # to numpy.ndarray
+
+ return smoothed
diff --git a/detrsmpl/core/post_processing/speed_up/__init__.py b/detrsmpl/core/post_processing/speed_up/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/core/post_processing/speed_up/deciwatch.py b/detrsmpl/core/post_processing/speed_up/deciwatch.py
new file mode 100644
index 0000000000000000000000000000000000000000..631bf4c0f88eb57d0af8c819c2923c23614dc59b
--- /dev/null
+++ b/detrsmpl/core/post_processing/speed_up/deciwatch.py
@@ -0,0 +1,716 @@
+import copy
+import math
+from typing import Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from mmcv.runner import load_checkpoint
+from torch import Tensor, nn
+
+from detrsmpl.utils.transforms import (
+ aa_to_rotmat,
+ rot6d_to_rotmat,
+ rotmat_to_aa,
+ rotmat_to_rot6d,
+)
+from ..builder import POST_PROCESSING
+
+
+@POST_PROCESSING.register_module(name=['DeciWatchPostProcessing', 'deciwatch'])
+class DeciWatchPostProcessing:
+ """DeciWatchFilter lib is from: https://arxiv.org/abs/2203.08713.
+
+ Args:
+ interval (int): The interval of Visible frames.
+ slide_window_q (int): frames per slide window contains + 1.
+ checkpoint (str): model checkpoint path
+ device (Union[torch.device, str], optional):
+ You can pass a str or torch.device for cpu or gpu render.
+ Defaults to 'cpu'.
+
+ Returns:
+ smoothed poses (np.ndarray, torch.tensor)
+ """
+
+ def __init__(self, interval, slide_window_q, checkpoint, device=None):
+ super(DeciWatchPostProcessing, self).__init__()
+ self.interval = interval
+ self.slide_window_q = slide_window_q
+ self.slide_window_size = self.slide_window_q * self.interval + 1
+ self.device = device
+
+ self.input_dimension = 24 * 6
+
+ self.model = DeciWatch(sample_interval=self.interval).to(self.device)
+
+ self.checkpoint_path = checkpoint
+
+ print(f'load checkpoint from local path: {self.checkpoint_path}')
+ load_checkpoint(
+ self.model, self.checkpoint_path, map_location=self.device)
+
+ def __call__(self, x=None):
+ # x.shape: [t,24,3]
+ seq_len = x.shape[0]
+ assert seq_len > self.slide_window_size
+ assert x.shape[1:] == (24, 3, 3) or x.shape[1:] == (
+ self.input_dimension) or x.shape[1:] == (24, 3)
+
+ if x.shape[1:] == (24, 3, 3):
+ input_type = 'matrix'
+ x = torch.tensor(x).to(self.device)
+ x = rotmat_to_rot6d(x).reshape(-1, self.input_dimension)
+ elif x.shape[1:] == (24, 3):
+ input_type = 'axis_angles'
+ x = torch.tensor(x).to(self.device)
+ x = rotmat_to_rot6d(aa_to_rotmat(x.reshape(-1, 3))).reshape(
+ -1, self.input_dimension)
+ else:
+ x = torch.tensor(x).to(self.device)
+ x = x.reshape(-1, self.input_dimension)
+ input_type = 'rotation_6d'
+
+ input = x.clone()
+
+ slide_window_x = torch.as_strided(
+ input, ((seq_len - self.slide_window_size) // (self.interval) + 1,
+ self.slide_window_size, self.input_dimension),
+ (self.interval * self.input_dimension, self.input_dimension, 1),
+ storage_offset=0).reshape(-1, self.slide_window_size,
+ self.input_dimension)
+
+ smoothed_len = (
+ seq_len - self.slide_window_size
+ ) // self.interval * self.interval + self.slide_window_size
+
+ with torch.no_grad():
+ smooth_poses, _ = self.model(slide_window_x, self.device)
+
+ output_poses = [[] for i in range(smoothed_len)]
+
+ for i in range(smooth_poses.shape[0]):
+ for j in range(self.slide_window_size):
+ output_poses[i * self.interval + j].append(smooth_poses[i,
+ j, :])
+
+ smooth_poses = torch.cat(
+ (smooth_poses[:, :self.slide_window_size - 1, :].reshape(
+ -1, self.input_dimension), smooth_poses[-1, -1, :].reshape(
+ -1, self.input_dimension)),
+ dim=0)
+
+ for i in range(smoothed_len):
+ output_poses[i] = torch.stack(output_poses[i]).mean(0)
+
+ output_poses = torch.stack(output_poses)
+
+ if smoothed_len < seq_len:
+ output_poses = torch.cat((output_poses, x[smoothed_len:, :]),
+ dim=0)
+
+ if input_type == 'matrix':
+ output_poses = rot6d_to_rotmat(output_poses.reshape(
+ -1, 6)).reshape(-1, 24, 3, 3)
+ elif input_type == 'axis_angles':
+ output_poses = rotmat_to_aa(
+ rot6d_to_rotmat(output_poses.reshape(-1,
+ 6))).reshape(-1, 24, 3)
+
+ return output_poses
+
+
+class PositionEmbeddingSine_1D(nn.Module):
+ """This is a more standard version of the position embedding, very similar
+ to the one used by the Attention is all you need paper, generalized to work
+ on images."""
+
+ def __init__(self,
+ num_pos_feats=64,
+ temperature=10000,
+ normalize=True,
+ scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError('normalize should be True if scale is passed')
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, B, L):
+
+ position = torch.arange(0, L, dtype=torch.float32).unsqueeze(0)
+ position = position.repeat(B, 1)
+
+ if self.normalize:
+ eps = 1e-6
+ position = position / (position[:, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32)
+ dim_t = self.temperature**(2 * (torch.div(dim_t, 1)) /
+ self.num_pos_feats)
+
+ pe = torch.zeros(B, L, self.num_pos_feats * 2)
+ pe[:, :, 0::2] = torch.sin(position[:, :, None] / dim_t)
+ pe[:, :, 1::2] = torch.cos(position[:, :, None] / dim_t)
+
+ pe = pe.permute(1, 0, 2)
+
+ return pe
+
+
+class DeciWatch(nn.Module):
+ """Apply DeciWatch framework for 10x efficiency.
+ "DeciWatch: A Simple Baseline for 10× Efficient 2D and 3D Pose Estimation",
+ arXiv'2022. More details can be found in the `paper
+ ` .
+ Args:
+ input_dim (int): The size of input spatial dimension,
+ e.g., 15*2 for 2d pose on the jhmdb dataset
+ sample_interval (int): DeciWatch argument. See :class:`DeciWatch`
+ for details. The intervals of the uniform sampling.
+ The sampling ratio is: 1/sample_interval. Default: 10
+ encoder_hidden_dim (int): DeciWatch argument. See :class:`DeciWatch`
+ for details. Hidden dimension in the encoder. Default: 64
+ decoder_hidden_dim (int): DeciWatch argument. See :class:`DeciWatch`
+ for details. Hidden dimension in the decoder. Default: 64
+ dropout (float): DeciWatch argument. See :class:`DeciWatch`
+ for details. dropout probability. Default: 0.1
+ nheads (int): DeciWatch argument. See :class:`DeciWatch`
+ for details. Default: 4
+ dim_feedforward (int): DeciWatch argument. See :class:`DeciWatch`
+ for details. Dimension of feed forward layers.
+ enc_layers (int): DeciWatch argument. See :class:`DeciWatch`
+ for details. Layers of the encoder. Default: 5
+ dec_layers (int): DeciWatch argument. See :class:`DeciWatch`
+ for details. Layers of the encoder. Default: 5
+ activation (str): DeciWatch argument. See :class:`DeciWatch`
+ for details. Activation function in deciwatch.
+ Default: 'leaky_relu'
+ pre_norm (bool): DeciWatch argument. See :class:`DeciWatch`
+ for details. Whether to normalize before positional embedding.
+ Default: False
+ """
+
+ def __init__(self,
+ input_dim=24 * 6,
+ sample_interval=10,
+ encoder_hidden_dim=16,
+ decoder_hidden_dim=16,
+ dropout=0.1,
+ nheads=4,
+ dim_feedforward=256,
+ enc_layers=3,
+ dec_layers=3,
+ activation='leaky_relu',
+ pre_norm=False):
+ super(DeciWatch, self).__init__()
+ self.pos_embed_dim = encoder_hidden_dim
+ self.pos_embed = self.build_position_encoding(self.pos_embed_dim)
+
+ self.sample_interval = sample_interval
+
+ self.deciwatch_par = {
+ 'input_dim': input_dim,
+ 'encoder_hidden_dim': encoder_hidden_dim,
+ 'decoder_hidden_dim': decoder_hidden_dim,
+ 'dropout': dropout,
+ 'nheads': nheads,
+ 'dim_feedforward': dim_feedforward,
+ 'enc_layers': enc_layers,
+ 'dec_layers': dec_layers,
+ 'activation': activation,
+ 'pre_norm': pre_norm
+ }
+
+ self.transformer = build_model(self.deciwatch_par)
+
+ def build_position_encoding(self, pos_embed_dim):
+ N_steps = pos_embed_dim // 2
+ position_embedding = PositionEmbeddingSine_1D(N_steps, normalize=True)
+ return position_embedding
+
+ def generate_unifrom_mask(self, L, sample_interval=10):
+ # 1 unseen 0 see
+
+ seq_len = L
+ if (seq_len - 1) % sample_interval != 0:
+ raise Exception(
+ 'The following equation should be satisfied: [Window size] \
+ = [sample interval] * Q + 1, where Q is an integer.')
+
+ sample_mask = np.ones(seq_len, dtype=np.int32)
+ sample_mask[::sample_interval] = 0
+
+ encoder_mask = sample_mask
+ decoder_mask = np.array([0] * L, dtype=np.int32)
+
+ return torch.tensor(encoder_mask), torch.tensor(decoder_mask)
+
+ def seqence_interpolation(self, motion, rate):
+
+ seq_len = motion.shape[-1]
+ indice = torch.arange(seq_len, dtype=int)
+ chunk = torch.div(indice, rate).type(torch.long)
+ remain = indice % rate
+
+ prev = motion[:, :, chunk * rate]
+
+ next = torch.cat([
+ motion[:, :, (chunk[:-1] + 1) * rate], motion[:, :, -1, np.newaxis]
+ ], -1)
+ remain = remain.to(motion.device)
+
+ interpolate = (prev / rate * (rate - remain)) + (next / rate * remain)
+
+ return interpolate
+
+ def forward(self, sequence, device):
+ B, L, C = sequence.shape
+ seq = sequence.permute(0, 2, 1) # B,C,L
+
+ encoder_mask, decoder_mask = self.generate_unifrom_mask(
+ L, sample_interval=self.sample_interval)
+ encoder_mask = encoder_mask.to(seq.device)
+ decoder_mask = decoder_mask.to(seq.device)
+
+ self.input_seq = seq * (1 - encoder_mask.int())
+ self.input_seq_interp = self.seqence_interpolation(
+ self.input_seq, self.sample_interval)
+ # self.input_seq=self.input_seq.reshape(1,1,-1)
+ self.encoder_mask = encoder_mask.unsqueeze(0).repeat(B, 1).to(device)
+ self.decoder_mask = decoder_mask.unsqueeze(0).repeat(B, 1).to(device)
+
+ self.encoder_pos_embed = self.pos_embed(B, L).to(device)
+ self.decoder_pos_embed = self.encoder_pos_embed.clone().to(device)
+
+ self.recover, self.denoise = self.transformer.forward(
+ input_seq=self.input_seq.to(torch.float32),
+ encoder_mask=self.encoder_mask,
+ encoder_pos_embed=self.encoder_pos_embed,
+ input_seq_interp=self.input_seq_interp,
+ decoder_mask=self.decoder_mask,
+ decoder_pos_embed=self.decoder_pos_embed,
+ sample_interval=self.sample_interval,
+ device=device)
+
+ self.recover = self.recover.permute(1, 0, 2).reshape(B, L, C)
+ self.denoise = self.denoise.permute(1, 0, 2).reshape(B, L, C)
+
+ return self.recover, self.denoise
+
+
+class DeciWatchTransformer(nn.Module):
+
+ def __init__(self,
+ input_nc,
+ encoder_hidden_dim=512,
+ decoder_hidden_dim=512,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation='relu',
+ pre_norm=False):
+ super(DeciWatchTransformer, self).__init__()
+
+ self.joints_dim = input_nc
+ # bring in semantic (5 frames) temporal information into tokens
+ self.decoder_embed = nn.Conv1d(
+ self.joints_dim,
+ decoder_hidden_dim,
+ kernel_size=5,
+ stride=1,
+ padding=2)
+
+ self.encoder_embed = nn.Linear(self.joints_dim, encoder_hidden_dim)
+
+ encoder_layer = DeciWatchTransformerEncoderLayer(
+ encoder_hidden_dim, nhead, dim_feedforward, dropout, activation,
+ pre_norm)
+ encoder_norm = nn.LayerNorm(encoder_hidden_dim) if pre_norm else None
+ self.encoder = DeciWatchTransformerEncoder(encoder_layer,
+ num_encoder_layers,
+ encoder_norm)
+
+ decoder_layer = DeciWatchTransformerDecoderLayer(
+ decoder_hidden_dim, nhead, dim_feedforward, dropout, activation,
+ pre_norm)
+ decoder_norm = nn.LayerNorm(decoder_hidden_dim)
+ self.decoder = DeciWatchTransformerDecoder(decoder_layer,
+ num_decoder_layers,
+ decoder_norm)
+
+ self.decoder_joints_embed = nn.Linear(decoder_hidden_dim,
+ self.joints_dim)
+ self.encoder_joints_embed = nn.Linear(encoder_hidden_dim,
+ self.joints_dim)
+
+ # reset parameters
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ self.encoder_hidden_dim = encoder_hidden_dim
+ self.decoder_hidden_dim = decoder_hidden_dim
+
+ self.nhead = nhead
+
+ def _generate_square_subsequent_mask(self, sz):
+ mask = torch.triu(torch.ones(sz, sz), 1)
+ mask = mask.masked_fill(mask == 1, float('-inf'))
+ return mask
+
+ def interpolate_embedding(self, input, rate):
+
+ tmp = input.clone()
+ seq_len = input.shape[0]
+ indice = torch.arange(seq_len, dtype=int).to(self.device)
+ chunk = torch.div(indice, rate).type(torch.long)
+ remain = indice % rate
+
+ prev = tmp[chunk * rate]
+
+ next = torch.cat([tmp[(chunk[:-1] + 1) * rate], tmp[-1].unsqueeze(0)],
+ dim=0)
+
+ interpolate = (prev / rate * (rate - remain.view(-1, 1, 1))) + (
+ next / rate * remain.view(-1, 1, 1))
+
+ return interpolate
+
+ def forward(self, input_seq, encoder_mask, encoder_pos_embed,
+ input_seq_interp, decoder_mask, decoder_pos_embed,
+ sample_interval, device):
+
+ self.device = device
+
+ # flatten NxCxL to LxNxC
+ bs, c, _ = input_seq.shape
+ input_seq = input_seq.permute(2, 0, 1)
+ input_seq_interp = input_seq_interp.permute(2, 0, 1)
+
+ input = input_seq.clone()
+
+ # mask on all sequences:
+ trans_src = self.encoder_embed(input_seq)
+ mem = self.encode(trans_src, encoder_mask, encoder_pos_embed)
+ reco = self.encoder_joints_embed(mem) + input
+
+ interp = self.interpolate_embedding(reco, sample_interval)
+ center = interp.clone()
+ trans_tgt = self.decoder_embed(interp.permute(1, 2,
+ 0)).permute(2, 0, 1)
+
+ output = self.decode(mem, encoder_mask, encoder_pos_embed, trans_tgt,
+ decoder_mask, decoder_pos_embed)
+
+ joints = self.decoder_joints_embed(output) + center
+ return joints, reco
+
+ def encode(self, src, src_mask, pos_embed):
+
+ mask = torch.eye(src.shape[0]).bool().to(src.device)
+ memory = self.encoder(
+ src, mask=mask, src_key_padding_mask=src_mask, pos=pos_embed)
+
+ return memory
+
+ def decode(self, memory, memory_mask, memory_pos, tgt, tgt_mask, tgt_pos):
+ hs = self.decoder(
+ tgt,
+ memory,
+ tgt_key_padding_mask=tgt_mask,
+ memory_key_padding_mask=memory_mask,
+ pos=memory_pos,
+ query_pos=tgt_pos)
+ return hs
+
+
+class DeciWatchTransformerEncoder(nn.Module):
+
+ def __init__(self, encoder_layer, num_layers, norm=None):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(self,
+ src,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ output = src
+
+ for layer in self.layers:
+ output = layer(
+ output,
+ src_mask=mask,
+ src_key_padding_mask=src_key_padding_mask,
+ pos=pos)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+
+class DeciWatchTransformerDecoder(nn.Module):
+
+ def __init__(self, decoder_layer, num_layers, norm=None):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ output = tgt
+
+ for layer in self.layers:
+ output = layer(
+ output,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ pos=pos,
+ query_pos=query_pos)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+
+class DeciWatchTransformerEncoderLayer(nn.Module):
+
+ def __init__(self,
+ encoder_hidden_dim,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation='relu',
+ pre_norm=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(
+ encoder_hidden_dim, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(encoder_hidden_dim, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, encoder_hidden_dim)
+
+ self.norm1 = nn.LayerNorm(encoder_hidden_dim)
+ self.norm2 = nn.LayerNorm(encoder_hidden_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.pre_norm = pre_norm
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ q = k = self.with_pos_embed(src, pos)
+ src2 = self.self_attn(
+ q,
+ k,
+ value=src,
+ attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask.bool())[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward_pre(self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ src2 = self.norm1(src)
+ q = k = self.with_pos_embed(src2, pos)
+ src2 = self.self_attn(
+ q,
+ k,
+ value=src2,
+ attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask)[0]
+ src = src + self.dropout1(src2)
+ src2 = self.norm2(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+ src = src + self.dropout2(src2)
+ return src
+
+ def forward(self,
+ src,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None):
+ if self.pre_norm:
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+
+class DeciWatchTransformerDecoderLayer(nn.Module):
+
+ def __init__(self,
+ decoder_hidden_dim,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation='relu',
+ pre_norm=False):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(
+ decoder_hidden_dim, nhead, dropout=dropout)
+ self.multihead_attn = nn.MultiheadAttention(
+ decoder_hidden_dim, nhead, dropout=dropout)
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(decoder_hidden_dim, dim_feedforward)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, decoder_hidden_dim)
+
+ self.norm1 = nn.LayerNorm(decoder_hidden_dim)
+ self.norm2 = nn.LayerNorm(decoder_hidden_dim)
+ self.norm3 = nn.LayerNorm(decoder_hidden_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+ self.pre_norm = pre_norm
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward_post(self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(
+ q,
+ k,
+ value=tgt,
+ attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask.bool())[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask.bool())[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward_pre(self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(
+ q,
+ k,
+ value=tgt2,
+ attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.multihead_attn(
+ query=self.with_pos_embed(tgt2, query_pos),
+ key=self.with_pos_embed(memory, pos),
+ value=memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask.bool())[0]
+
+ tgt = tgt + self.dropout2(tgt2)
+ tgt2 = self.norm3(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout3(tgt2)
+ return tgt
+
+ def forward(self,
+ tgt,
+ memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+ if self.pre_norm:
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
+ tgt_key_padding_mask,
+ memory_key_padding_mask, pos, query_pos)
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
+ tgt_key_padding_mask, memory_key_padding_mask,
+ pos, query_pos)
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def build_model(args):
+ return DeciWatchTransformer(
+ input_nc=args['input_dim'],
+ decoder_hidden_dim=args['decoder_hidden_dim'],
+ encoder_hidden_dim=args['encoder_hidden_dim'],
+ dropout=args['dropout'],
+ nhead=args['nheads'],
+ dim_feedforward=args['dim_feedforward'],
+ num_encoder_layers=args['enc_layers'],
+ num_decoder_layers=args['dec_layers'],
+ activation=args['activation'],
+ pre_norm=args['pre_norm'],
+ )
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string."""
+ if activation == 'relu':
+ return F.relu
+ if activation == 'gelu':
+ return F.gelu
+ if activation == 'glu':
+ return F.glu
+ if activation == 'leaky_relu':
+ return F.leaky_relu
+ raise RuntimeError(F'activation should be relu/gelu, not {activation}.')
diff --git a/detrsmpl/core/renderer/__init__.py b/detrsmpl/core/renderer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/core/renderer/matplotlib3d_renderer.py b/detrsmpl/core/renderer/matplotlib3d_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e607227778225f3ecebff33c5672bd5ffdfc2750
--- /dev/null
+++ b/detrsmpl/core/renderer/matplotlib3d_renderer.py
@@ -0,0 +1,408 @@
+import io
+import os
+import shutil
+from pathlib import Path
+from typing import Iterable, List, Optional, Union
+
+import cv2
+import mmcv
+import numpy as np
+from matplotlib import pyplot as plt
+from matplotlib.lines import Line2D
+from mpl_toolkits.mplot3d import Axes3D
+
+from detrsmpl.core.conventions.cameras.convert_convention import \
+ enc_camera_convention # prevent yapf isort conflict
+from detrsmpl.utils.demo_utils import get_different_colors
+from detrsmpl.utils.ffmpeg_utils import images_to_video
+from detrsmpl.utils.path_utils import check_path_suffix
+
+
+class Axes3dBaseRenderer(object):
+ """Base renderer."""
+ def init_camera(self,
+ cam_elev_angle=10,
+ cam_elev_speed=0.0,
+ cam_hori_angle=45,
+ cam_hori_speed=0.5):
+ """Initiate the route of camera with arguments.
+
+ Args:
+ cam_elev_angle (int, optional):
+ The pitch angle where camera starts.
+ Defaults to 10.
+ cam_elev_speed (float, optional):
+ The pitch angle camera steps in one frame.
+ It will go back and forth between -30 and 30 degree.
+ Defaults to 0.0.
+ cam_hori_angle (int, optional):
+ The yaw angle where camera starts. Defaults to 45.
+ cam_hori_speed (float, optional):
+ The yaw angle camera steps in one frame.
+ It will go back and forth between 0 and 90 degree.
+ Defaults to 0.5.
+ """
+ self.cam_elevation_args = [cam_elev_angle, cam_elev_speed]
+ self.cam_horizon_args = [cam_hori_angle, cam_hori_speed]
+ self.if_camera_init = True
+
+ def _get_camera_vector_list(self, frame_number):
+ """Generate self.cam_vector_list according to hori and elev arguments.
+
+ Args:
+ frame_number (int):
+ Number of frames.
+
+ Returns:
+ List[List[float, float]]:
+ A list of float vectors.
+ """
+ self.cam_vector_list = [
+ [self.cam_elevation_args[0], self.cam_horizon_args[0]],
+ ]
+ ele_sign = 1
+ hor_sign = 1
+ for _ in range(frame_number - 1):
+ new_ele_angle = ele_sign * self.cam_elevation_args[
+ 1] + self.cam_vector_list[-1][0]
+ # if elevation angle out of range, go backwards
+ if new_ele_angle <= self.cam_elevation_args[
+ 1] or new_ele_angle >= 30:
+ ele_sign = (-1) * ele_sign
+ new_ele_angle = (ele_sign * self.cam_elevation_args[1] +
+ self.cam_vector_list[-1][0])
+ new_hor_angle = (hor_sign * self.cam_horizon_args[1] +
+ self.cam_vector_list[-1][1])
+ # if horizon angle out of range, go backwards
+ if new_hor_angle >= 90 - 2 * self.cam_horizon_args[
+ 1] or new_hor_angle <= 2 * self.cam_horizon_args[1]:
+ hor_sign = (-1) * hor_sign
+ new_hor_angle = (hor_sign * self.cam_horizon_args[1] +
+ self.cam_vector_list[-1][1])
+ self.cam_vector_list.append([new_ele_angle, new_hor_angle])
+ return self.cam_vector_list
+
+ @staticmethod
+ def _get_visual_range(points: np.ndarray) -> np.ndarray:
+ """Calculate the visual range according to the input points. It make
+ sure that no point is absent.
+
+ Args:
+ points (np.ndarray):
+ An array of 3D points.
+ Axis at the last dim.
+
+ Returns:
+ np.ndarray:
+ An array in shape [3, 2].
+ It marks the lower bound and the upper bound
+ along each axis.
+ """
+ axis_num = points.shape[-1]
+ axis_stat = np.zeros(shape=[axis_num, 4])
+ for axis_index in range(axis_num):
+ axis_data = points[..., axis_index]
+ axis_min = np.min(axis_data)
+ axis_max = np.max(axis_data)
+ axis_mid = (axis_min + axis_max) / 2.0
+ axis_span = axis_max - axis_min
+ axis_stat[axis_index] = np.asarray(
+ (axis_min, axis_max, axis_mid, axis_span))
+ max_span = np.max(axis_stat[:, 3])
+ visual_range = np.zeros(shape=[axis_num, 2])
+ for axis_index in range(axis_num):
+ visual_range[axis_index, 0] =\
+ axis_stat[axis_index, 2] - max_span/2.0
+ visual_range[axis_index, 1] =\
+ axis_stat[axis_index, 2] + max_span/2.0
+ return visual_range
+
+ def _draw_scene(self,
+ visual_range,
+ axis_len=1.0,
+ cam_elev_angle=10,
+ cam_hori_angle=45):
+ """Draw an empty scene according to visual range and camera vector.
+
+ Args:
+ visual_range (np.ndarray):
+ Return value of _get_visual_range().
+ axis_len (float, optional):
+ The length of every axis.
+ Defaults to 1.0.
+ cam_elev_angle (int, optional):
+ Pitch angle of the camera.
+ Defaults to 10.
+ cam_hori_angle (int, optional):
+ Yaw angle of the camera.
+ Defaults to 45.
+
+ Returns:
+ list: Figure and Axes3D
+ """
+ fig = plt.figure()
+ ax = Axes3D(fig, auto_add_to_figure=False)
+ fig.add_axes(ax)
+ ax.set_xlim(*visual_range[0])
+ ax.set_ylim(*visual_range[1])
+ ax.set_zlim(*visual_range[2])
+ ax.view_init(cam_elev_angle, cam_hori_angle)
+ mid_point = [
+ np.average(visual_range[0]),
+ np.average(visual_range[1]),
+ np.average(visual_range[2]),
+ ]
+ # draw axis
+ zero_point = np.array([0, 0, 0])
+ x_axis = np.array([(visual_range[0][1] - mid_point[0]) * axis_len, 0,
+ 0])
+ y_axis = np.array(
+ [0, (visual_range[1][1] - mid_point[1]) * axis_len, 0])
+ z_axis = np.array(
+ [0, 0, (visual_range[2][1] - mid_point[2]) * axis_len])
+ ax = _plot_line_on_fig(ax, zero_point, x_axis, 'r')
+ ax = _plot_line_on_fig(ax, zero_point, y_axis, 'g')
+ ax = _plot_line_on_fig(ax, zero_point, z_axis, 'b')
+ return fig, ax
+
+
+class Axes3dJointsRenderer(Axes3dBaseRenderer):
+ """Render of joints."""
+ def __init__(self):
+ self.if_camera_init = False
+ self.cam_vector_list = None
+ self.if_connection_setup = False
+ self.if_frame_updated = False
+ self.temp_path = ''
+
+ def set_connections(self, limbs_connection, limbs_palette):
+ """set body limbs."""
+ self.limbs_connection = limbs_connection
+ self.limbs_palette = limbs_palette
+ self.if_connection_setup = True
+
+ def render_kp3d_to_video(
+ self,
+ keypoints_np: np.ndarray,
+ output_path: Optional[str] = None,
+ convention='opencv',
+ fps: Union[float, int] = 30,
+ resolution: Iterable[int] = (720, 720),
+ visual_range: Iterable[int] = (-100, 100),
+ frame_names: Optional[List[str]] = None,
+ disable_limbs: bool = False,
+ return_array: bool = False,
+ ) -> None:
+ """Render 3d keypoints to a video.
+
+ Args:
+ keypoints_np (np.ndarray): shape of input array should be
+ (f * n * J * 3).
+ output_path (str): output video path or frame folder.
+ sign (Iterable[int], optional): direction of the axis.
+ Defaults to (1, 1, 1).
+ axis (str, optional): axis convention.
+ Defaults to 'xzy'.
+ fps (Union[float, int], optional): fps.
+ Defaults to 30.
+ resolution (Iterable[int], optional): (width, height) of
+ output video.
+ Defaults to (720, 720).
+ visual_range (Iterable[int], optional): range of axis value.
+ Defaults to (-100, 100).
+ frame_names (Optional[List[str]], optional): List of string
+ for frame title, no title if None. Defaults to None.
+ disable_limbs (bool, optional): whether need to disable drawing
+ limbs.
+ Defaults to False.
+ Returns:
+ None.
+ """
+ assert self.if_camera_init is True
+ assert self.if_connection_setup is True
+ sign, axis = enc_camera_convention(convention)
+ if output_path is not None:
+ if check_path_suffix(output_path, ['.mp4', '.gif']):
+ self.temp_path = os.path.join(
+ Path(output_path).parent,
+ Path(output_path).name + '_output_temp')
+ mmcv.mkdir_or_exist(self.temp_path)
+ print('make dir', self.temp_path)
+ self.remove_temp = True
+ else:
+ self.temp_path = output_path
+ self.remove_temp = False
+ else:
+ self.temp_path = None
+ keypoints_np = _set_new_pose(keypoints_np, sign, axis)
+ if not self.if_frame_updated:
+ if self.cam_vector_list is None:
+ self._get_camera_vector_list(
+ frame_number=keypoints_np.shape[0])
+ assert len(self.cam_vector_list) == keypoints_np.shape[0]
+ if visual_range is None:
+ visual_range = self._get_visual_range(keypoints_np)
+ else:
+ visual_range = np.asarray(visual_range)
+ if len(visual_range.shape) == 1:
+ one_dim_visual_range = np.expand_dims(visual_range, 0)
+ visual_range = one_dim_visual_range.repeat(3, axis=0)
+ image_array = self._export_frames(keypoints_np, resolution,
+ visual_range, frame_names,
+ disable_limbs, return_array)
+ self.if_frame_updated = True
+
+ if output_path is not None:
+ if check_path_suffix(output_path, '.mp4'):
+ images_to_video(self.temp_path,
+ output_path,
+ img_format='frame_%06d.png',
+ fps=fps)
+ return image_array
+
+ def _export_frames(self, keypoints_np, resolution, visual_range,
+ frame_names, disable_limbs, return_array):
+ """Write output/temp images."""
+ image_array = []
+ for frame_index in range(keypoints_np.shape[0]):
+ keypoints_frame = keypoints_np[frame_index]
+ cam_ele, cam_hor = self.cam_vector_list[frame_index]
+ fig, ax = \
+ self._draw_scene(visual_range=visual_range, axis_len=0.5,
+ cam_elev_angle=cam_ele,
+ cam_hori_angle=cam_hor)
+ # draw limbs
+ num_person = keypoints_frame.shape[0]
+ for person_index, keypoints_person in enumerate(keypoints_frame):
+ if num_person >= 2:
+ self.limbs_palette = get_different_colors(
+ num_person)[person_index].reshape(-1, 3)
+ if not disable_limbs:
+ for part_name, limbs in self.limbs_connection.items():
+ if part_name == 'body':
+ linewidth = 2
+ else:
+ linewidth = 1
+ if isinstance(self.limbs_palette, np.ndarray):
+ color = self.limbs_palette.astype(
+ np.int32).reshape(-1, 3)
+ elif isinstance(self.limbs_palette, dict):
+ color = np.array(
+ self.limbs_palette[part_name]).astype(np.int32)
+ for limb_index, limb in enumerate(limbs):
+ limb_index = min(limb_index, len(color) - 1)
+
+ ax = _plot_line_on_fig(
+ ax,
+ keypoints_person[limb[0]],
+ keypoints_person[limb[1]],
+ color=np.array(color[limb_index]) / 255.0,
+ linewidth=linewidth)
+ scatter_points_index = list(
+ set(
+ np.array(self.limbs_connection['body']).reshape(
+ -1).tolist()))
+ ax.scatter(keypoints_person[scatter_points_index, 0],
+ keypoints_person[scatter_points_index, 1],
+ keypoints_person[scatter_points_index, 2],
+ c=np.array([0, 0, 0]).reshape(1, -1),
+ s=10,
+ marker='o')
+ if num_person >= 2:
+ ax.xaxis.set_ticklabels([])
+ ax.yaxis.set_ticklabels([])
+ ax.zaxis.set_ticklabels([])
+ labels = []
+ custom_lines = []
+ for person_index in range(num_person):
+ color = get_different_colors(
+ num_person)[person_index].reshape(1, 3) / 255.0
+ custom_lines.append(
+ Line2D([0], [0],
+ linestyle='-',
+ color=color[0],
+ lw=2,
+ marker='',
+ markeredgecolor='k',
+ markeredgewidth=.1,
+ markersize=20))
+ labels.append(f'person_{person_index + 1}')
+ ax.legend(
+ handles=custom_lines,
+ labels=labels,
+ loc='upper left',
+ )
+ plt.close('all')
+ rgb_mat = _get_cv2mat_from_buf(fig)
+ resized_mat = cv2.resize(rgb_mat, resolution)
+ if frame_names is not None:
+ cv2.putText(
+ resized_mat, str(frame_names[frame_index]),
+ (resolution[0] // 10, resolution[1] // 10),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5 * resolution[0] / 500,
+ np.array([255, 255, 255]).astype(np.int32).tolist(), 2)
+ if self.temp_path is not None:
+ frame_path = os.path.join(self.temp_path,
+ 'frame_%06d.png' % frame_index)
+ cv2.imwrite(frame_path, resized_mat)
+ if return_array:
+ image_array.append(resized_mat[None])
+ if return_array:
+ image_array = np.concatenate(image_array)
+ return image_array
+ else:
+ return None
+
+ def __del__(self):
+ """remove temp images."""
+ self.remove_temp_frames()
+
+ def remove_temp_frames(self):
+ """remove temp images."""
+ if self.temp_path is not None:
+ if Path(self.temp_path).is_dir() and self.remove_temp:
+ shutil.rmtree(self.temp_path)
+
+
+def _set_new_pose(pose_np, sign, axis):
+ """set new pose with axis convention."""
+ target_sign = [-1, 1, -1]
+ target_axis = ['x', 'z', 'y']
+
+ pose_rearrange_axis_result = pose_np.copy()
+ for axis_index, axis_name in enumerate(target_axis):
+ src_axis_index = axis.index(axis_name)
+ pose_rearrange_axis_result[..., axis_index] = \
+ pose_np[..., src_axis_index]
+
+ for dim_index in range(pose_rearrange_axis_result.shape[-1]):
+ pose_rearrange_axis_result[
+ ..., dim_index] = sign[dim_index] / target_sign[
+ dim_index] * pose_rearrange_axis_result[..., dim_index]
+ return pose_rearrange_axis_result
+
+
+def _plot_line_on_fig(ax,
+ point1_location,
+ point2_location,
+ color,
+ linewidth=1):
+ """Draw line on fig with matplotlib."""
+ ax.plot([point1_location[0], point2_location[0]],
+ [point1_location[1], point2_location[1]],
+ [point1_location[2], point2_location[2]],
+ color=color,
+ linewidth=linewidth)
+ return ax
+
+
+def _get_cv2mat_from_buf(fig, dpi=180):
+ """Get numpy image from IO."""
+ buf = io.BytesIO()
+ fig.savefig(buf, format='png', dpi=dpi)
+ buf.seek(0)
+ img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
+ buf.close()
+ img = cv2.imdecode(img_arr, 1)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return img
diff --git a/detrsmpl/core/renderer/mpr_renderer/__init__.py b/detrsmpl/core/renderer/mpr_renderer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e833984e8c1f456d89a610c6878c3832610b3cb4
--- /dev/null
+++ b/detrsmpl/core/renderer/mpr_renderer/__init__.py
@@ -0,0 +1,6 @@
+"""minimal_pytorch_rasterizer is a CUDA non-differentiable mesh rasterization
+library for pytorch tensors with python bindings.
+
+These codes brought from
+`https://github.com/rmbashirov/minimal_pytorch_rasterizer`.
+"""
diff --git a/detrsmpl/core/renderer/mpr_renderer/camera.py b/detrsmpl/core/renderer/mpr_renderer/camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..a73fad7c5edab58369583259fccd200bbac7a821
--- /dev/null
+++ b/detrsmpl/core/renderer/mpr_renderer/camera.py
@@ -0,0 +1,52 @@
+import numpy as np
+import torch
+
+
+class Pinhole2D:
+ def __init__(self, K=None, fx=None, fy=None, cx=None, cy=None, h=0, w=0):
+ if K is not None:
+ assert fx is None and fy is None and cx is None and cy is None
+ self.fx = K[0, 0]
+ self.fy = K[1, 1]
+ self.cx = K[0, 2]
+ self.cy = K[1, 2]
+ else:
+ assert \
+ fx is not None and fy is not None and \
+ cx is not None and cy is not None
+ self.fx = fx
+ self.fy = fy
+ self.cx = cx
+ self.cy = cy
+ self.h = h
+ self.w = w
+
+ def get_K(self):
+ return np.array([[self.fx, 0, self.cx], [0, self.fy, self.cy],
+ [0, 0, 1]])
+
+ def project_ndc(self, vertices, eps=1e-9):
+ """
+ vertices: torch.Tensor of shape (N, 3), 3 stands for xyz
+ """
+ assert isinstance(vertices, torch.Tensor)
+ assert len(vertices.shape) == 2
+ assert vertices.shape[1] == 3
+ K = torch.tensor(self.get_K(),
+ device=vertices.device,
+ dtype=vertices.dtype)
+
+ # apply intrinsics
+ vertices_ndc = vertices @ K.transpose(1, 0)
+
+ # divide xy by z, leave z unchanged
+ vertices_ndc[:, [0, 1]] /= vertices_ndc[:, [2]] + eps
+
+ # convert x from [0, w) to [-1, 1] range
+ # convert y from [0, h) to [-1, 1] range
+ wh = torch.tensor([self.w, self.h],
+ device=vertices.device,
+ dtype=vertices.dtype).unsqueeze(0)
+ vertices_ndc[:, [0, 1]] = 2 * vertices_ndc[:, [0, 1]] / wh - 1
+
+ return vertices_ndc
diff --git a/detrsmpl/core/renderer/mpr_renderer/cuda/rasterizer.cpp b/detrsmpl/core/renderer/mpr_renderer/cuda/rasterizer.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..fd7402b8b4c6db20f67ed42d485833deda38017f
--- /dev/null
+++ b/detrsmpl/core/renderer/mpr_renderer/cuda/rasterizer.cpp
@@ -0,0 +1,83 @@
+#include
+#include
+
+// CUDA forward declarations
+
+std::vector estimate_normals_cuda(
+ const torch::Tensor& vertices_ndc,
+ const torch::Tensor& faces,
+ const torch::Tensor& vertices,
+ const torch::Tensor& vertices_filter,
+ int h, int w
+);
+
+
+torch::Tensor project_mesh_cuda(
+ const torch::Tensor& vertices_ndc,
+ const torch::Tensor& faces,
+ const torch::Tensor& vertice_values,
+ const torch::Tensor& vertices_filter,
+ int h, int w
+);
+
+// C++ interface
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+void check_equal_dtype(const torch::Tensor& a, const torch::Tensor& b) {
+ TORCH_CHECK(
+ a.dtype() == b.dtype(),
+ "expected equal dtype, got ", a.dtype(), " != ", b.dtype()
+ );
+}
+
+void check_equal_gpuid(const torch::Tensor& a, const torch::Tensor& b) {
+ TORCH_CHECK(
+ a.device().index() == b.device().index(),
+ "expected equal gpu id, got ", a.device().index(), " != ", b.device().index()
+ );
+}
+
+std::vector estimate_normals(
+ const torch::Tensor& vertices_ndc,
+ const torch::Tensor& faces,
+ const torch::Tensor& vertices,
+ const torch::Tensor& vertices_filter,
+ int h, int w
+) {
+ TORCH_CHECK(h > 0, "h expected to be > 0");
+ TORCH_CHECK(w > 0, "w expected to be > 0");
+ CHECK_INPUT(vertices_ndc);
+ CHECK_INPUT(faces);
+ CHECK_INPUT(vertices_filter);
+ return estimate_normals_cuda(
+ vertices_ndc, faces, vertices, vertices_filter,
+ h, w
+ );
+}
+
+torch::Tensor project_mesh(
+ const torch::Tensor& vertices_ndc,
+ const torch::Tensor& faces,
+ const torch::Tensor& vertice_values,
+ const torch::Tensor& vertices_filter,
+ int h, int w
+) {
+ TORCH_CHECK(h > 0, "h expected to be > 0");
+ TORCH_CHECK(w > 0, "w expected to be > 0");
+ CHECK_INPUT(vertices_ndc);
+ CHECK_INPUT(faces);
+ CHECK_INPUT(vertice_values);
+ CHECK_INPUT(vertices_filter);
+ return project_mesh_cuda(
+ vertices_ndc, faces, vertice_values, vertices_filter,
+ h, w
+ );
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("estimate_normals", &estimate_normals, "estimate_normals (CUDA)");
+ m.def("project_mesh", &project_mesh, "project_mesh (CUDA)");
+}
diff --git a/detrsmpl/core/renderer/mpr_renderer/cuda/rasterizer_kernel.cu b/detrsmpl/core/renderer/mpr_renderer/cuda/rasterizer_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..797a78c602b28040d21585b786fb8666f28f70f2
--- /dev/null
+++ b/detrsmpl/core/renderer/mpr_renderer/cuda/rasterizer_kernel.cu
@@ -0,0 +1,577 @@
+/*
+
+There are 2 ways to rasterize triangles that came to mind:
+1) iterate over all pixels (they define CUDA grid), for selected pixel feed all triangles to 1 CUDA block
+2) iterate over all triangels (they define CUDA grid), for selected triangle feed pixels that are bounded by selected triangle to 1 CUDA block
+
+2nd way is implemented here
+*/
+
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#define BLOCK_SIZE 512
+#define BLOCK_SIZE_2D_X 32
+#define BLOCK_SIZE_2D_Y 16
+#define BLOCK_SIZE_3D_X 32
+#define BLOCK_SIZE_3D_Y 8
+#define BLOCK_SIZE_3D_Z 4
+
+// vertices coords:
+// vertices[:, 0]: x
+// vertices[:, 1]: y
+// vertices[:, 2]: z
+
+// 2d tensor axis:
+// 0: yi
+// 1: xi
+
+// 3d tensor axis:
+// 0: zi
+// 1: yi
+// 2: xi
+
+template
+__device__ __forceinline__ scalar_t atomicMinFloat(scalar_t * addr, scalar_t value) {
+ scalar_t old;
+ old = (value >= 0) ? __int_as_float(atomicMin((int *)addr, __float_as_int(value))) :
+ __uint_as_float(atomicMax((unsigned int *)addr, __float_as_uint(value)));
+ return old;
+}
+
+__device__ double atomicMin_double(double* address, double val)
+{
+ unsigned long long int* address_as_ull = (unsigned long long int*) address;
+ unsigned long long int old = *address_as_ull, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(address_as_ull, assumed,
+ __double_as_longlong(fmin(val, __longlong_as_double(assumed))));
+ } while (assumed != old);
+ return __longlong_as_double(old);
+}
+
+// kernel utils
+
+template
+__device__ int lower_bound(const scalar_t* values, const scalar_t value, const int N) {
+ int left = 0;
+ int right = N;
+ int mid;
+ while (right - left > 1) {
+ mid = (left + right) / 2;
+ if (values[mid] < value) {
+ left = mid;
+ } else {
+ right = mid;
+ }
+ }
+ return right;
+}
+
+// kernels
+
+template
+__global__ void rasterize_cuda_kernel(
+ const torch::PackedTensorAccessor32 vertices_ndc,
+ const torch::PackedTensorAccessor32 faces,
+ const torch::PackedTensorAccessor32 vertices_filter,
+ torch::PackedTensorAccessor32 depth,
+ scalar_t* global_face_ndc_inv,
+ int* global_is_bad_face
+) {
+ const int face_indx = blockIdx.x;
+ const int H = depth.size(0);
+ const int W = depth.size(1);
+
+ scalar_t min_x, max_x, min_y, max_y;
+ scalar_t denom;
+
+ __shared__ int vertices_per_thread_x, vertices_per_thread_y;
+ __shared__ int ai, bi, ci;
+ __shared__ bool is_bad_face;
+ __shared__ int min_xi, max_xi, min_yi, max_yi;
+ __shared__ scalar_t face_ndc[9];
+ __shared__ scalar_t face_ndc_inv[9];
+ const scalar_t eps = 1e-5;
+
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
+ ai = faces[face_indx][0];
+ bi = faces[face_indx][1];
+ ci = faces[face_indx][2];
+
+ if (vertices_filter[ai] == 0 || vertices_filter[bi] == 0 || vertices_filter[ci] == 0) {
+ is_bad_face = true;
+ global_is_bad_face[face_indx] = 1;
+ return;
+ }
+
+ face_ndc[0] = vertices_ndc[ai][0]; face_ndc[1] = vertices_ndc[ai][1]; face_ndc[2] = vertices_ndc[ai][2];
+ face_ndc[3] = vertices_ndc[bi][0]; face_ndc[4] = vertices_ndc[bi][1]; face_ndc[5] = vertices_ndc[bi][2];
+ face_ndc[6] = vertices_ndc[ci][0]; face_ndc[7] = vertices_ndc[ci][1]; face_ndc[8] = vertices_ndc[ci][2];
+
+ // negative vertex
+ is_bad_face = false;
+ if (face_ndc[2] < eps || face_ndc[5] < eps || face_ndc[8] < eps) {
+ is_bad_face = true;
+ global_is_bad_face[face_indx] = 1;
+ return;
+ }
+
+ face_ndc_inv[0] = face_ndc[4] - face_ndc[7];
+ face_ndc_inv[1] = face_ndc[6] - face_ndc[3];
+ face_ndc_inv[2] = face_ndc[3] * face_ndc[7] - face_ndc[6] * face_ndc[4];
+ face_ndc_inv[3] = face_ndc[7] - face_ndc[1];
+ face_ndc_inv[4] = face_ndc[0] - face_ndc[6];
+ face_ndc_inv[5] = face_ndc[6] * face_ndc[1] - face_ndc[0] * face_ndc[7];
+ face_ndc_inv[6] = face_ndc[1] - face_ndc[4];
+ face_ndc_inv[7] = face_ndc[3] - face_ndc[0];
+ face_ndc_inv[8] = face_ndc[0] * face_ndc[4] - face_ndc[3] * face_ndc[1];
+
+ denom = (
+ face_ndc[6] * (face_ndc[1] - face_ndc[4]) +
+ face_ndc[0] * (face_ndc[4] - face_ndc[7]) +
+ face_ndc[3] * (face_ndc[7] - face_ndc[1])
+ );
+
+// if (abs(denom) < eps) {
+// is_bad_face = true;
+// global_is_bad_face[face_indx] = 1;
+// return;
+// }
+
+ for (int i = 0; i < 9; ++i) {
+ face_ndc_inv[i] /= denom;
+ }
+
+ for (int i = 0; i < 9; ++i) {
+ global_face_ndc_inv[9 * face_indx + i] = face_ndc_inv[i];
+ }
+
+ global_is_bad_face[face_indx] = 0;
+
+ min_x = min(min(face_ndc[0], face_ndc[3]), face_ndc[6]);
+ min_x = (min_x + 1) / 2 * W; // convert from ndc to img coordinates
+ min_xi = static_cast(floorf(static_cast(min_x)));
+ min_xi = min(max(min_xi, 0), W - 1);
+ max_x = max(max(face_ndc[0], face_ndc[3]), face_ndc[6]);
+ max_x = (max_x + 1) / 2 * W;
+ max_xi = static_cast(ceilf(static_cast(max_x)));
+ max_xi = min(max(max_xi, 0), W - 1);
+
+ min_y = min(min(face_ndc[1], face_ndc[4]), face_ndc[7]);
+ min_y = (min_y + 1) / 2 * H;
+ min_yi = static_cast(floorf(static_cast(min_y)));
+ min_yi = min(max(min_yi, 0), H - 1);
+ max_y = max(max(face_ndc[1], face_ndc[4]), face_ndc[7]);
+ max_y = (max_y + 1) / 2 * H;
+ max_yi = static_cast(ceilf(static_cast(max_y)));
+ max_yi = min(max(max_yi, 0), H - 1);
+
+ vertices_per_thread_x = (max_xi - min_xi) / blockDim.x + 1;
+ vertices_per_thread_y = (max_yi - min_yi) / blockDim.y + 1;
+ }
+ __syncthreads();
+ if (is_bad_face) {
+ return;
+ }
+
+ const int left = min_xi + vertices_per_thread_x * threadIdx.x;
+ const int right = min(left + vertices_per_thread_x, max_xi);
+
+ const int top = min_yi + vertices_per_thread_y * threadIdx.y;
+ const int bottom = min(top + vertices_per_thread_y, max_yi);
+
+ scalar_t x, y, face_z, wa, wb, wc, wsum;
+ for (int i = top; i <= bottom; i++) {
+ for (int j = left; j <= right; j++) {
+ x = 2 * ((scalar_t)j + 0.5) / W - 1;
+ y = 2 * ((scalar_t)i + 0.5) / H - 1;
+
+ // check pixel is inside the face
+ if (((y - face_ndc[1]) * (face_ndc[3] - face_ndc[0]) > (x - face_ndc[0]) * (face_ndc[4] - face_ndc[1])) ||
+ ((y - face_ndc[4]) * (face_ndc[6] - face_ndc[3]) > (x - face_ndc[3]) * (face_ndc[7] - face_ndc[4])) ||
+ ((y - face_ndc[7]) * (face_ndc[0] - face_ndc[6]) > (x - face_ndc[6]) * (face_ndc[1] - face_ndc[7]))) {
+ continue;
+ }
+
+ wa = face_ndc_inv[0] * x + face_ndc_inv[1] * y + face_ndc_inv[2];
+ wb = face_ndc_inv[3] * x + face_ndc_inv[4] * y + face_ndc_inv[5];
+ wc = face_ndc_inv[6] * x + face_ndc_inv[7] * y + face_ndc_inv[8];
+ wsum = wa + wb + wc;
+ wa /= wsum; wb /= wsum; wc /= wsum;
+
+ wa /= face_ndc[2];
+ wb /= face_ndc[5];
+ wc /= face_ndc[8];
+ wsum = wa + wb + wc;
+ wa /= wsum; wb /= wsum; wc /= wsum;
+
+ face_z = wa * face_ndc[2] + wb * face_ndc[5] + wc * face_ndc[8];
+
+ if (sizeof(scalar_t) == sizeof(double)) {
+ atomicMin_double((double*)&depth[i][j], (double)face_z);
+ } else {
+ atomicMinFloat(&depth[i][j], face_z);
+ }
+ }
+ }
+}
+
+
+template
+__global__ void interpolate_cuda_kernel(
+ const torch::PackedTensorAccessor32 vertices_ndc,
+ const torch::PackedTensorAccessor32 faces,
+ const torch::PackedTensorAccessor32 depth,
+ const scalar_t* global_face_ndc_inv,
+ const int* global_is_bad_face,
+ const torch::PackedTensorAccessor32 vertice_values,
+ torch::PackedTensorAccessor32 result
+) {
+ const int face_indx = blockIdx.x;
+
+ if (global_is_bad_face[face_indx]) {
+ return;
+ }
+
+ const int H = depth.size(0);
+ const int W = depth.size(1);
+ const int C = vertice_values.size(1);
+ const scalar_t eps = 1e-5;
+
+ scalar_t min_x, max_x, min_y, max_y;
+ __shared__ int vertices_per_thread_x, vertices_per_thread_y;
+ __shared__ int ai, bi, ci;
+ __shared__ scalar_t face_ndc[9];
+ __shared__ scalar_t face_ndc_inv[9];
+ __shared__ int min_xi, max_xi, min_yi, max_yi;
+
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
+ ai = faces[face_indx][0];
+ bi = faces[face_indx][1];
+ ci = faces[face_indx][2];
+
+ face_ndc[0] = vertices_ndc[ai][0]; face_ndc[1] = vertices_ndc[ai][1]; face_ndc[2] = vertices_ndc[ai][2];
+ face_ndc[3] = vertices_ndc[bi][0]; face_ndc[4] = vertices_ndc[bi][1]; face_ndc[5] = vertices_ndc[bi][2];
+ face_ndc[6] = vertices_ndc[ci][0]; face_ndc[7] = vertices_ndc[ci][1]; face_ndc[8] = vertices_ndc[ci][2];
+
+ for (int i = 0; i < 9; ++i) {
+ face_ndc_inv[i] = global_face_ndc_inv[9 * face_indx + i];
+ }
+
+ min_x = min(min(face_ndc[0], face_ndc[3]), face_ndc[6]);
+ min_x = (min_x + 1) / 2 * W; // convert from ndc to img coordinates
+ min_xi = static_cast(floorf(static_cast(min_x)));
+ min_xi = min(max(min_xi, 0), W - 1);
+ max_x = max(max(face_ndc[0], face_ndc[3]), face_ndc[6]);
+ max_x = (max_x + 1) / 2 * W;
+ max_xi = static_cast(ceilf(static_cast(max_x)));
+ max_xi = min(max(max_xi, 0), W - 1);
+
+ min_y = min(min(face_ndc[1], face_ndc[4]), face_ndc[7]);
+ min_y = (min_y + 1) / 2 * H;
+ min_yi = static_cast(floorf(static_cast(min_y)));
+ min_yi = min(max(min_yi, 0), H - 1);
+ max_y = max(max(face_ndc[1], face_ndc[4]), face_ndc[7]);
+ max_y = (max_y + 1) / 2 * H;
+ max_yi = static_cast(ceilf(static_cast(max_y)));
+ max_yi = min(max(max_yi, 0), H - 1);
+
+ vertices_per_thread_x = (max_xi - min_xi) / blockDim.x + 1;
+ vertices_per_thread_y = (max_yi - min_yi) / blockDim.y + 1;
+ }
+ __syncthreads();
+
+ const int left = min_xi + vertices_per_thread_x * threadIdx.x;
+ const int right = min(left + vertices_per_thread_x, max_xi);
+
+ const int top = min_yi + vertices_per_thread_y * threadIdx.y;
+ const int bottom = min(top + vertices_per_thread_y, max_yi);
+
+ scalar_t x, y, face_z, wa, wb, wc, wsum;
+ for (int i = top; i <= bottom; i++) {
+ for (int j = left; j <= right; j++) {
+ x = 2 * ((scalar_t)j + 0.5) / W - 1;
+ y = 2 * ((scalar_t)i + 0.5) / H - 1;
+
+ // check pixel is inside the face
+ if (((y - face_ndc[1]) * (face_ndc[3] - face_ndc[0]) > (x - face_ndc[0]) * (face_ndc[4] - face_ndc[1])) ||
+ ((y - face_ndc[4]) * (face_ndc[6] - face_ndc[3]) > (x - face_ndc[3]) * (face_ndc[7] - face_ndc[4])) ||
+ ((y - face_ndc[7]) * (face_ndc[0] - face_ndc[6]) > (x - face_ndc[6]) * (face_ndc[1] - face_ndc[7]))) {
+ continue;
+ }
+
+ wa = face_ndc_inv[0] * x + face_ndc_inv[1] * y + face_ndc_inv[2];
+ wb = face_ndc_inv[3] * x + face_ndc_inv[4] * y + face_ndc_inv[5];
+ wc = face_ndc_inv[6] * x + face_ndc_inv[7] * y + face_ndc_inv[8];
+ wsum = wa + wb + wc;
+ wa /= wsum; wb /= wsum; wc /= wsum;
+
+ wa /= face_ndc[2];
+ wb /= face_ndc[5];
+ wc /= face_ndc[8];
+ wsum = wa + wb + wc;
+ wa /= wsum; wb /= wsum; wc /= wsum;
+
+ face_z = wa * face_ndc[2] + wb * face_ndc[5] + wc * face_ndc[8];
+
+ if (face_z - eps < depth[i][j]) {
+ for (int c = 0; c < C; c++) {
+ result[i][j][c] = wa * vertice_values[ai][c] + wb * vertice_values[bi][c] + wc * vertice_values[ci][c];
+ }
+ }
+ }
+ }
+}
+
+
+template
+__global__ void estimate_normals_cuda_kernel(
+ const torch::PackedTensorAccessor32 vertices_ndc,
+ const torch::PackedTensorAccessor32 faces,
+ const torch::PackedTensorAccessor32 depth,
+ const scalar_t* global_face_ndc_inv,
+ const int* global_is_bad_face,
+ const torch::PackedTensorAccessor32 vertices,
+ torch::PackedTensorAccessor32 coords,
+ torch::PackedTensorAccessor32 normals
+) {
+ const int face_indx = blockIdx.x;
+
+ if (global_is_bad_face[face_indx]) {
+ return;
+ }
+
+ const int H = depth.size(0);
+ const int W = depth.size(1);
+ const scalar_t eps = 1e-5;
+
+ scalar_t min_x, max_x, min_y, max_y;
+ scalar_t v1x, v1y, v1z, v2x, v2y, v2z, nlen;
+ __shared__ int vertices_per_thread_x, vertices_per_thread_y;
+ __shared__ int ai, bi, ci;
+ __shared__ scalar_t face[9];
+ __shared__ scalar_t face_ndc[9];
+ __shared__ scalar_t face_ndc_inv[9];
+ __shared__ int min_xi, max_xi, min_yi, max_yi;
+ __shared__ scalar_t nx, ny, nz;
+
+ if (threadIdx.x == 0 && threadIdx.y == 0) {
+ ai = faces[face_indx][0];
+ bi = faces[face_indx][1];
+ ci = faces[face_indx][2];
+
+ face[0] = vertices[ai][0]; face[1] = vertices[ai][1]; face[2] = vertices[ai][2];
+ face[3] = vertices[bi][0]; face[4] = vertices[bi][1]; face[5] = vertices[bi][2];
+ face[6] = vertices[ci][0]; face[7] = vertices[ci][1]; face[8] = vertices[ci][2];
+
+ v1x = face[3] - face[0]; v2x = face[6] - face[0];
+ v1y = face[4] - face[1]; v2y = face[7] - face[1];
+ v1z = face[5] - face[2]; v2z = face[8] - face[2];
+
+ nx = v1y * v2z - v1z * v2y;
+ ny = v1z * v2x - v1x * v2z;
+ nz = v1x * v2y - v1y * v2x;
+ nlen = nx * nx + ny * ny + nz * nz;
+ nlen = (scalar_t)sqrt((float)nlen);
+ nx /= nlen;
+ ny /= nlen;
+ nz /= nlen;
+
+ face_ndc[0] = vertices_ndc[ai][0]; face_ndc[1] = vertices_ndc[ai][1]; face_ndc[2] = vertices_ndc[ai][2];
+ face_ndc[3] = vertices_ndc[bi][0]; face_ndc[4] = vertices_ndc[bi][1]; face_ndc[5] = vertices_ndc[bi][2];
+ face_ndc[6] = vertices_ndc[ci][0]; face_ndc[7] = vertices_ndc[ci][1]; face_ndc[8] = vertices_ndc[ci][2];
+
+ for (int i = 0; i < 9; ++i) {
+ face_ndc_inv[i] = global_face_ndc_inv[9 * face_indx + i];
+ }
+
+ min_x = min(min(face_ndc[0], face_ndc[3]), face_ndc[6]);
+ min_x = (min_x + 1) / 2 * W; // convert from ndc to img coordinates
+ min_xi = static_cast(floorf(static_cast(min_x)));
+ min_xi = min(max(min_xi, 0), W - 1);
+ max_x = max(max(face_ndc[0], face_ndc[3]), face_ndc[6]);
+ max_x = (max_x + 1) / 2 * W;
+ max_xi = static_cast(ceilf(static_cast(max_x)));
+ max_xi = min(max(max_xi, 0), W - 1);
+
+ min_y = min(min(face_ndc[1], face_ndc[4]), face_ndc[7]);
+ min_y = (min_y + 1) / 2 * H;
+ min_yi = static_cast(floorf(static_cast(min_y)));
+ min_yi = min(max(min_yi, 0), H - 1);
+ max_y = max(max(face_ndc[1], face_ndc[4]), face_ndc[7]);
+ max_y = (max_y + 1) / 2 * H;
+ max_yi = static_cast(ceilf(static_cast(max_y)));
+ max_yi = min(max(max_yi, 0), H - 1);
+
+ vertices_per_thread_x = (max_xi - min_xi) / blockDim.x + 1;
+ vertices_per_thread_y = (max_yi - min_yi) / blockDim.y + 1;
+ }
+ __syncthreads();
+
+ const int left = min_xi + vertices_per_thread_x * threadIdx.x;
+ const int right = min(left + vertices_per_thread_x, max_xi);
+
+ const int top = min_yi + vertices_per_thread_y * threadIdx.y;
+ const int bottom = min(top + vertices_per_thread_y, max_yi);
+
+ scalar_t x, y, face_z, wa, wb, wc, wsum;
+ for (int i = top; i <= bottom; i++) {
+ for (int j = left; j <= right; j++) {
+ x = 2 * ((scalar_t)j + 0.5) / W - 1;
+ y = 2 * ((scalar_t)i + 0.5) / H - 1;
+
+ // check pixel is inside the face
+ if (((y - face_ndc[1]) * (face_ndc[3] - face_ndc[0]) > (x - face_ndc[0]) * (face_ndc[4] - face_ndc[1])) ||
+ ((y - face_ndc[4]) * (face_ndc[6] - face_ndc[3]) > (x - face_ndc[3]) * (face_ndc[7] - face_ndc[4])) ||
+ ((y - face_ndc[7]) * (face_ndc[0] - face_ndc[6]) > (x - face_ndc[6]) * (face_ndc[1] - face_ndc[7]))) {
+ continue;
+ }
+
+ wa = face_ndc_inv[0] * x + face_ndc_inv[1] * y + face_ndc_inv[2];
+ wb = face_ndc_inv[3] * x + face_ndc_inv[4] * y + face_ndc_inv[5];
+ wc = face_ndc_inv[6] * x + face_ndc_inv[7] * y + face_ndc_inv[8];
+ wsum = wa + wb + wc;
+ wa /= wsum; wb /= wsum; wc /= wsum;
+
+ wa /= face_ndc[2];
+ wb /= face_ndc[5];
+ wc /= face_ndc[8];
+ wsum = wa + wb + wc;
+ wa /= wsum; wb /= wsum; wc /= wsum;
+
+ face_z = wa * face_ndc[2] + wb * face_ndc[5] + wc * face_ndc[8];
+
+ if (face_z - eps < depth[i][j]) {
+ coords[i][j][0] = wa * face[0] + wb * face[3] + wc * face[6];
+ coords[i][j][1] = wa * face[1] + wb * face[4] + wc * face[7];
+ coords[i][j][2] = wa * face[2] + wb * face[5] + wc * face[8];
+
+ normals[i][j][0] = nx;
+ normals[i][j][1] = ny;
+ normals[i][j][2] = nz;
+ }
+ }
+ }
+}
+
+// cpp defined functions
+
+torch::Tensor project_mesh_cuda(
+ const torch::Tensor& vertices_ndc,
+ const torch::Tensor& faces,
+ const torch::Tensor& vertice_values,
+ const torch::Tensor& vertices_filter,
+ int H, int W
+) {
+ const int N = vertices_ndc.size(0);
+ const int C = vertice_values.size(1);
+ const int M = faces.size(0);
+
+ const int gpuid = vertices_ndc.device().index();
+ AT_CUDA_CHECK(cudaSetDevice(gpuid));
+ auto options = torch::dtype(vertices_ndc.scalar_type()).device(torch::kCUDA, gpuid);
+
+ const dim3 dimGrid(M);
+ const dim3 dimBlock(4, 4);
+
+ auto depth = torch::ones({H, W}, options) * 1e10;
+ auto result = torch::zeros({H, W, C}, options);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(vertices_ndc.scalar_type(), "project_mesh_cuda_kernel", [&] {
+ scalar_t* global_face_ndc_inv;
+ cudaMalloc(&global_face_ndc_inv, M * 9 * sizeof(scalar_t));
+ int* global_is_bad_face;
+ cudaMalloc(&global_is_bad_face, M * sizeof(int));
+ rasterize_cuda_kernel<<>>(
+ vertices_ndc.packed_accessor32(),
+ faces.packed_accessor32(),
+ vertices_filter.packed_accessor32(),
+ depth.packed_accessor32(),
+ global_face_ndc_inv,
+ global_is_bad_face
+ );
+ AT_CUDA_CHECK(cudaGetLastError());
+
+ interpolate_cuda_kernel<<>>(
+ vertices_ndc.packed_accessor32(),
+ faces.packed_accessor32(),
+ depth.packed_accessor32(),
+ global_face_ndc_inv,
+ global_is_bad_face,
+ vertice_values.packed_accessor32(),
+ result.packed_accessor32()
+ );
+ AT_CUDA_CHECK(cudaGetLastError());
+
+ cudaFree(global_face_ndc_inv);
+ cudaFree(global_is_bad_face);
+ AT_CUDA_CHECK(cudaGetLastError());
+ });
+
+ return result;
+}
+
+
+std::vector estimate_normals_cuda(
+ const torch::Tensor& vertices_ndc,
+ const torch::Tensor& faces,
+ const torch::Tensor& vertices,
+ const torch::Tensor& vertices_filter,
+ int H, int W
+) {
+ const int N = vertices_ndc.size(0);
+ const int M = faces.size(0);
+
+ const int gpuid = vertices_ndc.device().index();
+ AT_CUDA_CHECK(cudaSetDevice(gpuid));
+ auto options = torch::dtype(vertices_ndc.scalar_type()).device(torch::kCUDA, gpuid);
+
+ const dim3 dimGrid(M);
+ const dim3 dimBlock(4, 4);
+
+ auto depth = torch::ones({H, W}, options) * 1e10;
+ auto coords = torch::zeros({H, W, 3}, options);
+ auto normals = torch::zeros({H, W, 3}, options);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(vertices_ndc.scalar_type(), "project_mesh_cuda_kernel", [&] {
+ scalar_t* global_face_ndc_inv;
+ cudaMalloc(&global_face_ndc_inv, M * 9 * sizeof(scalar_t));
+ int* global_is_bad_face;
+ cudaMalloc(&global_is_bad_face, M * sizeof(int));
+ rasterize_cuda_kernel<<>>(
+ vertices_ndc.packed_accessor32(),
+ faces.packed_accessor32(),
+ vertices_filter.packed_accessor32(),
+ depth.packed_accessor32(),
+ global_face_ndc_inv,
+ global_is_bad_face
+ );
+ AT_CUDA_CHECK(cudaGetLastError());
+
+ estimate_normals_cuda_kernel<<>>(
+ vertices_ndc.packed_accessor32(),
+ faces.packed_accessor32(),
+ depth.packed_accessor32(),
+ global_face_ndc_inv,
+ global_is_bad_face,
+ vertices.packed_accessor32(),
+ coords.packed_accessor32(),
+ normals.packed_accessor32()
+ );
+ AT_CUDA_CHECK(cudaGetLastError());
+
+ cudaFree(global_face_ndc_inv);
+ cudaFree(global_is_bad_face);
+ AT_CUDA_CHECK(cudaGetLastError());
+ });
+
+ return {coords, normals};
+}
diff --git a/detrsmpl/core/renderer/mpr_renderer/rasterizer.py b/detrsmpl/core/renderer/mpr_renderer/rasterizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..18746e8b582193b8cb15219be568a9d74cf411a3
--- /dev/null
+++ b/detrsmpl/core/renderer/mpr_renderer/rasterizer.py
@@ -0,0 +1,68 @@
+import torch
+
+try:
+ from detrsmpl.core.renderer.mpr_renderer.cuda.rasterizer import \
+ estimate_normals as estimate_normals_cuda # noqa: E501
+ from detrsmpl.core.renderer.mpr_renderer.cuda.rasterizer import \
+ project_mesh as project_mesh_cuda # noqa: E501
+except (ImportError, ModuleNotFoundError):
+ print('Please reinstall MMHuman3D to build mpr_renderer.')
+ raise
+
+
+def estimate_normals(vertices, faces, pinhole, vertices_filter=None):
+ """Estimate the vertices normals with the specified faces and camera.
+
+ Args:
+ vertices (torch.tensor): Shape should be (num_verts, 3).
+ faces (torch.tensor): The faces of the vertices.
+ pinhole (object): The object of the camera.
+
+ Returns:
+ coords (torch.tensor): The estimated coordinates.
+ normals (torch.tensor): The estimated normals.
+ """
+ if vertices_filter is None:
+ assert torch.is_tensor(vertices)
+ assert vertices.is_cuda
+ assert len(vertices.shape) == 2
+ n = vertices.shape[0]
+ vertices_filter = torch.ones((n),
+ dtype=torch.uint8,
+ device=vertices.device)
+ vertices = vertices.contiguous()
+ vertices_ndc = pinhole.project_ndc(vertices)
+ coords, normals = estimate_normals_cuda(vertices_ndc, faces, vertices,
+ vertices_filter, pinhole.h,
+ pinhole.w)
+ return coords, normals
+
+
+def project_mesh(vertices,
+ faces,
+ vertice_values,
+ pinhole,
+ vertices_filter=None):
+ """Project mesh to the image plane with the specified faces and camera.
+
+ Args:
+ vertices (torch.tensor): Shape should be (num_verts, 3).
+ faces (torch.tensor): The faces of the vertices.
+ vertice_values (torch.tensor): The depth of the each vertex.
+ pinhole (object): The object of the camera.
+
+ Returns:
+ torch.tensor: The projected mesh.
+ """
+ if vertices_filter is None:
+ assert torch.is_tensor(vertices)
+ assert vertices.is_cuda
+ assert len(vertices.shape) == 2
+ n = vertices.shape[0]
+ vertices_filter = torch.ones((n),
+ dtype=torch.uint8,
+ device=vertices.device)
+ vertices = vertices.contiguous()
+ vertices_ndc = pinhole.project_ndc(vertices)
+ return project_mesh_cuda(vertices_ndc, faces, vertice_values,
+ vertices_filter, pinhole.h, pinhole.w)
diff --git a/detrsmpl/core/renderer/mpr_renderer/smpl_realrender.py b/detrsmpl/core/renderer/mpr_renderer/smpl_realrender.py
new file mode 100644
index 0000000000000000000000000000000000000000..e795b1fbabfccd59f7ba4b638071ec211cf4c1b6
--- /dev/null
+++ b/detrsmpl/core/renderer/mpr_renderer/smpl_realrender.py
@@ -0,0 +1,48 @@
+import cv2
+import numpy as np
+import torch
+
+from detrsmpl.core.renderer.mpr_renderer.camera import Pinhole2D
+from detrsmpl.core.renderer.mpr_renderer.rasterizer import \
+ estimate_normals # noqa: E501
+from detrsmpl.core.renderer.mpr_renderer.utils import \
+ vis_normals # noqa: E501
+
+
+class VisualizerMeshSMPL:
+ def __init__(self,
+ device=None,
+ body_models=None,
+ focal_length=5000.,
+ camera_center=[112., 112.],
+ resolution=None,
+ scale=None):
+ self.body_models = body_models
+ self.pinhole2d = Pinhole2D(fx=focal_length,
+ fy=focal_length,
+ cx=camera_center[0],
+ cy=camera_center[1],
+ w=resolution[1],
+ h=resolution[0])
+ self.device = torch.device(device)
+ self.faces = self.body_models.faces_tensor.to(dtype=torch.int32,
+ device=self.device)
+
+ def __call__(self, vertices, bg=None, **kwargs):
+ assert vertices.device == self.faces.device
+ vertices = vertices.clone()
+ coords, normals = estimate_normals(vertices=vertices,
+ faces=self.faces,
+ pinhole=self.pinhole2d)
+ vis = vis_normals(coords, normals)
+ if bg is not None:
+ mask = coords[:, :, [2]] <= 0
+ vis = (
+ vis[:, :, None] +
+ torch.tensor(bg).to(mask.device) * mask).cpu().numpy().astype(
+ np.uint8)
+ else:
+ # convert gray to 3 channel img
+ vis = vis.detach().cpu().numpy()
+ vis = cv2.merge((vis, vis, vis))
+ return vis
diff --git a/detrsmpl/core/renderer/mpr_renderer/utils.py b/detrsmpl/core/renderer/mpr_renderer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3256d1092a5dd3481a8201530d8f6c5d8eb50dcd
--- /dev/null
+++ b/detrsmpl/core/renderer/mpr_renderer/utils.py
@@ -0,0 +1,37 @@
+import torch
+
+
+def vis_z_buffer(z, percentile=1, vis_pad=0.2):
+ z = z[:, :, 0]
+ mask = z > 1e-5
+ if torch.sum(mask) == 0:
+ z[...] = 0
+ else:
+ vmin = torch.quantile(z[mask], percentile / 100)
+ vmax = torch.quantile(z[mask], 1 - percentile / 100)
+ pad = (vmax - vmin) * vis_pad
+ vmin_padded = vmin - pad
+ vmax_padded = vmax + pad
+ z[mask] = vmin + vmax - z[mask]
+ z = (z - vmin_padded) / (vmax_padded - vmin_padded)
+ z = torch.clip(torch.round(z * 255), 0, 255)
+ z_cpu = z.to(dtype=torch.uint8).detach().cpu().numpy()
+ return z_cpu
+
+
+def vis_normals(coords, normals, vis_pad=0.2):
+ mask = coords[:, :, 2] > 0
+ coords_masked = -coords[mask]
+ normals_masked = normals[mask]
+
+ coords_len = torch.sqrt(torch.sum(coords_masked**2, dim=1))
+
+ dot = torch.sum(coords_masked * normals_masked, dim=1) / coords_len
+
+ h, w = normals.shape[:2]
+ vis = torch.zeros((h, w), dtype=coords.dtype, device=coords.device)
+ vis[mask] = torch.clamp(dot, 0, 1) * (1 - 2 * vis_pad) + vis_pad
+
+ vis = (vis * 255).to(dtype=torch.uint8)
+
+ return vis
diff --git a/detrsmpl/core/renderer/torch3d_renderer/__init__.py b/detrsmpl/core/renderer/torch3d_renderer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/core/renderer/torch3d_renderer/base_renderer.py b/detrsmpl/core/renderer/torch3d_renderer/base_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6a37a534cca1f1c6e9ae5fd5e2f4971aa326f3e
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/base_renderer.py
@@ -0,0 +1,273 @@
+import os.path as osp
+import shutil
+from pathlib import Path
+from typing import Optional, Tuple, Union
+
+import cv2
+import mmcv
+import torch
+import torch.nn as nn
+from pytorch3d.renderer import (
+ AmbientLights,
+ BlendParams,
+ DirectionalLights,
+ Materials,
+ MeshRasterizer,
+ PointLights,
+ RasterizationSettings,
+)
+
+from detrsmpl.core.cameras import MMCamerasBase
+from detrsmpl.utils.ffmpeg_utils import images_to_gif, images_to_video
+from detrsmpl.utils.path_utils import check_path_suffix
+from .lights import build_lights
+from .shader import build_shader
+from .utils import normalize, rgb2bgr, tensor2array
+
+
+class BaseRenderer(nn.Module):
+ def __init__(self,
+ resolution: Tuple[int, int] = None,
+ device: Union[torch.device, str] = 'cpu',
+ output_path: Optional[str] = None,
+ out_img_format: str = '%06d.png',
+ **kwargs) -> None:
+ """BaseRenderer for differentiable rendering and visualization.
+
+ Args:
+ resolution (Iterable[int]):
+ (width, height) of the rendered images resolution.
+ device (Union[torch.device, str], optional):
+ You can pass a str or torch.device for cpu or gpu render.
+ Defaults to 'cpu'.
+ output_path (Optional[str], optional):
+ Output path of the video or images to be saved.
+ Defaults to None.
+ out_img_format (str, optional): The image format string for
+ saving the images.
+ Defaults to '%06d.png'.
+
+ **kwargs is used for render setting.
+ You can set up your render kwargs like:
+ {
+ 'shader': {
+ 'type': 'soft_phong'
+ },
+ 'lights': {
+ 'type': 'directional',
+ 'direction': [[10.0, 10.0, 10.0]],
+ 'ambient_color': [[0.5, 0.5, 0.5]],
+ 'diffuse_color': [[0.5, 0.5, 0.5]],
+ 'specular_color': [[0.5, 0.5, 0.5]],
+ },
+ 'materials': {
+ 'ambient_color': [[1, 1, 1]],
+ 'diffuse_color': [[0.5, 0.5, 0.5]],
+ 'specular_color': [[0.5, 0.5, 0.5]],
+ 'shininess': 60.0,
+ },
+ 'rasterizer': {
+ 'bin_size': 0,
+ 'blur_radius': 0.0,
+ 'faces_per_pixel': 1,
+ 'perspective_correct': False,
+ },
+ 'blend_params': {'background_color': (1.0, 1.0, 1.0)},
+ },
+ You can change any parameter in the suitable range, please check
+ configs/render/smpl.py.
+
+ Returns:
+ None
+ """
+ super().__init__()
+ self.device = device
+ self.output_path = output_path
+ self.resolution = resolution
+ self.temp_path = None
+ self.out_img_format = out_img_format
+ self._set_output_path(output_path)
+ self._init_renderer(**kwargs)
+
+ def _init_renderer(self,
+ rasterizer: Union[dict, nn.Module] = None,
+ shader: Union[dict, nn.Module] = None,
+ materials: Union[dict, Materials] = None,
+ lights: Union[dict, DirectionalLights, PointLights,
+ AmbientLights] = None,
+ blend_params: Union[dict, BlendParams] = None,
+ **kwargs):
+ """Initial renderer."""
+ if isinstance(materials, dict):
+ materials = Materials(**materials)
+ elif materials is None:
+ materials = Materials()
+ elif not isinstance(materials, Materials):
+ raise TypeError(f'Wrong type of materials: {type(materials)}.')
+
+ if isinstance(lights, dict):
+ self.lights = build_lights(lights)
+ elif lights is None:
+ self.lights = AmbientLights()
+ elif isinstance(lights,
+ (AmbientLights, PointLights, DirectionalLights)):
+ self.lights = lights
+ else:
+ raise TypeError(f'Wrong type of lights: {type(lights)}.')
+
+ if isinstance(blend_params, dict):
+ blend_params = BlendParams(**blend_params)
+ elif blend_params is None:
+ blend_params = BlendParams()
+ elif not isinstance(blend_params, BlendParams):
+ raise TypeError(
+ f'Wrong type of blend_params: {type(blend_params)}.')
+
+ if isinstance(rasterizer, nn.Module):
+ if self.resolution is not None:
+ rasterizer.raster_settings.image_size = self.resolution
+ self.rasterizer = rasterizer
+ elif isinstance(rasterizer, dict):
+ if self.resolution is not None:
+ rasterizer['image_size'] = self.resolution
+ raster_settings = RasterizationSettings(**rasterizer)
+ self.rasterizer = MeshRasterizer(raster_settings=raster_settings)
+ elif rasterizer is None:
+ self.rasterizer = MeshRasterizer(
+ raster_settings=RasterizationSettings(
+ image_size=self.resolution,
+ bin_size=0,
+ blur_radius=0,
+ faces_per_pixel=1,
+ perspective_correct=False))
+ else:
+ raise TypeError(
+ f'Wrong type of rasterizer: {type(self.rasterizer)}.')
+
+ if self.resolution is None:
+ self.resolution = self.rasterizer.raster_settings.image_size
+ assert self.resolution is not None
+ self.resolution = (self.resolution, self.resolution) if isinstance(
+ self.resolution, int) else tuple(self.resolution)
+ if isinstance(shader, nn.Module):
+ self.shader = shader
+ elif isinstance(shader, dict):
+ shader.update(materials=materials,
+ lights=self.lights,
+ blend_params=blend_params)
+ self.shader = build_shader(shader)
+ elif shader is None:
+ self.shader = build_shader(
+ dict(type=self.shader_type,
+ materials=materials,
+ lights=self.lights,
+ blend_params=blend_params))
+ else:
+ raise TypeError(f'Wrong type of shader: {type(self.shader)}.')
+ self = self.to(self.device)
+
+ def to(self, device):
+ if isinstance(device, str):
+ device = torch.device(device)
+ self.device = device
+ if getattr(self.rasterizer, 'cameras', None) is not None:
+ self.rasterizer.cameras = self.rasterizer.cameras.to(device)
+
+ if getattr(self.shader, 'cameras', None) is not None:
+ self.shader.cameras = self.shader.cameras.to(device)
+ if getattr(self.shader, 'materials', None) is not None:
+ self.shader.materials = self.shader.materials.to(device)
+ if getattr(self.shader, 'lights', None) is not None:
+ self.shader.lights = self.shader.lights.to(device)
+ return self
+
+ def _set_output_path(self, output_path):
+ if output_path is not None:
+ self.output_path = output_path
+ if check_path_suffix(output_path, ['.mp4', '.gif']):
+ self.temp_path = osp.join(
+ Path(output_path).parent,
+ Path(output_path).name + '_output_temp')
+ elif check_path_suffix(output_path, ['.png', '.jpg', '.jpeg']):
+ mmcv.mkdir_or_exist(Path(output_path).parent)
+ self.temp_path = osp.join(
+ Path(output_path).parent,
+ Path(output_path).name + '_output_temp')
+ else:
+ self.temp_path = output_path
+ mmcv.mkdir_or_exist(self.temp_path)
+ print('Make dir', self.temp_path)
+
+ def _update_resolution(self, cameras, **kwargs):
+ if isinstance(cameras, MMCamerasBase):
+ self.resolution = (int(cameras.resolution[0][0]),
+ int(cameras.resolution[0][1]))
+ if 'resolution' in kwargs:
+ self.resolution = kwargs.get('resolution')
+ self.rasterizer.raster_settings.image_size = self.resolution
+
+ def export(self):
+ """Export output video if need."""
+ if self.output_path is not None:
+ folder = self.temp_path if self.temp_path is not None else\
+ self.output_path
+ if check_path_suffix(self.output_path, ['.mp4']):
+ images_to_video(input_folder=folder,
+ output_path=self.output_path,
+ img_format=self.out_img_format)
+ elif check_path_suffix(self.output_path, ['.gif']):
+ images_to_gif(input_folder=folder,
+ output_path=self.output_path,
+ img_format=self.out_img_format)
+
+ def __del__(self):
+ """remove_temp_files."""
+ if self.output_path is not None:
+ if Path(self.output_path).is_file():
+ self._remove_temp_frames()
+
+ def _remove_temp_frames(self):
+ """Remove temp files."""
+ if self.temp_path:
+ if osp.exists(self.temp_path) and osp.isdir(self.temp_path):
+ shutil.rmtree(self.temp_path)
+
+ def _write_images(self, rgba, backgrounds, indexes):
+ """Write output/temp images."""
+ if rgba.shape[-1] > 3:
+ rgbs, valid_masks = rgba[..., :3], rgba[..., 3:]
+ else:
+ rgbs = rgba[..., :3]
+ valid_masks = torch.ones_like(rgbs[..., :1])
+ rgbs = normalize(rgbs, origin_value_range=(0, 1), clip=True)
+ bgrs = rgb2bgr(rgbs)
+ if backgrounds is not None:
+ image_max = 1.0 if backgrounds.max() <= 1.0 else 255
+ backgrounds = normalize(backgrounds,
+ origin_value_range=(0, image_max),
+ out_value_range=(0, 1))
+ output_images = bgrs * valid_masks + (1 -
+ valid_masks) * backgrounds
+ output_images = tensor2array(output_images)
+
+ else:
+ output_images = tensor2array(bgrs)
+ for idx, real_idx in enumerate(indexes):
+ folder = self.temp_path if self.temp_path is not None else\
+ self.output_path
+ cv2.imwrite(osp.join(folder, self.out_img_format % real_idx),
+ output_images[idx])
+
+ def forward(self):
+ """"Should be called by each sub renderer class."""
+ raise NotImplementedError()
+
+ def tensor2rgba(self, tensor: torch.Tensor):
+ valid_masks = (tensor[..., 3:] > 0) * 1.0
+ rgbs = tensor[..., :3]
+
+ rgbs = normalize(rgbs,
+ origin_value_range=[0, 1],
+ out_value_range=[0, 1])
+ rgba = torch.cat([rgbs, valid_masks], -1)
+ return rgba
diff --git a/detrsmpl/core/renderer/torch3d_renderer/builder.py b/detrsmpl/core/renderer/torch3d_renderer/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..57834e17a1c44c1c5a7008a21d6fdd019c6f8ec3
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/builder.py
@@ -0,0 +1,45 @@
+from mmcv.utils import Registry
+
+from .base_renderer import BaseRenderer
+from .depth_renderer import DepthRenderer
+from .mesh_renderer import MeshRenderer
+from .normal_renderer import NormalRenderer
+from .pointcloud_renderer import PointCloudRenderer
+from .segmentation_renderer import SegmentationRenderer
+from .silhouette_renderer import SilhouetteRenderer
+from .uv_renderer import UVRenderer
+
+RENDERER = Registry('renderer')
+RENDERER.register_module(
+ name=['base', 'Base', 'base_renderer', 'BaseRenderer'],
+ module=BaseRenderer)
+RENDERER.register_module(
+ name=['Depth', 'depth', 'depth_renderer', 'DepthRenderer'],
+ module=DepthRenderer)
+RENDERER.register_module(
+ name=['Mesh', 'mesh', 'mesh_renderer', 'MeshRenderer'],
+ module=MeshRenderer)
+RENDERER.register_module(
+ name=['Normal', 'normal', 'normal_renderer', 'NormalRenderer'],
+ module=NormalRenderer)
+RENDERER.register_module(name=[
+ 'PointCloud', 'pointcloud', 'point_cloud', 'pointcloud_renderer',
+ 'PointCloudRenderer'
+],
+ module=PointCloudRenderer)
+RENDERER.register_module(name=[
+ 'segmentation', 'segmentation_renderer', 'Segmentation',
+ 'SegmentationRenderer'
+],
+ module=SegmentationRenderer)
+RENDERER.register_module(name=[
+ 'silhouette', 'silhouette_renderer', 'Silhouette', 'SilhouetteRenderer'
+],
+ module=SilhouetteRenderer)
+RENDERER.register_module(name=['uv_renderer', 'uv', 'UV', 'UVRenderer'],
+ module=UVRenderer)
+
+
+def build_renderer(cfg):
+ """Build renderers."""
+ return RENDERER.build(cfg)
diff --git a/detrsmpl/core/renderer/torch3d_renderer/depth_renderer.py b/detrsmpl/core/renderer/torch3d_renderer/depth_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..51a5cbc5f2817769c92270316d6f4fd0b2c352de
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/depth_renderer.py
@@ -0,0 +1,109 @@
+from typing import Iterable, Optional, Tuple, Union
+
+import torch
+from pytorch3d.structures import Meshes
+
+from detrsmpl.core.cameras import MMCamerasBase
+from .base_renderer import BaseRenderer
+from .shader import build_shader
+from .utils import normalize
+
+
+class DepthRenderer(BaseRenderer):
+ """Render depth map with the help of camera system."""
+ shader_type = 'DepthShader'
+
+ def __init__(
+ self,
+ resolution: Tuple[int, int] = None,
+ device: Union[torch.device, str] = 'cpu',
+ output_path: Optional[str] = None,
+ out_img_format: str = '%06d.png',
+ depth_max: Union[int, float, torch.Tensor] = None,
+ **kwargs,
+ ) -> None:
+ """Renderer for depth map of meshes.
+
+ Args:
+ resolution (Iterable[int]):
+ (width, height) of the rendered images resolution.
+ device (Union[torch.device, str], optional):
+ You can pass a str or torch.device for cpu or gpu render.
+ Defaults to 'cpu'.
+ output_path (Optional[str], optional):
+ Output path of the video or images to be saved.
+ Defaults to None.
+ out_img_format (str, optional): The image format string for
+ saving the images.
+ Defaults to '%06d.png'.
+
+ depth_max (Union[int, float, torch.Tensor], optional):
+ The max value for normalize depth range. Defaults to None.
+
+ Returns:
+ None
+ """
+ super().__init__(resolution=resolution,
+ device=device,
+ output_path=output_path,
+ out_img_format=out_img_format,
+ **kwargs)
+ self.depth_max = depth_max
+
+ def _init_renderer(self,
+ rasterizer=None,
+ shader=None,
+ materials=None,
+ lights=None,
+ blend_params=None,
+ **kwargs):
+ shader = build_shader(dict(
+ type='DepthShader')) if shader is None else shader
+ return super()._init_renderer(rasterizer, shader, materials, lights,
+ blend_params, **kwargs)
+
+ def forward(self,
+ meshes: Optional[Meshes] = None,
+ cameras: Optional[MMCamerasBase] = None,
+ indexes: Optional[Iterable[int]] = None,
+ backgrounds: Optional[torch.Tensor] = None,
+ **kwargs):
+ """Render depth map.
+
+ Args:
+ meshes (Optional[Meshes], optional): meshes to be rendered.
+ Defaults to None.
+ cameras (Optional[MMCamerasBase], optional): cameras for rendering.
+ Defaults to None.
+ indexes (Optional[Iterable[int]], optional): indexes for the
+ images.
+ Defaults to None.
+ backgrounds (Optional[torch.Tensor], optional): background images.
+ Defaults to None.
+
+ Returns:
+ Union[torch.Tensor, None]: return tensor or None.
+ """
+ meshes = meshes.to(self.device)
+ self._update_resolution(cameras, **kwargs)
+
+ fragments = self.rasterizer(meshes_world=meshes, cameras=cameras)
+ depth_map = self.shader(fragments=fragments,
+ meshes=meshes,
+ cameras=cameras)
+
+ if self.output_path is not None:
+ rgba = self.tensor2rgba(depth_map)
+ if self.output_path is not None:
+ self._write_images(rgba, backgrounds, indexes)
+
+ return depth_map
+
+ def tensor2rgba(self, tensor: torch.Tensor):
+ rgbs, valid_masks = tensor.repeat(1, 1, 1, 3), (tensor > 0) * 1.0
+ depth_max = self.depth_max if self.depth_max is not None else rgbs.max(
+ )
+ rgbs = normalize(rgbs,
+ origin_value_range=(0, depth_max),
+ out_value_range=(0, 1))
+ return torch.cat([rgbs, valid_masks], -1)
diff --git a/detrsmpl/core/renderer/torch3d_renderer/lights/__init__.py b/detrsmpl/core/renderer/torch3d_renderer/lights/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..06d85748e0646a09f1d6828d1d6ffc636011bc27
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/lights/__init__.py
@@ -0,0 +1,10 @@
+# yapf: disable
+from .builder import ( # noqa: F401
+ AmbientLights,
+ DirectionalLights,
+ PointLights,
+ build_lights,
+)
+from .lights import MMLights # noqa: F401
+
+# yapf: enable
diff --git a/detrsmpl/core/renderer/torch3d_renderer/lights/builder.py b/detrsmpl/core/renderer/torch3d_renderer/lights/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c82511f41722380486567503b83c817441070ce
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/lights/builder.py
@@ -0,0 +1,17 @@
+from mmcv.utils import Registry
+
+from .lights import AmbientLights, DirectionalLights, PointLights # noqa:E401
+
+LIGHTS = Registry('lights')
+LIGHTS.register_module(
+ name=['directional', 'directional_lights', 'DirectionalLights'],
+ module=DirectionalLights)
+LIGHTS.register_module(name=['point', 'point_lights', 'PointLights'],
+ module=PointLights)
+LIGHTS.register_module(name=['ambient', 'ambient_lights', 'AmbientLights'],
+ module=AmbientLights)
+
+
+def build_lights(cfg):
+ """Build lights."""
+ return LIGHTS.build(cfg)
diff --git a/detrsmpl/core/renderer/torch3d_renderer/lights/lights.py b/detrsmpl/core/renderer/torch3d_renderer/lights/lights.py
new file mode 100644
index 0000000000000000000000000000000000000000..62eda2eddba87237c65e6f2924a8921cde6fed74
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/lights/lights.py
@@ -0,0 +1,80 @@
+from typing import Union
+
+import torch
+from pytorch3d.renderer.lighting import AmbientLights as _AmbientLights
+from pytorch3d.renderer.lighting import DirectionalLights as _DirectionalLights
+from pytorch3d.renderer.lighting import PointLights as _PointLights
+from pytorch3d.renderer.utils import TensorProperties
+
+MMLIGHT_ATTR = [
+ 'ambient_color', 'diffuse_color', 'specular_color', 'location', 'direction'
+]
+
+
+class MMLights(TensorProperties):
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+ _N = 1
+ self.mmlight_attr_list = []
+ for attr_name in MMLIGHT_ATTR:
+ if hasattr(self, attr_name):
+ self.mmlight_attr_list.append(attr_name)
+ for k in self.mmlight_attr_list:
+ v = getattr(self, k)
+ if not isinstance(v, torch.Tensor):
+ v = torch.Tensor(v)
+ v = v.view(-1, 3)
+ setattr(self, k, v)
+
+ if getattr(self, k).shape[0] > _N:
+ _N = getattr(self, k).shape[0]
+ for k in self.mmlight_attr_list:
+ if getattr(self, k).shape[0] == 1:
+ setattr(self, k, getattr(self, k).repeat(_N, 1))
+ self._N = _N
+
+ def __len__(self, ):
+ return self._N
+
+ def __getitem__(self, index: Union[int, slice]):
+ if isinstance(index, int):
+ index = [index]
+ kwargs = {}
+ for k in self.mmlight_attr_list:
+ kwargs[k] = getattr(self, k)[index]
+
+ return self.__class__(device=self.device, **kwargs)
+
+ def extend(self, N):
+ kwargs = {}
+ for k in self.mmlight_attr_list:
+ kwargs[k] = getattr(self, k).repeat(N, 1)
+ return self.__class__(device=self.device, **kwargs)
+
+ def extend_(self, N):
+ for k in self.mmlight_attr_list:
+ setattr(self, k, getattr(self, k).repeat(N, 1))
+ self._N = N
+
+
+class AmbientLights(_AmbientLights, MMLights):
+ def __init__(self, ambient_color=None, device='cpu', **kwargs) -> None:
+ if ambient_color is None:
+ ambient_color = ((1.0, 1.0, 1.0), )
+ diffuse_color = ((0.0, 0.0, 0.0), )
+ super(_AmbientLights, self).__init__(ambient_color=ambient_color,
+ diffuse_color=diffuse_color,
+ device=device)
+
+ def __getitem__(self, index: Union[int, slice]):
+ return super(_AmbientLights, self).__getitem__(index)
+
+
+class PointLights(_PointLights, MMLights):
+ def __getitem__(self, index: Union[int, slice]):
+ return super(_PointLights, self).__getitem__(index)
+
+
+class DirectionalLights(_DirectionalLights, MMLights):
+ def __getitem__(self, index: Union[int, slice]):
+ return super(_DirectionalLights, self).__getitem__(index)
diff --git a/detrsmpl/core/renderer/torch3d_renderer/mesh_renderer.py b/detrsmpl/core/renderer/torch3d_renderer/mesh_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ff4876d6bf6cc259c666eb93cc3abcaf0abe581
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/mesh_renderer.py
@@ -0,0 +1,81 @@
+from typing import Iterable, Optional, Tuple, Union
+
+import torch
+from pytorch3d.structures import Meshes
+
+from detrsmpl.core.cameras import MMCamerasBase
+from .base_renderer import BaseRenderer
+from .lights import MMLights
+
+
+class MeshRenderer(BaseRenderer):
+ """Render RGBA image with the help of camera system."""
+ shader_type = 'SoftPhongShader'
+
+ def __init__(
+ self,
+ resolution: Tuple[int, int] = None,
+ device: Union[torch.device, str] = 'cpu',
+ output_path: Optional[str] = None,
+ out_img_format: str = '%06d.png',
+ **kwargs,
+ ) -> None:
+ """Renderer for RGBA image of meshes.
+
+ Args:
+ resolution (Iterable[int]):
+ (width, height) of the rendered images resolution.
+ device (Union[torch.device, str], optional):
+ You can pass a str or torch.device for cpu or gpu render.
+ Defaults to 'cpu'.
+ output_path (Optional[str], optional):
+ Output path of the video or images to be saved.
+ Defaults to None.
+ out_img_format (str, optional): The image format string for
+ saving the images.
+ Defaults to '%06d.png'.
+ """
+ super().__init__(resolution=resolution,
+ device=device,
+ output_path=output_path,
+ out_img_format=out_img_format,
+ **kwargs)
+
+ def forward(self,
+ meshes: Meshes,
+ cameras: Optional[MMCamerasBase] = None,
+ lights: Optional[MMLights] = None,
+ indexes: Optional[Iterable[int]] = None,
+ backgrounds: Optional[torch.Tensor] = None,
+ **kwargs) -> Union[torch.Tensor, None]:
+ """Render Meshes.
+
+ Args:
+ meshes (Meshes): meshes to be rendered.
+ cameras (Optional[MMCamerasBase], optional): cameras for render.
+ Defaults to None.
+ lights (Optional[MMLights], optional): lights for render.
+ Defaults to None.
+ indexes (Optional[Iterable[int]], optional): indexes for images.
+ Defaults to None.
+ backgrounds (Optional[torch.Tensor], optional): background images.
+ Defaults to None.
+
+ Returns:
+ Union[torch.Tensor, None]: return tensor or None.
+ """
+
+ meshes = meshes.to(self.device)
+ self._update_resolution(cameras, **kwargs)
+ fragments = self.rasterizer(meshes_world=meshes, cameras=cameras)
+
+ rendered_images = self.shader(
+ fragments=fragments,
+ meshes=meshes,
+ cameras=cameras,
+ lights=self.lights if lights is None else lights)
+
+ if self.output_path is not None:
+ rgba = self.tensor2rgba(rendered_images)
+ self._write_images(rgba, backgrounds, indexes)
+ return rendered_images
diff --git a/detrsmpl/core/renderer/torch3d_renderer/meshes.py b/detrsmpl/core/renderer/torch3d_renderer/meshes.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5b8783b78218e1635b6f61660dfa77f4d2d1e71
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/meshes.py
@@ -0,0 +1,526 @@
+from typing import Iterable, List, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from pytorch3d.renderer import TexturesUV, TexturesVertex
+from pytorch3d.renderer.mesh.textures import TexturesBase
+from pytorch3d.structures import Meshes, list_to_padded, padded_to_list
+
+from detrsmpl.models.body_models.builder import SMPL, SMPLX
+from detrsmpl.utils.mesh_utils import \
+ join_meshes_as_batch as _join_meshes_as_batch
+from .builder import build_renderer
+from .textures.textures import TexturesNearest
+from .utils import align_input_to_padded
+
+
+class ParametricMeshes(Meshes):
+ """Mesh structure for parametric body models, E.g., smpl, smplx, mano,
+ flame.
+
+ There are 3 ways to initialize the verts:
+ 1): Pass the verts directly as verts_padded (N, V, 3) or verts_list
+ (list of (N, 3)).
+ 2): Pass body_model and pose_params.
+ 3): Pass meshes. Could be Meshes or ParametricMeshes.
+ Will use the verts from the meshes.
+ There are 3 ways to initialize the faces:
+ 1): Pass the faces directly as faces_padded (N, F, 3) or faces_list
+ (list of (F, 3)).
+ 2): Pass body_model and will use body_model.faces_tensor.
+ 3): Pass meshes. Could be Meshes or ParametricMeshes.
+ Will use the faces from the meshes.
+ There are 4 ways to initialize the textures.
+ 1): Pass the textures directly.
+ 2): Pass the texture_images of shape (H, W, 3) for single person or
+ (_N_individual, H, W, 3) for multi-person. `body_model` should be
+ passed and should has `uv_renderer`.
+ 3): Pass the vertex_color of shape (3) or (V, 3) or (N, V, 3).
+ 4): Pass meshes. Could be Meshes or ParametricMeshes.
+ Will use the textures directly from the meshes.
+ """
+ # TODO: More model class to be added (FLAME, MANO)
+ MODEL_CLASSES = {'smpl': SMPL, 'smplx': SMPLX}
+
+ def __init__(self,
+ verts: Union[List[torch.Tensor], torch.Tensor] = None,
+ faces: Union[List[torch.Tensor], torch.Tensor] = None,
+ textures: TexturesBase = None,
+ meshes: Meshes = None,
+ body_model: Union[nn.Module, dict] = None,
+ uv_renderer: Union[nn.Module, dict] = None,
+ vertex_color: Union[Iterable[float], torch.Tensor,
+ np.ndarray] = ((1, 1, 1), ),
+ use_nearest: bool = False,
+ texture_images: Union[torch.Tensor, List[torch.Tensor],
+ None] = None,
+ model_type: str = 'smpl',
+ N_individual_override: int = None,
+ *,
+ verts_normals: torch.Tensor = None,
+ **pose_params) -> None:
+
+ if isinstance(meshes, Meshes):
+ verts = meshes.verts_padded()
+ faces = meshes.faces_padded()
+ textures = meshes.textures
+
+ self.model_type = body_model._get_name().lower(
+ ) if body_model is not None else model_type
+
+ self.model_class = self.MODEL_CLASSES[self.model_type]
+
+ use_list = False
+
+ # formart verts as verts_padded: (N, V, 3)
+ if verts is None:
+ assert body_model is not None
+ verts = body_model(**pose_params)['vertices']
+ elif isinstance(verts, list):
+ verts = list_to_padded(verts)
+ use_list = True
+ # specify number of individuals
+ if N_individual_override is not None:
+ verts = verts.view(
+ -1, self.model_class.NUM_VERTS * N_individual_override, 3)
+
+ # the information of _N_individual should be revealed in verts's shape
+ self._N_individual = int(verts.shape[-2] // self.model_class.NUM_VERTS)
+
+ assert verts.shape[1] % self.model_class.NUM_VERTS == 0
+ verts = verts.view(-1, self.model_class.NUM_VERTS * self._N_individual,
+ 3)
+ device = verts.device
+ N, V, _ = verts.shape
+
+ # formart faces as faces_padded: (N, F, 3)
+ if isinstance(faces, list):
+ faces = list_to_padded(faces)
+ self.face_individual = faces[0][:self.model_class.NUM_FACES].to(
+ device)
+ elif faces is None:
+ assert body_model is not None
+ self.face_individual = body_model.faces_tensor[None].to(device)
+ faces = self.get_faces_padded(N, self._N_individual)
+ elif isinstance(faces, torch.Tensor):
+ faces = align_input_to_padded(faces, ndim=3, batch_size=N)
+ self.face_individual = faces[:1, :self.model_class.NUM_FACES].to(
+ device)
+ else:
+ raise ValueError(f'Wrong type of faces: {type(faces)}.')
+
+ assert faces.shape == (N,
+ self.model_class.NUM_FACES * self._N_individual,
+ 3)
+ F = faces.shape[1]
+ if textures is None:
+ if texture_images is None:
+ # input vertex_color should be
+ # (3), (1, 3), (1, 1, 3). all the same color
+ # (V, 3), (1, V, 3), each vertex has a single color
+ # (N, V, 3), each batch each vertex has a single color
+ if isinstance(vertex_color, (tuple, list)):
+ vertex_color = torch.Tensor(vertex_color)
+ elif isinstance(vertex_color, np.ndarray):
+ vertex_color = torch.from_numpy(vertex_color)
+ if vertex_color.numel() == 3:
+ vertex_color = vertex_color.view(1, 3).repeat(V, 1)
+ vertex_color = align_input_to_padded(vertex_color,
+ ndim=3,
+ batch_size=N)
+ assert vertex_color.shape == verts.shape
+ if use_nearest:
+ textures = TexturesNearest(
+ verts_features=vertex_color).to(device)
+ else:
+ textures = TexturesVertex(
+ verts_features=vertex_color).to(device)
+ else:
+
+ texture_images = align_input_to_padded(texture_images,
+ ndim=4,
+ batch_size=N).to(device)
+
+ assert uv_renderer is not None
+ if isinstance(uv_renderer, dict):
+ uv_renderer = build_renderer(uv_renderer)
+ uv_renderer = uv_renderer.to(device)
+ textures = uv_renderer.wrap_texture(texture_images).to(device)
+ if self._N_individual > 1:
+ textures = textures.join_scene()
+ textures = textures.extend(N)
+
+ num_verts_per_mesh = [V for _ in range(N)]
+ num_faces_per_mesh = [F for _ in range(N)]
+
+ if use_list:
+ verts = padded_to_list(verts, num_verts_per_mesh)
+ faces = padded_to_list(faces, num_faces_per_mesh)
+ super().__init__(
+ verts=verts,
+ faces=faces,
+ textures=textures,
+ verts_normals=verts_normals,
+ )
+
+ def get_faces_padded(self, N_batch, N_individual):
+ faces = self.face_individual.repeat(N_batch, N_individual, 1)
+ faces_offset = torch.arange(N_individual).view(N_individual, 1).repeat(
+ 1, self.model_class.NUM_FACES).view(1, -1, 1).to(faces.device)
+ faces = faces + faces_offset * self.model_class.NUM_VERTS
+ return faces
+
+ def _compute_list(self):
+ self._faces_list = self.faces_list()
+ self._verts_list = self.verts_list()
+
+ def extend(self, N_batch: int, N_scene: int = 1):
+ if N_batch == 1:
+ meshes_batch = self
+ else:
+ meshes_batch = join_meshes_as_batch([self for _ in range(N_batch)])
+
+ if N_scene == 1:
+ meshes = meshes_batch
+ else:
+ meshes = join_batch_meshes_as_scene(
+ [meshes_batch for _ in range(N_scene)])
+ return meshes
+
+ def clone(self):
+ """Modified from pytorch3d and add `model_type` in
+ __class__.__init__."""
+ verts_list = self.verts_list()
+ faces_list = self.faces_list()
+ new_verts_list = [v.clone() for v in verts_list]
+ new_faces_list = [f.clone() for f in faces_list]
+ other = self.__class__(verts=new_verts_list,
+ faces=new_faces_list,
+ model_type=self.model_type)
+ for k in self._INTERNAL_TENSORS:
+ v = getattr(self, k)
+ if torch.is_tensor(v):
+ setattr(other, k, v.clone())
+
+ # Textures is not a tensor but has a clone method
+ if self.textures is not None:
+ other.textures = self.textures.clone()
+ return other
+
+ def detach(self):
+ """Modified from pytorch3d and add `model_type` in
+ __class__.__init__."""
+ verts_list = self.verts_list()
+ faces_list = self.faces_list()
+ new_verts_list = [v.detach() for v in verts_list]
+ new_faces_list = [f.detach() for f in faces_list]
+ other = self.__class__(verts=new_verts_list,
+ faces=new_faces_list,
+ model_type=self.model_type)
+
+ for k in self._INTERNAL_TENSORS:
+ v = getattr(self, k)
+ if torch.is_tensor(v):
+ setattr(other, k, v.detach())
+
+ # Textures is not a tensor but has a detach method
+ if self.textures is not None:
+ other.textures = self.textures.detach()
+ return other
+
+ def update_padded(self, new_verts_padded: torch.Tensor):
+ """Modified from pytorch3d and add `model_type` in
+ __class__.__init__."""
+ def check_shapes(x, size):
+ if x.shape[0] != size[0]:
+ raise ValueError('new values must have the same batch size.')
+ if x.shape[1] != size[1]:
+ raise ValueError(
+ 'new values must have the same number of points.')
+ if x.shape[2] != size[2]:
+ raise ValueError('new values must have the same dimension.')
+
+ check_shapes(new_verts_padded, [self._N, self._V, 3])
+
+ new = self.__class__(verts=new_verts_padded,
+ faces=self.faces_padded(),
+ model_type=self.model_type)
+
+ if new._N != self._N or new._V != self._V or new._F != self._F:
+ raise ValueError('Inconsistent sizes after construction.')
+
+ # overwrite the equisized flag
+ new.equisized = self.equisized
+
+ # overwrite textures if any
+ new.textures = self.textures
+
+ # copy auxiliary tensors
+ copy_tensors = ['_num_verts_per_mesh', '_num_faces_per_mesh', 'valid']
+
+ for k in copy_tensors:
+ v = getattr(self, k)
+ if torch.is_tensor(v):
+ setattr(new, k, v) # shallow copy
+
+ # shallow copy of faces_list if any, st new.faces_list()
+ # does not re-compute from _faces_padded
+ new._faces_list = self._faces_list
+
+ # update verts/faces packed if they are computed in self
+ if self._verts_packed is not None:
+ copy_tensors = [
+ '_faces_packed',
+ '_verts_packed_to_mesh_idx',
+ '_faces_packed_to_mesh_idx',
+ '_mesh_to_verts_packed_first_idx',
+ '_mesh_to_faces_packed_first_idx',
+ ]
+ for k in copy_tensors:
+ v = getattr(self, k)
+ assert torch.is_tensor(v)
+ setattr(new, k, v) # shallow copy
+ # update verts_packed
+ pad_to_packed = self.verts_padded_to_packed_idx()
+ new_verts_packed = new_verts_padded.reshape(-1,
+ 3)[pad_to_packed, :]
+ new._verts_packed = new_verts_packed
+ new._verts_padded_to_packed_idx = pad_to_packed
+
+ # update edges packed if they are computed in self
+ if self._edges_packed is not None:
+ copy_tensors = [
+ '_edges_packed',
+ '_edges_packed_to_mesh_idx',
+ '_mesh_to_edges_packed_first_idx',
+ '_faces_packed_to_edges_packed',
+ '_num_edges_per_mesh',
+ ]
+ for k in copy_tensors:
+ v = getattr(self, k)
+ assert torch.is_tensor(v)
+ setattr(new, k, v) # shallow copy
+
+ # update laplacian if it is compute in self
+ if self._laplacian_packed is not None:
+ new._laplacian_packed = self._laplacian_packed
+
+ assert new._verts_list is None
+ assert new._verts_normals_packed is None
+ assert new._faces_normals_packed is None
+ assert new._faces_areas_packed is None
+
+ return new
+
+ def __getitem__(self, index: Union[tuple, int, list, slice, torch.Tensor]):
+ """Slice the meshes by the batch dim like pytorch3d Meshes. And slice
+ by scene dim due to the topology of the parametric meshes.
+
+ Args:
+ index (Union[tuple, int, list, slice, torch.Tensor]): indexes, if
+ pass only one augment, will ignore the scene dim.
+ """
+ if isinstance(index, tuple):
+ batch_index, individual_index = index
+ else:
+ batch_index, individual_index = index, None
+
+ if isinstance(batch_index, int):
+ batch_index = [batch_index]
+ elif isinstance(batch_index, (tuple, list, slice)):
+ batch_index = torch.arange(self._N)[batch_index]
+ batch_index = torch.tensor(batch_index) if not isinstance(
+ batch_index, torch.Tensor) else batch_index
+ batch_index = batch_index.to(self.device, dtype=torch.long)
+
+ if (batch_index >= self._N).any():
+ raise IndexError('list index out of range')
+
+ if individual_index is None:
+ return self.__class__(verts=self.verts_padded()[batch_index],
+ faces=self.faces_padded()[batch_index],
+ textures=self.textures[batch_index]
+ if self.textures is not None else None,
+ model_type=self.model_type)
+
+ if isinstance(individual_index, int):
+ individual_index = [individual_index]
+ elif isinstance(individual_index, (tuple, list, slice)):
+ individual_index = torch.arange(
+ self._N_individual)[individual_index]
+ individual_index = torch.tensor(individual_index) if not isinstance(
+ individual_index, torch.Tensor) else individual_index
+ if (individual_index > self._N_individual).any():
+ raise IndexError('list index out of range')
+ vertex_index = [
+ torch.arange(self.model_class.NUM_VERTS) +
+ idx * self.model_class.NUM_VERTS for idx in individual_index
+ ]
+ vertex_index = torch.cat(vertex_index).to(self.device).long()
+
+ new_face_num = self.model_class.NUM_FACES * len(individual_index)
+
+ verts_padded = self.verts_padded()[batch_index][:, vertex_index]
+ faces_padded = self.get_faces_padded(len(verts_padded),
+ len(individual_index))
+
+ textures_batch = self.textures[batch_index]
+
+ if isinstance(textures_batch, TexturesUV):
+ # TODO: there is still some problem with `TexturesUV`
+ # slice and need to fix the function `join_meshes_as_scene`.
+ # It is recommended that we re-inplement the `TexturesUV`
+ # as `ParametricTexturesUV`, mainly for the `__getitem__`
+ # and `join_scene` functions.
+
+ # textures_batch.get('unique_map_index ')
+
+ # This version only consider the maps tensor as different id.
+ maps = textures_batch.maps_padded()
+ width_individual = maps.shape[-2] // self._N_individual
+ maps_index = [
+ torch.arange(width_individual * idx,
+ width_individual * (idx + 1))
+ for idx in individual_index
+ ]
+ maps_index = torch.cat(maps_index).to(self.device)
+ verts_uvs_padded = textures_batch.verts_uvs_padded(
+ )[:, :len(vertex_index)] * torch.Tensor([
+ self._N_individual / len(individual_index), 1
+ ]).view(1, 1, 2).to(self.device)
+ faces_uvs_padded = textures_batch.faces_uvs_padded(
+ )[:, :new_face_num]
+ maps_padded = maps[:, :, maps_index]
+ textures = TexturesUV(faces_uvs=faces_uvs_padded,
+ verts_uvs=verts_uvs_padded,
+ maps=maps_padded)
+ elif isinstance(textures_batch, (TexturesVertex, TexturesNearest)):
+ verts_features_padded = textures_batch.verts_features_padded(
+ )[:, vertex_index]
+ textures = textures_batch.__class__(verts_features_padded)
+ meshes = self.__class__(verts=verts_padded,
+ faces=faces_padded,
+ textures=textures,
+ model_type=self.model_type)
+ return meshes
+
+ @property
+ def shape(self, ):
+ return (len(self), self._N_individual)
+
+
+def join_meshes_as_batch(meshes: List[ParametricMeshes],
+ include_textures: bool = True) -> ParametricMeshes:
+ """Join the meshes along the batch dim.
+
+ Args:
+ meshes (Union[ParametricMeshes, List[ParametricMeshes, Meshes,
+ List[Meshes]]]): Meshes object that contains a batch of meshes,
+ or a list of Meshes objects.
+ include_textures (bool, optional): whether to try to join the textures.
+ Defaults to True.
+
+ Returns:
+ ParametricMeshes: the joined ParametricMeshes.
+ """
+ if isinstance(meshes, ParametricMeshes):
+ raise ValueError('Wrong first argument to join_meshes_as_batch.')
+ first = meshes[0]
+
+ assert all(mesh.model_type == first.model_type
+ for mesh in meshes), 'model_type should all be the same.'
+
+ meshes = _join_meshes_as_batch(meshes, include_textures=include_textures)
+ return ParametricMeshes(model_type=first.model_type, meshes=meshes)
+
+
+def join_meshes_as_scene(meshes: Union[ParametricMeshes,
+ List[ParametricMeshes]],
+ include_textures: bool = True) -> ParametricMeshes:
+ """Join the meshes along the scene dim.
+
+ Args:
+ meshes (Union[ParametricMeshes, List[ParametricMeshes]]):
+ ParametricMeshes object that contains a batch of meshes,
+ or a list of ParametricMeshes objects.
+ include_textures (bool, optional): whether to try to join the textures.
+ Defaults to True.
+
+ Returns:
+ ParametricMeshes: the joined ParametricMeshes.
+ """
+ first = meshes[0]
+ assert all(mesh.model_type == first.model_type
+ for mesh in meshes), 'model_type should all be the same.'
+
+ if isinstance(meshes, List):
+ meshes = join_meshes_as_batch(meshes,
+ include_textures=include_textures)
+
+ if len(meshes) == 1:
+ return meshes
+ verts = meshes.verts_packed() # (sum(V_n), 3)
+ # Offset automatically done by faces_packed
+ faces = meshes.faces_packed() # (sum(F_n), 3)
+ textures = None
+
+ if include_textures and meshes.textures is not None:
+ textures = meshes.textures.join_scene()
+
+ mesh = ParametricMeshes(verts=verts.unsqueeze(0),
+ faces=faces.unsqueeze(0),
+ textures=textures,
+ model_type=first.model_type)
+
+ return mesh
+
+
+def join_batch_meshes_as_scene(
+ meshes: List[ParametricMeshes],
+ include_textures: bool = True) -> ParametricMeshes:
+ """Join `meshes` as a scene each batch. For ParametricMeshes. The Meshes
+ must share the same batch size, and topology could be different. They must
+ all be on the same device. If `include_textures` is true, the textures
+ should be the same type, all be None is not accepted. If `include_textures`
+ is False, textures are ignored. The return meshes will have no textures.
+
+ Args:
+ meshes (List[ParametricMeshes]): Meshes object that contains a list of
+ Meshes objects.
+ include_textures (bool, optional): whether to try to join the textures.
+ Defaults to True.
+
+
+ Returns:
+ New Meshes which has join different Meshes by each batch.
+ """
+ first = meshes[0]
+
+ assert all(mesh.model_type == first.model_type
+ for mesh in meshes), 'model_type should all be the same.'
+
+ assert all(len(mesh) == len(first) for mesh in meshes)
+ if not all(mesh.shape[1] == first.shape[1] for mesh in meshes):
+ meshes_temp = []
+ for mesh_scene in meshes:
+ meshes_temp.extend([
+ mesh_scene[:, individual_index]
+ for individual_index in range(mesh_scene._N_individual)
+ ])
+ meshes = meshes_temp
+ for mesh in meshes:
+ mesh._verts_list = padded_to_list(mesh.verts_padded(),
+ mesh.num_verts_per_mesh().tolist())
+ num_scene_size = len(meshes)
+ num_batch_size = len(meshes[0])
+
+ meshes_all = []
+ for j in range(num_batch_size):
+ meshes_batch = []
+ for i in range(num_scene_size):
+ meshes_batch.append(meshes[i][j])
+ meshes_all.append(join_meshes_as_scene(meshes_batch, include_textures))
+ meshes_final = join_meshes_as_batch(meshes_all, include_textures)
+
+ return meshes_final
diff --git a/detrsmpl/core/renderer/torch3d_renderer/normal_renderer.py b/detrsmpl/core/renderer/torch3d_renderer/normal_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e9364830c251302261b3a7f1b5578e86e706495
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/normal_renderer.py
@@ -0,0 +1,89 @@
+from typing import Iterable, Optional, Union
+
+import torch
+from pytorch3d.structures import Meshes
+
+from detrsmpl.core.cameras import MMCamerasBase
+from .base_renderer import BaseRenderer
+from .utils import normalize
+
+
+class NormalRenderer(BaseRenderer):
+ """Render normal map with the help of camera system."""
+ shader_type = 'NormalShader'
+
+ def __init__(
+ self,
+ resolution: Iterable[int] = None,
+ device: Union[torch.device, str] = 'cpu',
+ output_path: Optional[str] = None,
+ out_img_format: str = '%06d.png',
+ **kwargs,
+ ) -> None:
+ """Renderer for normal map of meshes.
+
+ Args:
+ resolution (Iterable[int]):
+ (width, height) of the rendered images resolution.
+ device (Union[torch.device, str], optional):
+ You can pass a str or torch.device for cpu or gpu render.
+ Defaults to 'cpu'.
+ output_path (Optional[str], optional):
+ Output path of the video or images to be saved.
+ Defaults to None.
+ out_img_format (str, optional): The image format string for
+ saving the images.
+ Defaults to '%06d.png'.
+
+ Returns:
+ None
+ """
+ super().__init__(resolution=resolution,
+ device=device,
+ output_path=output_path,
+ obj_path=None,
+ out_img_format=out_img_format,
+ **kwargs)
+
+ def forward(self,
+ meshes: Optional[Meshes] = None,
+ cameras: Optional[MMCamerasBase] = None,
+ indexes: Optional[Iterable[int]] = None,
+ backgrounds: Optional[torch.Tensor] = None,
+ **kwargs):
+ """Render Meshes.
+
+ Args:
+ meshes (Optional[Meshes], optional): meshes to be rendered.
+ Defaults to None.
+ cameras (Optional[MMCamerasBase], optional): cameras for render.
+ Defaults to None.
+ indexes (Optional[Iterable[int]], optional): indexes for the
+ images.
+ Defaults to None.
+ backgrounds (Optional[torch.Tensor], optional): background images.
+ Defaults to None.
+
+ Returns:
+ Union[torch.Tensor, None]: return tensor or None.
+ """
+
+ meshes = meshes.to(self.device)
+ self._update_resolution(cameras, **kwargs)
+ fragments = self.rasterizer(meshes_world=meshes, cameras=cameras)
+ normal_map = self.shader(fragments=fragments,
+ meshes=meshes,
+ cameras=cameras)
+
+ if self.output_path is not None:
+ rgba = self.tensor2rgba(normal_map)
+ self._write_images(rgba, backgrounds, indexes)
+
+ return normal_map
+
+ def tensor2rgba(self, tensor: torch.Tensor):
+ rgbs, valid_masks = tensor[..., :3], (tensor[..., 3:] > 0) * 1.0
+ rgbs = normalize(rgbs,
+ origin_value_range=(-1, 1),
+ out_value_range=(0, 1))
+ return torch.cat([rgbs, valid_masks], -1)
diff --git a/detrsmpl/core/renderer/torch3d_renderer/pointcloud_renderer.py b/detrsmpl/core/renderer/torch3d_renderer/pointcloud_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..045e69ba5346d94811c39ba0b1440b9e642d2ac0
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/pointcloud_renderer.py
@@ -0,0 +1,161 @@
+import warnings
+from typing import Iterable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from pytorch3d.renderer import (
+ AlphaCompositor,
+ PointsRasterizationSettings,
+ PointsRasterizer,
+)
+from pytorch3d.structures import Meshes, Pointclouds
+
+from detrsmpl.core.cameras import MMCamerasBase
+from detrsmpl.utils.mesh_utils import mesh_to_pointcloud_vc
+from .base_renderer import BaseRenderer
+
+
+class PointCloudRenderer(BaseRenderer):
+ def __init__(self,
+ resolution: Tuple[int, int] = None,
+ device: Union[torch.device, str] = 'cpu',
+ output_path: Optional[str] = None,
+ out_img_format: str = '%06d.png',
+ radius: Optional[float] = None,
+ **kwargs) -> None:
+ """Point cloud renderer.
+
+ Args:
+ resolution (Iterable[int]):
+ (width, height) of the rendered images resolution.
+ device (Union[torch.device, str], optional):
+ You can pass a str or torch.device for cpu or gpu render.
+ Defaults to 'cpu'.
+ output_path (Optional[str], optional):
+ Output path of the video or images to be saved.
+ Defaults to None.
+ out_img_format (str, optional): name format for temp images.
+ Defaults to '%06d.png'.
+ radius (float, optional): radius of points. Defaults to None.
+
+ Returns:
+ None
+ """
+ self.radius = radius
+ super().__init__(resolution=resolution,
+ device=device,
+ output_path=output_path,
+ out_img_format=out_img_format,
+ **kwargs)
+
+ def to(self, device):
+ if isinstance(device, str):
+ device = torch.device(device)
+ self.device = device
+ if getattr(self.rasterizer, 'cameras', None) is not None:
+ self.rasterizer.cameras = self.rasterizer.cameras.to(device)
+
+ self.compositor = self.compositor.to(device)
+ return self
+
+ def _init_renderer(self, rasterizer=None, compositor=None, **kwargs):
+ """Set render params."""
+
+ if isinstance(rasterizer, nn.Module):
+ rasterizer.raster_settings.image_size = self.resolution
+ self.rasterizer = rasterizer
+ elif isinstance(rasterizer, dict):
+ rasterizer['image_size'] = self.resolution
+ if self.radius is not None:
+ rasterizer.update(radius=self.radius)
+ raster_settings = PointsRasterizationSettings(**rasterizer)
+ self.rasterizer = PointsRasterizer(raster_settings=raster_settings)
+ elif rasterizer is None:
+ self.rasterizer = PointsRasterizer(
+ raster_settings=PointsRasterizationSettings(
+ radius=self.radius,
+ image_size=self.resolution,
+ points_per_pixel=10))
+ else:
+ raise TypeError(
+ f'Wrong type of rasterizer: {type(self.rasterizer)}.')
+
+ if isinstance(compositor, dict):
+ self.compositor = AlphaCompositor(**compositor)
+ elif isinstance(compositor, nn.Module):
+ self.compositor = compositor
+ elif compositor is None:
+ self.compositor = AlphaCompositor()
+ else:
+ raise TypeError(
+ f'Wrong type of compositor: {type(self.compositor)}.')
+ self = self.to(self.device)
+
+ def forward(
+ self,
+ pointclouds: Optional[Pointclouds] = None,
+ vertices: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+ verts_rgba: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+ meshes: Meshes = None,
+ cameras: Optional[MMCamerasBase] = None,
+ indexes: Optional[Iterable[int]] = None,
+ backgrounds: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[None, torch.Tensor]:
+ """Render pointclouds.
+
+ Args:
+ pointclouds (Optional[Pointclouds], optional): pytorch3d data
+ structure. If not None, `vertices` and `verts_rgba` will
+ be ignored.
+ Defaults to None.
+ vertices (Optional[Union[torch.Tensor, List[torch.Tensor]]],
+ optional): coordinate tensor of points. Defaults to None.
+ verts_rgba (Optional[Union[torch.Tensor, List[torch.Tensor]]],
+ optional): color tensor of points. Defaults to None.
+ indexes (Optional[Iterable[int]], optional): indexes for the
+ images.
+ Defaults to None.
+ backgrounds (Optional[torch.Tensor], optional): background images.
+ Defaults to None.
+
+ Returns:
+ Union[None, torch.Tensor]: Return tensor or None.
+ """
+ if pointclouds is None:
+ if meshes is not None:
+ pointclouds = mesh_to_pointcloud_vc(meshes)
+ else:
+ assert vertices is not None
+ if isinstance(vertices, torch.Tensor):
+ if vertices.ndim == 2:
+ vertices = vertices[None]
+ if isinstance(verts_rgba, torch.Tensor):
+ if verts_rgba.ndim == 2:
+ verts_rgba = verts_rgba[None]
+ pointclouds = Pointclouds(points=vertices, features=verts_rgba)
+ else:
+ if vertices is not None or verts_rgba is not None:
+ warnings.warn(
+ 'Redundant input, will ignore `vertices` and `verts_rgb`.')
+ pointclouds = pointclouds.to(self.device)
+ self._update_resolution(cameras, **kwargs)
+ fragments = self.rasterizer(pointclouds, cameras=cameras)
+ r = self.rasterizer.raster_settings.radius
+
+ dists2 = fragments.dists.permute(0, 3, 1, 2)
+ weights = 1 - dists2 / (r * r)
+ rendered_images = self.compositor(
+ fragments.idx.long().permute(0, 3, 1, 2),
+ weights,
+ pointclouds.features_packed().permute(1, 0),
+ **kwargs,
+ )
+ rendered_images = rendered_images.permute(0, 2, 3, 1)
+
+ if self.output_path is not None:
+ rgba = self.tensor2rgba(rendered_images)
+ if self.output_path is not None:
+ self.write_images(rgba, backgrounds, indexes)
+
+ return rendered_images
diff --git a/detrsmpl/core/renderer/torch3d_renderer/render_runner.py b/detrsmpl/core/renderer/torch3d_renderer/render_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9c313011b206438a1b075d920cc22f90d71b32c
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/render_runner.py
@@ -0,0 +1,125 @@
+import math
+import os
+from typing import Iterable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from pytorch3d.renderer import MeshRenderer, SoftSilhouetteShader
+from pytorch3d.renderer.cameras import CamerasBase
+from pytorch3d.structures import Meshes
+from tqdm import trange
+
+from detrsmpl.core.cameras import MMCamerasBase
+from detrsmpl.core.cameras.builder import build_cameras
+from .base_renderer import BaseRenderer
+from .builder import build_renderer
+from .lights import AmbientLights, MMLights, build_lights
+
+osj = os.path.join
+
+
+def render(renderer: Union[nn.Module, dict],
+ meshes: Union[Meshes, None] = None,
+ output_path: Optional[str] = None,
+ resolution: Union[Iterable[int], int] = None,
+ device: Union[str, torch.device] = 'cpu',
+ cameras: Union[MMCamerasBase, CamerasBase, dict, None] = None,
+ lights: Union[MMLights, dict, None] = None,
+ batch_size: int = 5,
+ return_tensor: bool = False,
+ no_grad: bool = False,
+ verbose: bool = True,
+ **forward_params):
+
+ if isinstance(renderer, dict):
+ renderer = build_renderer(renderer)
+ elif isinstance(renderer, MeshRenderer):
+ if isinstance(renderer.shader, SoftSilhouetteShader):
+ renderer = build_renderer(
+ dict(type='silhouette',
+ resolution=resolution,
+ shader=renderer.shader,
+ rasterizer=renderer.rasterizer))
+ else:
+ renderer = build_renderer(
+ dict(type='mesh',
+ resolution=resolution,
+ shader=renderer.shader,
+ rasterizer=renderer.rasterizer))
+ elif isinstance(renderer, BaseRenderer):
+ renderer = renderer
+ else:
+ raise TypeError('Wrong input renderer type.')
+
+ renderer = renderer.to(device)
+ if output_path is not None:
+ renderer._set_output_path(output_path)
+
+ if isinstance(cameras, dict):
+ cameras = build_cameras(cameras)
+ elif isinstance(cameras, MMCamerasBase):
+ cameras = cameras
+ elif isinstance(cameras,
+ CamerasBase) and not isinstance(cameras, MMCamerasBase):
+ cameras = build_cameras(
+ dict(type=cameras.__class__.__name__,
+ K=cameras.K,
+ R=cameras.R,
+ T=cameras.T,
+ in_ndc=cameras.in_ndc(),
+ resolution=resolution))
+ else:
+ raise TypeError('Wrong input cameras type.')
+ num_frames = len(meshes)
+ if isinstance(lights, dict):
+ lights = build_lights(lights)
+ elif isinstance(lights, MMLights):
+ lights = lights
+ elif lights is None:
+ lights = AmbientLights(device=device).extend(num_frames)
+ else:
+ raise ValueError('Wrong light type.')
+
+ if len(cameras) == 1:
+ cameras = cameras.extend(num_frames)
+ if len(lights) == 1:
+ lights = lights.extend(num_frames)
+
+ forward_params.update(lights=lights, cameras=cameras, meshes=meshes)
+
+ batch_size = min(batch_size, num_frames)
+ tensors = []
+ for k in forward_params:
+ if isinstance(forward_params[k], np.ndarray):
+ forward_params.update(
+ {k: torch.tensor(forward_params[k]).to(device)})
+ if verbose:
+ iter_func = trange
+ else:
+ iter_func = range
+ for i in iter_func(math.ceil(num_frames // batch_size)):
+ indexes = list(
+ range(i * batch_size, min((i + 1) * batch_size, len(meshes))))
+ foward_params_batch = {}
+
+ for k in forward_params:
+ if hasattr(forward_params[k], '__getitem__'):
+ foward_params_batch[k] = forward_params[k][indexes].to(device)
+
+ if no_grad:
+ with torch.no_grad():
+ images_batch = renderer(indexes=indexes, **foward_params_batch)
+
+ else:
+ images_batch = renderer(indexes=indexes, **foward_params_batch)
+ # if return_tensor:
+ tensors.append(images_batch)
+
+ renderer.export()
+
+ if return_tensor:
+ tensors = torch.cat(tensors)
+ return tensors
+ else:
+ return np.concatenate(tensors)
diff --git a/detrsmpl/core/renderer/torch3d_renderer/render_smpl_config.py b/detrsmpl/core/renderer/torch3d_renderer/render_smpl_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4afc80b78b7ccd679dd08d71b969b261de242e7d
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/render_smpl_config.py
@@ -0,0 +1,149 @@
+base_directional_light = {
+ 'type': 'directional',
+ 'direction': [[1, 1, 1]],
+ 'ambient_color': [[0.5, 0.5, 0.5]],
+ 'diffuse_color': [[0.5, 0.5, 0.5]],
+ 'specular_color': [[0.5, 0.5, 0.5]],
+}
+
+base_point_light = {
+ 'type': 'point',
+ 'ambient_color': [[1, 1, 1]],
+ 'diffuse_color': [[0.3, 0.3, 0.3]],
+ 'specular_color': [[0.5, 0.5, 0.5]],
+ 'location': [[2.0, 2.0, -2.0]],
+}
+
+base_ambient_light = {
+ 'type': 'ambient',
+ 'ambient_color': [[1.0, 1.0, 1.0]],
+}
+
+base_material = {
+ 'ambient_color': [[1, 1, 1]],
+ 'diffuse_color': [[0.5, 0.5, 0.5]],
+ 'specular_color': [[0.15, 0.15, 0.15]],
+ 'shininess': 60.0,
+}
+
+silhouete_material = {
+ 'ambient_color': [[1.0, 1.0, 1.0]],
+ 'diffuse_color': [[0.0, 0.0, 0.0]],
+ 'specular_color': [[0.0, 0.0, 0.0]],
+ 'shininess': 1.0,
+}
+
+white_blend_params = {'background_color': (1.0, 1.0, 1.0)}
+
+black_blend_params = {'background_color': (0.0, 0.0, 0.0)}
+
+RENDER_CONFIGS = {
+ # low quality
+ 'lq': {
+ 'type': 'mesh',
+ 'shader': {
+ 'type': 'hard_flat'
+ },
+ 'lights': base_directional_light,
+ 'materials': base_material,
+ 'rasterizer': {
+ 'bin_size': 0,
+ 'blur_radius': 0.0,
+ 'faces_per_pixel': 1,
+ 'perspective_correct': False,
+ },
+ 'blend_params': white_blend_params,
+ },
+ # medium quality
+ 'mq': {
+ 'type': 'mesh',
+ 'shader': {
+ 'type': 'soft_gouraud'
+ },
+ 'lights': base_directional_light,
+ 'materials': base_material,
+ 'rasterizer': {
+ 'bin_size': 0,
+ 'blur_radius': 0.0,
+ 'faces_per_pixel': 1,
+ 'perspective_correct': False,
+ },
+ 'blend_params': white_blend_params,
+ },
+ # high quality
+ 'hq': {
+ 'type': 'mesh',
+ 'shader': {
+ 'type': 'soft_phong'
+ },
+ 'lights': base_directional_light,
+ 'materials': base_material,
+ 'rasterizer': {
+ 'bin_size': 0,
+ 'blur_radius': 0.0,
+ 'faces_per_pixel': 1,
+ 'perspective_correct': False,
+ },
+ 'blend_params': white_blend_params,
+ },
+ 'silhouette': {
+ 'type': 'silhouette',
+ 'lights': None,
+ 'materials': silhouete_material,
+ 'rasterizer': {
+ 'bin_size': 0,
+ 'blur_radius': 2e-5,
+ 'faces_per_pixel': 50,
+ 'perspective_correct': False,
+ },
+ 'blend_params': black_blend_params,
+ },
+ 'part_silhouette': {
+ 'type': 'segmentation',
+ 'material': base_material,
+ 'rasterizer': {
+ 'bin_size': 0,
+ 'blur_radius': 0.0,
+ 'faces_per_pixel': 1,
+ 'perspective_correct': False,
+ },
+ 'blend_params': black_blend_params,
+ },
+ 'depth': {
+ 'type': 'depth',
+ 'rasterizer': {
+ 'bin_size': 0,
+ 'blur_radius': 0.0,
+ 'faces_per_pixel': 1,
+ 'perspective_correct': False,
+ },
+ 'blend_params': black_blend_params,
+ },
+ 'normal': {
+ 'type': 'normal',
+ 'rasterizer': {
+ 'bin_size': 0,
+ 'blur_radius': 0.0,
+ 'faces_per_pixel': 1,
+ 'perspective_correct': False,
+ },
+ 'blend_params': white_blend_params,
+ },
+ 'pointcloud': {
+ 'type': 'pointcloud',
+ 'compositor': {
+ 'background_color': [
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ ],
+ },
+ 'rasterizer': {
+ 'points_per_pixel': 10,
+ 'radius': 0.003,
+ 'bin_size': None,
+ 'max_points_per_bin': None,
+ }
+ }
+}
diff --git a/detrsmpl/core/renderer/torch3d_renderer/segmentation_renderer.py b/detrsmpl/core/renderer/torch3d_renderer/segmentation_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c403e4073af365e1709879fb1ad0ff48124d4df9
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/segmentation_renderer.py
@@ -0,0 +1,106 @@
+from typing import Iterable, Optional, Tuple, Union
+
+import torch
+from pytorch3d.structures import Meshes
+
+from detrsmpl.core.cameras import MMCamerasBase
+from detrsmpl.utils.demo_utils import get_different_colors
+from .base_renderer import BaseRenderer
+from .utils import normalize
+
+
+class SegmentationRenderer(BaseRenderer):
+ """Render segmentation map into a segmentation index tensor."""
+ shader_type = 'SegmentationShader'
+
+ def __init__(self,
+ resolution: Tuple[int, int] = None,
+ device: Union[torch.device, str] = 'cpu',
+ output_path: Optional[str] = None,
+ out_img_format: str = '%06d.png',
+ num_class: int = 1,
+ **kwargs) -> None:
+ """Render vertex-color mesh into a segmentation map of a (B, H, W)
+ tensor. For visualization, the output rgba image will be (B, H, W, 4),
+ and the color palette comes from `get_different_colors`. The
+ segmentation map is a tensor each pixel saves the classification index.
+ Please make sure you have allocate each pixel a correct classification
+ index by defining a textures of vertex color.
+
+ [CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.
+ CrossEntropyLoss.html)
+
+ Args:
+ resolution (Iterable[int]):
+ (width, height) of the rendered images resolution.
+ device (Union[torch.device, str], optional):
+ You can pass a str or torch.device for cpu or gpu render.
+ Defaults to 'cpu'.
+ output_path (Optional[str], optional):
+ Output path of the video or images to be saved.
+ Defaults to None.
+ out_img_format (str, optional): The image format string for
+ saving the images.
+ Defaults to '%06d.png'.
+ num_class (int, optional): number of segmentation parts.
+ Defaults to 1.
+
+ Returns:
+ None
+ """
+ super().__init__(resolution=resolution,
+ device=device,
+ output_path=output_path,
+ obj_path=None,
+ out_img_format=out_img_format,
+ **kwargs)
+ self.num_class = num_class
+
+ def forward(self,
+ meshes: Meshes,
+ cameras: Optional[MMCamerasBase] = None,
+ indexes: Optional[Iterable[int]] = None,
+ backgrounds: Optional[torch.Tensor] = None,
+ **kwargs):
+ """Render segmentation map.
+
+ Args:
+ meshes (Meshes): meshes to be rendered.
+ Require the textures type is `TexturesClosest`.
+ The color indicates the class index of the triangle.
+ cameras (Optional[MMCamerasBase], optional): cameras for render.
+ Defaults to None.
+ indexes (Optional[Iterable[int]], optional): indexes for images.
+ Defaults to None.
+ backgrounds (Optional[torch.Tensor], optional): background images.
+ Defaults to None.
+
+ Returns:
+ Union[torch.Tensor, None]: return tensor or None.
+ """
+
+ meshes = meshes.to(self.device)
+ self._update_resolution(cameras, **kwargs)
+ fragments = self.rasterizer(meshes_world=meshes, cameras=cameras)
+ segmentation_map = self.shader(fragments=fragments,
+ meshes=meshes,
+ cameras=cameras)
+
+ if self.output_path is not None:
+ rgba = self.tensor2rgba(segmentation_map)
+ if self.output_path is not None:
+ self._write_images(rgba, backgrounds, indexes)
+
+ return segmentation_map
+
+ def tensor2rgba(self, tensor: torch.Tensor):
+ valid_masks = (tensor[..., :] > 0) * 1.0
+ color = torch.Tensor(get_different_colors(self.num_class))
+ color = torch.cat([torch.zeros(1, 3), color]).to(self.device)
+ B, H, W, _ = tensor.shape
+ rgbs = color[tensor.view(-1)].view(B, H, W, 3) * valid_masks
+ rgbs = normalize(rgbs.float(),
+ origin_value_range=(0, 255),
+ out_value_range=(0, 1))
+ rgba = torch.cat([rgbs, valid_masks], -1)
+ return rgba
diff --git a/detrsmpl/core/renderer/torch3d_renderer/shader/__init__.py b/detrsmpl/core/renderer/torch3d_renderer/shader/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e55fad81f1eb3de0b5e39d7107ffc579b314446
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/shader/__init__.py
@@ -0,0 +1,16 @@
+# yapf: disable
+from .builder import ( # noqa: F401
+ DepthShader,
+ HardFlatShader,
+ HardGouraudShader,
+ HardPhongShader,
+ NoLightShader,
+ NormalShader,
+ SegmentationShader,
+ SilhouetteShader,
+ SoftGouraudShader,
+ SoftPhongShader,
+ build_shader,
+)
+
+# yapf: enable
diff --git a/detrsmpl/core/renderer/torch3d_renderer/shader/builder.py b/detrsmpl/core/renderer/torch3d_renderer/shader/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..900d4440dbdcae86ac1babae876e442b4de1f373
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/shader/builder.py
@@ -0,0 +1,51 @@
+from mmcv.utils import Registry
+from pytorch3d.renderer import (
+ HardFlatShader,
+ HardGouraudShader,
+ HardPhongShader,
+ SoftGouraudShader,
+ SoftPhongShader,
+)
+
+from .shader import (
+ DepthShader,
+ NoLightShader,
+ NormalShader,
+ SegmentationShader,
+ SilhouetteShader,
+)
+
+SHADER = Registry('shader')
+SHADER.register_module(name=[
+ 'flat', 'hard_flat_shader', 'hard_flat', 'HardFlat', 'HardFlatShader'
+],
+ module=HardFlatShader)
+SHADER.register_module(name=['hard_phong', 'HardPhong', 'HardPhongShader'],
+ module=HardPhongShader)
+SHADER.register_module(
+ name=['hard_gouraud', 'HardGouraud', 'HardGouraudShader'],
+ module=HardGouraudShader)
+SHADER.register_module(
+ name=['soft_gouraud', 'SoftGouraud', 'SoftGouraudShader'],
+ module=SoftGouraudShader)
+SHADER.register_module(name=['soft_phong', 'SoftPhong', 'SoftPhongShader'],
+ module=SoftPhongShader)
+SHADER.register_module(name=['silhouette', 'Silhouette', 'SilhouetteShader'],
+ module=SilhouetteShader)
+SHADER.register_module(
+ name=['nolight', 'nolight_shader', 'NoLight', 'NoLightShader'],
+ module=NoLightShader)
+SHADER.register_module(
+ name=['normal', 'normal_shader', 'Normal', 'NormalShader'],
+ module=NormalShader)
+SHADER.register_module(name=['depth', 'depth_shader', 'Depth', 'DepthShader'],
+ module=DepthShader)
+SHADER.register_module(name=[
+ 'segmentation', 'segmentation_shader', 'Segmentation', 'SegmentationShader'
+],
+ module=SegmentationShader)
+
+
+def build_shader(cfg):
+ """Build shader."""
+ return SHADER.build(cfg)
diff --git a/detrsmpl/core/renderer/torch3d_renderer/shader/shader.py b/detrsmpl/core/renderer/torch3d_renderer/shader/shader.py
new file mode 100644
index 0000000000000000000000000000000000000000..977b0859cfae5a7270aad3639318251e6bc68e3a
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/shader/shader.py
@@ -0,0 +1,103 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from pytorch3d.ops import interpolate_face_attributes
+from pytorch3d.renderer import BlendParams, hard_rgb_blend
+from pytorch3d.renderer.mesh.shader import SoftSilhouetteShader
+from pytorch3d.structures.utils import padded_to_packed
+
+
+class SilhouetteShader(SoftSilhouetteShader):
+ """Avoid unexpected keyword argument error."""
+ def __init__(self,
+ blend_params: Optional[BlendParams] = None,
+ **kwargs) -> None:
+ super().__init__(blend_params)
+
+
+class NoLightShader(nn.Module):
+ """No light shader."""
+ def __init__(self,
+ blend_params: Optional[BlendParams] = None,
+ **kwargs) -> None:
+ """Initlialize without blend_params."""
+ super().__init__()
+ self.blend_params = blend_params if blend_params is not None\
+ else BlendParams()
+
+ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
+ """Sample without light."""
+ texels = meshes.sample_textures(fragments)
+ blend_params = kwargs.get('blend_params', self.blend_params)
+ images = hard_rgb_blend(texels, fragments, blend_params)
+ return images
+
+
+class DepthShader(nn.Module):
+ """No light shader."""
+ def __init__(self,
+ blend_params: Optional[BlendParams] = None,
+ **kwargs) -> None:
+ """Initlialize without blend_params."""
+ super().__init__()
+ self.blend_params = blend_params if blend_params is not None\
+ else BlendParams()
+
+ def forward(self, fragments, meshes, cameras, **kwargs) -> torch.Tensor:
+ """Sample without light."""
+ verts_depth = cameras.compute_depth_of_points(meshes.verts_padded())
+ faces = meshes.faces_packed() # (F, 3)
+ verts_depth = padded_to_packed(verts_depth)
+ faces_depth = verts_depth[faces]
+ depth_map = interpolate_face_attributes(
+ pix_to_face=fragments.pix_to_face,
+ barycentric_coords=fragments.bary_coords,
+ face_attributes=faces_depth)
+ return depth_map[..., 0, :]
+
+
+class NormalShader(nn.Module):
+ """No light shader."""
+ def __init__(self,
+ blend_params: Optional[BlendParams] = None,
+ **kwargs) -> None:
+ """Initlialize without blend_params."""
+ super().__init__()
+ self.blend_params = blend_params if blend_params is not None\
+ else BlendParams()
+
+ def forward(self, fragments, meshes, cameras, **kwargs) -> torch.Tensor:
+ """Sample without light."""
+ verts_normal = cameras.compute_normal_of_meshes(meshes)
+ faces = meshes.faces_packed() # (F, 3)
+ verts_normal = padded_to_packed(verts_normal)
+ faces_normal = verts_normal[faces]
+ normal_map = interpolate_face_attributes(
+ pix_to_face=fragments.pix_to_face,
+ barycentric_coords=fragments.bary_coords,
+ face_attributes=faces_normal)
+ return normal_map[..., 0, :]
+
+
+class SegmentationShader(nn.Module):
+ """No light shader."""
+ def __init__(self,
+ blend_params: Optional[BlendParams] = None,
+ **kwargs) -> None:
+ """Initlialize without blend_params."""
+ super().__init__()
+ self.blend_params = blend_params if blend_params is not None\
+ else BlendParams()
+
+ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
+ """Sample without light."""
+ verts_class = meshes.textures.verts_features_padded()
+ faces = meshes.faces_packed() # (F, 3)
+ verts_class = padded_to_packed(verts_class)
+ faces_class = verts_class[faces]
+ segmentation_map = interpolate_face_attributes(
+ pix_to_face=fragments.pix_to_face,
+ barycentric_coords=fragments.bary_coords,
+ face_attributes=faces_class).long()
+ return segmentation_map[..., :, 0]
diff --git a/detrsmpl/core/renderer/torch3d_renderer/silhouette_renderer.py b/detrsmpl/core/renderer/torch3d_renderer/silhouette_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..850f9d7bec5e107b8b2abc56d655da7e2c017b01
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/silhouette_renderer.py
@@ -0,0 +1,89 @@
+from typing import Iterable, Optional, Tuple, Union
+
+import torch
+from pytorch3d.structures import Meshes
+
+from detrsmpl.core.cameras import MMCamerasBase
+from .base_renderer import BaseRenderer
+from .utils import normalize
+
+
+class SilhouetteRenderer(BaseRenderer):
+ """Silhouette renderer."""
+ shader_type = 'SilhouetteShader'
+
+ def __init__(
+ self,
+ resolution: Tuple[int, int] = None,
+ device: Union[torch.device, str] = 'cpu',
+ output_path: Optional[str] = None,
+ out_img_format: str = '%06d.png',
+ **kwargs,
+ ) -> None:
+ """SilhouetteRenderer for neural rendering and visualization.
+
+ Args:
+ resolution (Iterable[int]):
+ (width, height) of the rendered images resolution.
+ device (Union[torch.device, str], optional):
+ You can pass a str or torch.device for cpu or gpu render.
+ Defaults to 'cpu'.
+ output_path (Optional[str], optional):
+ Output path of the video or images to be saved.
+ Defaults to None.
+ out_img_format (str, optional): The image format string for
+ saving the images.
+ Defaults to '%06d.png'.
+
+ Returns:
+ None
+ """
+ super().__init__(resolution=resolution,
+ device=device,
+ output_path=output_path,
+ out_img_format=out_img_format,
+ **kwargs)
+
+ def forward(self,
+ meshes: Optional[Meshes] = None,
+ cameras: Optional[MMCamerasBase] = None,
+ images: Optional[torch.Tensor] = None,
+ indexes: Iterable[str] = None,
+ backgrounds: Optional[torch.Tensor] = None,
+ **kwargs):
+ """Render silhouette map.
+
+ Args:
+ meshes (Optional[Meshes], optional): meshes to be rendered.
+ Require the textures type is `TexturesClosest`.
+ The color indicates the class index of the triangle.
+ Defaults to None.
+ cameras (Optional[MMCamerasBase], optional): cameras for render.
+ Defaults to None.
+ indexes (Optional[Iterable[int]], optional): indexes for images.
+ Defaults to None.
+ backgrounds (Optional[torch.Tensor], optional): background images.
+ Defaults to None.
+
+ Returns:
+ Union[torch.Tensor, None]: return tensor or None.
+ """
+ meshes = meshes.to(self.device)
+ self._update_resolution(cameras, **kwargs)
+ fragments = self.rasterizer(meshes_world=meshes, cameras=cameras)
+ silhouette_map = self.shader(fragments=fragments,
+ meshes=meshes,
+ cameras=cameras)
+
+ if self.output_path is not None:
+ rgba = self.tensor2rgba(silhouette_map)
+ self._write_images(rgba, backgrounds, indexes)
+
+ return silhouette_map
+
+ def tensor2rgba(self, tensor: torch.Tensor):
+ silhouette = tensor[..., 3:]
+ rgbs = silhouette.repeat(1, 1, 1, 3)
+ valid_masks = (silhouette > 0) * 1.0
+ rgbs = normalize(rgbs, out_value_range=(0, 1))
+ return torch.cat([rgbs, valid_masks], -1)
diff --git a/detrsmpl/core/renderer/torch3d_renderer/smpl_renderer.py b/detrsmpl/core/renderer/torch3d_renderer/smpl_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..870d0dde2c5d2e0192856a08519eb8b5cd99f8f7
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/smpl_renderer.py
@@ -0,0 +1,279 @@
+import os.path as osp
+from pathlib import Path
+from typing import Iterable, Optional, Tuple, Union
+
+import cv2
+import mmcv
+import numpy as np
+import torch
+from pytorch3d.structures import Meshes
+from torch.nn.functional import interpolate
+
+from detrsmpl.core.cameras import MMCamerasBase
+from detrsmpl.utils.ffmpeg_utils import images_to_array
+from detrsmpl.utils.path_utils import check_path_suffix
+from .base_renderer import BaseRenderer
+from .builder import build_renderer
+from .lights import DirectionalLights, PointLights
+from .utils import align_input_to_padded, normalize, rgb2bgr, tensor2array
+
+
+class SMPLRenderer(BaseRenderer):
+ """Render SMPL(X) with different render choices."""
+ def __init__(self,
+ resolution: Tuple[int, int] = None,
+ device: Union[torch.device, str] = 'cpu',
+ output_path: Optional[str] = None,
+ return_tensor: bool = False,
+ alpha: float = 1.0,
+ out_img_format: str = '%06d.png',
+ read_img_format: str = None,
+ render_choice='mq',
+ frames_folder: Optional[str] = None,
+ plot_kps: bool = False,
+ vis_kp_index: bool = False,
+ final_resolution: Tuple[int, int] = None,
+ **kwargs) -> None:
+ super(BaseRenderer, self).__init__()
+
+ self.device = device
+ self.resolution = resolution
+ self.render_choice = render_choice
+ self.output_path = output_path
+ self.frames_folder = frames_folder
+ self.plot_kps = plot_kps
+ self.vis_kp_index = vis_kp_index
+ self.read_img_format = read_img_format
+ self.out_img_format = out_img_format
+ self.final_resolution = final_resolution
+ self.return_tensor = return_tensor
+ if output_path is not None:
+ if check_path_suffix(output_path, ['.mp4', '.gif']):
+ self.temp_path = osp.join(
+ Path(output_path).parent,
+ Path(output_path).name + '_output_temp')
+ mmcv.mkdir_or_exist(self.temp_path)
+ print('make dir', self.temp_path)
+ else:
+ self.temp_path = output_path
+
+ self.image_renderer = build_renderer(
+ dict(device=device, resolution=resolution, **kwargs))
+
+ if plot_kps:
+ self.alpha = max(min(0.8, alpha), 0.1)
+ self.joints_renderer = build_renderer(
+ dict(type='pointcloud',
+ resolution=resolution,
+ device=device,
+ radius=0.008))
+ else:
+ self.alpha = max(min(1.0, alpha), 0.1)
+ """
+ Render Mesh for SMPL and SMPL-X. For function render_smpl.
+ 2 modes: mesh render with different quality and palette,
+ or silhouette render.
+
+ Args:
+ resolution (Iterable[int]): (height, width of render images)
+ faces (Union[np.ndarray, torch.LongTensor]): face of mesh to
+ be rendered.
+ device (torch.device, optional): cuda or cpu device.
+ Defaults to torch.device('cpu').
+ output_path (Optional[str], optional): render output path.
+ could be .mp4 or .gif or a folder.
+ Else: 1). If `render_choice` in ['lq', 'mq', 'hq'], the output
+ video will be a smpl mesh video which each person in a single
+ color.
+ 2). If `render_choice` is `silhouette`, the output video will
+ be a black-white smpl silhouette video.
+ 3). If `render_choice` is `part_silhouette`, the output video
+ will be a smpl mesh video which each body-part in a single
+ color.
+ If None, no video will be wrote.
+ Defaults to None.
+ palette (Optional[List[str]], optional):
+ List of palette string. Defaults to ['blue'].
+ return_tensor (bool, optional): Whether return tensors.
+ return None if set to False.
+ Defaults to False.
+ alpha (float, optional): transparency value, from 0.0 to 1.0.
+ Defaults to 1.0.
+
+ Returns:
+ None
+ """
+
+ def to(self, device):
+ return super(BaseRenderer, self).to(device)
+
+ def forward(
+ self,
+ meshes: Meshes,
+ cameras: Optional[MMCamerasBase] = None,
+ images: Optional[torch.Tensor] = None,
+ joints: Optional[torch.Tensor] = None,
+ joints_gt: Optional[torch.Tensor] = None,
+ indexes: Optional[Iterable[int]] = None,
+ **kwargs,
+ ) -> Union[None, torch.Tensor]:
+ """Forward render procedure.
+
+ Args:
+ vertices (torch.Tensor): shape should be (frame, num_V, 3) or
+ (frame, num_people, num_V, 3). Num people Would influence
+ the visualization.
+ images (Optional[torch.Tensor], optional): Tensor of background
+ images. If None, no background.
+ Defaults to None.
+ joints (Optional[torch.Tensor], optional):
+ joints produced from smpl model.
+ Defaults to None.
+ joints_gt (Optional[torch.Tensor], optional):
+ ground-truth points passed.
+ Defaults to None.
+ indexes (Optional[Iterable[int]], optional):
+ indexes for writing images.
+ Defaults to None.
+
+ Returns:
+ Union[None, torch.Tensor]:
+ return None if not return_tensor.
+ Else: 1). If render images, the output tensor shape would be
+ (frame, h, w, 4) or (frame, num_people, h, w, 4), depends on
+ number of people.
+ 2). If render silhouette, the output tensor shape will be
+ (frame, h, w) or (frame, num_people, h, w).
+ 3). If render part silhouette, the output tensor shape should
+ be (frame, h, w, 1) or (frame, num_people, h, w, 1
+ ).
+ """
+ num_frames = len(meshes)
+ if self.frames_folder is not None and images is None:
+
+ images = images_to_array(self.frames_folder,
+ resolution=self.resolution,
+ img_format=self.read_img_format,
+ start=indexes[0],
+ end=indexes[-1] + 1,
+ disable_log=True).astype(np.float64)
+ images = torch.Tensor(images).to(self.device)
+ images = align_input_to_padded(
+ images,
+ ndim=4,
+ batch_size=num_frames,
+ padding_mode='ones',
+ )
+ if images is not None:
+ images = images.to(self.device)
+
+ lights = getattr(self.image_renderer, 'lights', None)
+ if isinstance(lights, DirectionalLights):
+ lights = lights.clone()
+ lights.direction = -cameras.get_camera_plane_normals()
+ elif isinstance(lights, PointLights):
+ lights = lights.clone()
+ lights.location = -cameras.get_camera_plane_normals(
+ ) - cameras.get_camera_center()
+
+ rendered_tensor = self.image_renderer(meshes=meshes,
+ cameras=cameras,
+ lights=lights,
+ indexes=indexes)
+
+ rendered_images = self.image_renderer.tensor2rgba(rendered_tensor)
+
+ rgbs = rendered_images[..., :3]
+ valid_masks = rendered_images[..., 3:]
+ images = normalize(images,
+ origin_value_range=[0, 255],
+ out_value_range=[0, 1],
+ dtype=torch.float32) if images is not None else None
+
+ bgrs = rgb2bgr(rgbs)
+
+ # write temp images for the output video
+ if self.output_path is not None:
+
+ if images is not None:
+ output_images = bgrs * valid_masks * self.alpha + \
+ images * valid_masks * (
+ 1 - self.alpha) + (1 - valid_masks) * images
+
+ else:
+ output_images = bgrs
+
+ if self.plot_kps:
+
+ joints = joints.to(self.device)
+ joints_2d = cameras.transform_points_screen(
+ joints, image_size=self.resolution)[..., :2]
+ if joints_gt is None:
+ joints_padded = joints
+ num_joints = joints_padded.shape[1]
+ joints_rgb_padded = torch.ones(
+ num_frames, num_joints, 4) * (torch.tensor(
+ [0.0, 1.0, 0.0, 1.0]).view(1, 1, 4))
+ else:
+ joints_gt = joints_gt.to(self.device)
+ joints_padded = torch.cat([joints, joints_gt], dim=1)
+ num_joints = joints.shape[1]
+ num_joints_gt = joints_gt.shape[1]
+ joints_rgb = torch.ones(num_frames, num_joints, 4) * (
+ torch.tensor([0.0, 1.0, 0.0, 1.0]).view(1, 1, 4))
+ joints_rgb_gt = torch.ones(
+ num_frames, num_joints_gt, 4) * (torch.tensor(
+ [1.0, 0.0, 0.0, 1.0]).view(1, 1, 4))
+ joints_rgb_padded = torch.cat([joints_rgb, joints_rgb_gt],
+ dim=1)
+
+ pointcloud_images = self.joints_renderer(
+ vertices=joints_padded,
+ verts_rgba=joints_rgb_padded.to(self.device),
+ cameras=cameras)
+
+ pointcloud_rgb = pointcloud_images[..., :3]
+ pointcloud_bgr = rgb2bgr(pointcloud_rgb)
+ pointcloud_mask = (pointcloud_images[..., 3:] > 0) * 1.0
+ output_images = output_images * (
+ 1 - pointcloud_mask) + pointcloud_mask * pointcloud_bgr
+
+ output_images = tensor2array(output_images)
+
+ for frame_idx, real_idx in enumerate(indexes):
+ folder = self.temp_path if self.temp_path is not None else\
+ self.output_path
+ im = output_images[frame_idx]
+ if self.plot_kps and self.vis_kp_index:
+ point_xy = joints_2d[frame_idx]
+ for j_idx in range(point_xy.shape[-2]):
+ x = point_xy[j_idx, 0]
+ y = point_xy[j_idx, 1]
+ cv2.putText(im, str(j_idx), (int(x), int(y)),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 0.25 * self.final_resolution[1] / 500,
+ [0, 0, 0],
+ int(1 * self.final_resolution[1] / 1000))
+ if self.final_resolution != self.resolution:
+ im = cv2.resize(im, self.final_resolution, cv2.INTER_CUBIC)
+ # cv2.imwrite(osp.join(folder, self.out_img_format % real_idx),
+ # im)
+ # import ipdb;ipdb.set_trace()
+ # cv2.imwrite(self.output_path+'temp.jpg', im)
+ cv2.imwrite(self.output_path, im)
+
+ # return
+ if self.return_tensor:
+
+ if images is not None:
+ rendered_map = torch.tensor(output_images)
+ else:
+ rendered_map = rendered_tensor
+
+ if self.final_resolution != self.resolution:
+ rendered_map = interpolate(rendered_map,
+ size=self.final_resolution,
+ mode='bilinear')
+ return rendered_map
+ else:
+ return output_images
diff --git a/detrsmpl/core/renderer/torch3d_renderer/textures/__init__.py b/detrsmpl/core/renderer/torch3d_renderer/textures/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f7c950fc41d29935d34dbc8d7daa585da2b2f42
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/textures/__init__.py
@@ -0,0 +1,10 @@
+# yapf: disable
+from .builder import ( # noqa:F401
+ TexturesAtlas,
+ TexturesNearest,
+ TexturesUV,
+ TexturesVertex,
+ build_textures,
+)
+
+# yapf: enable
diff --git a/detrsmpl/core/renderer/torch3d_renderer/textures/builder.py b/detrsmpl/core/renderer/torch3d_renderer/textures/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f91ccad7747711dd15774d598134730166e60224
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/textures/builder.py
@@ -0,0 +1,22 @@
+from mmcv.utils import Registry
+from pytorch3d.renderer import TexturesAtlas, TexturesUV, TexturesVertex
+
+from .textures import TexturesNearest
+
+TEXTURES = Registry('textures')
+TEXTURES.register_module(
+ name=['TexturesAtlas', 'textures_atlas', 'atlas', 'Atlas'],
+ module=TexturesAtlas)
+TEXTURES.register_module(
+ name=['TexturesNearest', 'textures_nearest', 'nearest', 'Nearest'],
+ module=TexturesNearest)
+TEXTURES.register_module(name=['TexturesUV', 'textures_uv', 'uv'],
+ module=TexturesUV)
+TEXTURES.register_module(
+ name=['TexturesVertex', 'textures_vertex', 'vertex', 'vc'],
+ module=TexturesVertex)
+
+
+def build_textures(cfg):
+ """Build textures."""
+ return TEXTURES.build(cfg)
diff --git a/detrsmpl/core/renderer/torch3d_renderer/textures/textures.py b/detrsmpl/core/renderer/torch3d_renderer/textures/textures.py
new file mode 100644
index 0000000000000000000000000000000000000000..264e8f60cb225ae59719b5ee7fca2689b3aa9962
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/textures/textures.py
@@ -0,0 +1,23 @@
+import torch
+from pytorch3d.ops import interpolate_face_attributes
+from pytorch3d.renderer import TexturesVertex
+
+
+class TexturesNearest(TexturesVertex):
+ """Textures for nearest interpolation."""
+ def sample_textures(self, fragments, faces_packed=None) -> torch.Tensor:
+ """Rewrite sample_textures to use the nearest interpolation.
+
+ This function will only be called in render forwarding.
+ """
+ verts_features_packed = self.verts_features_packed()
+ faces_verts_features = verts_features_packed[faces_packed]
+ bary_coords = fragments.bary_coords
+ _, idx = torch.max(bary_coords, -1)
+ mask = torch.arange(bary_coords.size(-1)).reshape(1, 1, -1).to(
+ self.device) == idx.unsqueeze(-1)
+ bary_coords *= 0
+ bary_coords[mask] = 1
+ texels = interpolate_face_attributes(fragments.pix_to_face,
+ bary_coords, faces_verts_features)
+ return texels
diff --git a/detrsmpl/core/renderer/torch3d_renderer/utils.py b/detrsmpl/core/renderer/torch3d_renderer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..26374dbe09bc12cd09fe33d9b12cfd618d803836
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/utils.py
@@ -0,0 +1,113 @@
+from typing import List, Union
+
+import numpy as np
+import torch
+from pytorch3d.structures import list_to_padded
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+
+def normalize(value,
+ origin_value_range=None,
+ out_value_range=(0, 1),
+ dtype=None,
+ clip=False) -> Union[torch.Tensor, np.ndarray]:
+ """Normalize the tensor or array and convert dtype."""
+ if origin_value_range is not None:
+ value = (value - origin_value_range[0]) / (
+ origin_value_range[1] - origin_value_range[0] + 1e-9)
+
+ else:
+ value = (value - value.min()) / (value.max() - value.min())
+ value = value * (out_value_range[1] -
+ out_value_range[0]) + out_value_range[0]
+ if clip:
+ value = torch.clip(value,
+ min=out_value_range[0],
+ max=out_value_range[1])
+ if isinstance(value, torch.Tensor):
+ if dtype is not None:
+ return value.type(dtype)
+ else:
+ return value
+ elif isinstance(value, np.ndarray):
+ if dtype is not None:
+ return value.astype(dtype)
+ else:
+ return value
+
+
+def tensor2array(image: torch.Tensor) -> np.ndarray:
+ """Convert image tensor to array."""
+ image = image.detach().cpu().numpy()
+ image = normalize(image,
+ origin_value_range=(0, 1),
+ out_value_range=(0, 255),
+ dtype=np.uint8)
+ return image
+
+
+def array2tensor(image: np.ndarray) -> torch.Tensor:
+ """Convert image array to tensor."""
+ image = torch.Tensor(image)
+ image = normalize(image,
+ origin_value_range=(0, 255),
+ out_value_range=(0, 1),
+ dtype=torch.float32)
+ return image
+
+
+def rgb2bgr(rgbs) -> Union[torch.Tensor, np.ndarray]:
+ """Convert color channels."""
+ bgrs = [rgbs[..., 2, None], rgbs[..., 1, None], rgbs[..., 0, None]]
+ if isinstance(rgbs, torch.Tensor):
+ bgrs = torch.cat(bgrs, -1)
+ elif isinstance(rgbs, np.ndarray):
+ bgrs = np.concatenate(bgrs, -1)
+ return bgrs
+
+
+def align_input_to_padded(tensor=Union[List[torch.Tensor], torch.Tensor],
+ ndim: int = 3,
+ batch_size: int = None,
+ padding_mode: Literal['ones', 'zeros', 'repeat',
+ 'none'] = 'none'):
+ if isinstance(tensor, list):
+ for i in range(len(tensor)):
+ if tensor[i].dim == ndim:
+ tensor[i] = tensor[i][0]
+ tensor = list_to_padded(tensor, equisized=True)
+ assert tensor.ndim in (ndim, ndim - 1)
+ if tensor.ndim == ndim - 1:
+ tensor = tensor.unsqueeze(0)
+
+ if batch_size is not None:
+ current_batch_size = tensor.shape[0]
+ if current_batch_size == 1:
+ tensor = tensor.repeat_interleave(batch_size, 0)
+ elif current_batch_size < batch_size:
+ if padding_mode == 'ones':
+ tensor = torch.cat([
+ tensor,
+ torch.ones_like(tensor)[:1].repeat_interleave(
+ batch_size - current_batch_size, 0)
+ ])
+ elif padding_mode == 'ones':
+ tensor = torch.cat([
+ tensor,
+ torch.zeros_like(tensor)[:1].repeat_interleave(
+ batch_size - current_batch_size, 0)
+ ])
+ elif padding_mode == 'repeat':
+ tensor = tensor.repeat_interleave(
+ batch_size // current_batch_size + 1, 0)[:batch_size]
+ else:
+ raise ValueError('Wrong batch_size to allocate,'
+ ' please specify padding mode.')
+ elif current_batch_size > batch_size:
+ tensor = tensor[:batch_size]
+
+ return tensor
diff --git a/detrsmpl/core/renderer/torch3d_renderer/uv_renderer.py b/detrsmpl/core/renderer/torch3d_renderer/uv_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a8e667d1a40504cf9e7b7004e79e6102b0bb3bc
--- /dev/null
+++ b/detrsmpl/core/renderer/torch3d_renderer/uv_renderer.py
@@ -0,0 +1,520 @@
+import warnings
+from typing import Iterable, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from pytorch3d.io.obj_io import load_objs_as_meshes
+from pytorch3d.ops import interpolate_face_attributes
+from pytorch3d.renderer.mesh import TexturesUV
+from pytorch3d.renderer.mesh.rasterizer import (
+ MeshRasterizer,
+ RasterizationSettings,
+)
+from pytorch3d.structures import Meshes
+from pytorch3d.structures.utils import padded_to_packed
+
+from detrsmpl.core.cameras.cameras import (
+ FoVOrthographicCameras,
+ MMCamerasBase,
+)
+from detrsmpl.utils.path_utils import check_path_suffix
+from .utils import array2tensor, rgb2bgr
+
+
+class UVRenderer(nn.Module):
+ """Renderer for SMPL(x) UV map."""
+ def __init__(
+ self,
+ resolution: Tuple[int] = 1024,
+ model_type: Optional[str] = 'smpl',
+ uv_param_path: Optional[str] = None,
+ obj_path: Optional[str] = None,
+ device: Union[torch.device, str] = 'cpu',
+ threshold_size: int = 512,
+ # TODO: Solved the sample bug when the resolution is too small.
+ # set threshold_size is just a temporary solution.
+
+ # TODO: add smplx_uv.npz and eval the warping & sampling of smplx
+ # model.
+ ):
+ super().__init__()
+ self.threshold_size = threshold_size
+ num_verts = {'smpl': 6890, 'smplx': 10475}
+ self.NUM_VERTS = num_verts[model_type]
+ self.device = device
+ self.resolution = (resolution, resolution) if isinstance(
+ resolution, int) else resolution
+ self.uv_param_path = uv_param_path
+ self.obj_path = obj_path
+ if uv_param_path is not None:
+ check_path_suffix(uv_param_path, allowed_suffix=['npz'])
+ param_dict = dict(np.load(uv_param_path))
+
+ verts_uv = torch.Tensor(param_dict['verts_uv'])
+ verts_u, verts_v = torch.unbind(verts_uv, -1)
+ verts_v_ = 1 - verts_u.unsqueeze(-1)
+ verts_u_ = verts_v.unsqueeze(-1)
+ self.verts_uv = torch.cat([verts_u_, verts_v_], -1).to(self.device)
+ self.faces_uv = torch.LongTensor(param_dict['faces_uv']).to(
+ self.device)
+
+ self.NUM_VT = self.verts_uv.shape[0]
+
+ self.faces_tensor = torch.LongTensor(param_dict['faces'].astype(
+ np.int64)).to(self.device)
+ self.num_faces = self.faces_uv.shape[0]
+ elif obj_path is not None:
+ check_path_suffix(obj_path, allowed_suffix=['obj'])
+ mesh_template = load_objs_as_meshes([obj_path])
+ self.faces_uv = mesh_template.textures.faces_uvs_padded()[0].to(
+ self.device)
+ self.verts_uv = mesh_template.textures.verts_uvs_padded()[0].to(
+ self.device)
+ self.NUM_VT = self.verts_uv.shape[0]
+ self.faces_tensor = mesh_template.faces_padded()[0].to(self.device)
+ self.num_faces = self.faces_uv.shape[0]
+ self.update_fragments()
+ self.update_face_uv_pixel()
+
+ self = self.to(self.device)
+
+ def to(self, device):
+ if isinstance(device, str):
+ device = torch.device(device)
+ self.device = device
+ for k in dir(self):
+ if isinstance(getattr(self, k), (torch.Tensor)):
+ setattr(self, k, getattr(self, k).to(device))
+ return self
+
+ def update_fragments(self):
+ """Update pix_to_face, bary_coords."""
+ rasterizer = MeshRasterizer(cameras=FoVOrthographicCameras(
+ min_x=1, max_x=0, max_y=1, min_y=0, device=self.device),
+ raster_settings=RasterizationSettings(
+ blur_radius=0,
+ image_size=self.resolution,
+ faces_per_pixel=1,
+ perspective_correct=False,
+ )).to(self.device)
+ verts_uv = torch.cat([
+ self.verts_uv[None],
+ torch.ones(1, self.NUM_VT, 1).to(self.device)
+ ], -1)
+
+ fragments = rasterizer(
+ Meshes(verts=verts_uv, faces=self.faces_uv[None]))
+ self.pix_to_face = fragments.pix_to_face[0, ..., 0]
+ self.bary_coords = fragments.bary_coords[0, ..., 0, :]
+ self.mask = (self.pix_to_face >= 0).long()
+
+ def update_face_uv_pixel(self):
+ """Move the pixels lie on the edges inside the mask, then refine the
+ rest points by searching the nearest pixel in the faces it should be
+ in."""
+ H, W = self.resolution
+ device = self.device
+ cameras = FoVOrthographicCameras(min_x=1,
+ max_x=0,
+ max_y=1,
+ min_y=0,
+ device=self.device)
+ verts_uv = torch.cat([
+ self.verts_uv[None],
+ torch.ones(1, self.NUM_VT, 1).to(self.device)
+ ], -1)
+
+ verts_uv_pixel = cameras.transform_points_screen(
+ verts_uv, image_size=self.resolution).round().long()[0, ..., :2]
+ verts_uv_pixel[..., 0] = torch.clip(verts_uv_pixel[..., 0],
+ min=0,
+ max=W - 1)
+ verts_uv_pixel[..., 1] = torch.clip(verts_uv_pixel[..., 1],
+ min=0,
+ max=H - 1)
+ verts_uv_pixel = verts_uv_pixel.long()
+ mask = self.mask
+
+ wrong_indexes = torch.where(
+ mask[verts_uv_pixel.view(-1, 2)[:, 1],
+ verts_uv_pixel.view(-1, 2)[:, 0]] == 0)[0]
+ for wrong_index in wrong_indexes:
+ proposed_faces = torch.where(self.faces_uv == wrong_index)[0]
+ vert_xy = verts_uv_pixel[wrong_index]
+ faces_xy = []
+ for face_id in proposed_faces:
+ x = torch.where(self.pix_to_face == face_id)[1]
+ y = torch.where(self.pix_to_face == face_id)[0]
+ if x.shape[0] > 0:
+ face_xy = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1)], -1)
+ faces_xy.append(face_xy)
+ if len(faces_xy) > 0:
+ faces_xy = torch.cat(faces_xy, 0)
+ min_arg = torch.argmin(
+ torch.sqrt(((faces_xy - vert_xy) *
+ (faces_xy - vert_xy)).sum(-1).float()))
+
+ verts_uv_pixel[wrong_index] = faces_xy[min_arg]
+
+ up_bound = ((mask[:-1] - mask[1:]) < 0).long()
+ bottom_bound = ((mask[1:] - mask[:-1]) < 0).long()
+ left_bound = ((mask[:, :-1] - mask[:, 1:]) < 0).long()
+ right_bound = ((mask[:, 1:] - mask[:, :-1]) < 0).long()
+
+ left_bound = torch.cat(
+ [left_bound, torch.zeros(H, 1).to(device)], 1).unsqueeze(-1)
+ right_bound = torch.cat([torch.zeros(H, 1).to(device), right_bound],
+ 1).unsqueeze(-1)
+ up_bound = torch.cat([up_bound, torch.zeros(1, W).to(device)],
+ 0).unsqueeze(-1)
+ bottom_bound = torch.cat([torch.zeros(1, W).to(device), bottom_bound],
+ 0).unsqueeze(-1)
+
+ leftup_corner_ = ((mask[:-1, :-1] - mask[1:, 1:]) < 0).long()
+ rightup_corner_ = ((mask[:-1, 1:] - mask[1:, :-1]) < 0).long()
+ leftbottom_corner_ = ((mask[1:, :-1] - mask[:-1, 1:]) < 0).long()
+ rightbottom_corner_ = ((mask[1:, 1:] - mask[:-1, :-1]) < 0).long()
+
+ leftup_corner = torch.zeros_like(mask).long()
+ leftup_corner[:-1, :-1] = leftup_corner_
+ leftup_corner = leftup_corner.unsqueeze(-1)
+
+ rightup_corner = torch.zeros_like(mask).long()
+ rightup_corner[:-1, 1:] = rightup_corner_
+ rightup_corner = rightup_corner.unsqueeze(-1)
+
+ leftbottom_corner = torch.zeros_like(mask).long()
+ leftbottom_corner[1:, :-1] = leftbottom_corner_
+ leftbottom_corner = leftbottom_corner.unsqueeze(-1)
+
+ rightbottom_corner = torch.zeros_like(mask).long()
+ rightbottom_corner[1:, 1:] = rightbottom_corner_
+ rightbottom_corner = rightbottom_corner.unsqueeze(-1)
+
+ stride_uv_mask = torch.cat([
+ right_bound * -1 + left_bound * 1 + rightbottom_corner * -1 +
+ leftbottom_corner * 1 + rightup_corner * -1 + leftup_corner * 1,
+ up_bound * 1 + bottom_bound * -1 + rightbottom_corner * -1 +
+ leftbottom_corner * -1 + rightup_corner * 1 + leftup_corner * 1
+ ], -1).long()
+
+ verts_uv_pixel = verts_uv_pixel + stride_uv_mask[
+ verts_uv_pixel.view(-1, 2)[:, 1],
+ verts_uv_pixel.view(-1, 2)[:, 0]].view(self.NUM_VT, 2)
+
+ face_uv_pixel = verts_uv_pixel[self.faces_uv]
+
+ face_uv_pixel = face_uv_pixel.long()
+ self.face_uv_pixel = face_uv_pixel
+
+ def forward(self,
+ verts_attr: Optional[torch.Tensor],
+ resolution: Optional[Iterable[int]] = None) -> torch.Tensor:
+ """Interpolate the vertex attributes to a map.
+
+ Args:
+ verts_attr (Optional[torch.Tensor]): shape should be (N, V, C),
+ required.
+ resolution (Optional[Iterable[int]], optional): resolution to
+ override self.resolution. If None, will use self.resolution.
+ Defaults to None.
+
+ Returns:
+ torch.Tensor: interpolated maps of (N, H, W, C)
+ """
+ if verts_attr.ndim == 2:
+ verts_attr = verts_attr[None]
+ if resolution is not None and resolution != self.resolution:
+ self.resolution = resolution
+ self.update_fragments()
+ self.update_face_uv_pixel()
+
+ bary_coords = self.bary_coords
+ pix_to_face = self.pix_to_face
+
+ N, V, C = verts_attr.shape
+ assert V == self.NUM_VERTS
+ verts_attr = verts_attr.view(N * V, C).to(self.device)
+ offset_idx = torch.arange(0, N).long() * (self.NUM_VERTS - 1)
+ faces_packed = self.faces_tensor[None].repeat(
+ N, 1, 1) + offset_idx.view(-1, 1, 1).to(self.device)
+ faces_packed = faces_packed.view(-1, 3)
+ face_attr = verts_attr[faces_packed]
+ assert face_attr.shape == (N * self.num_faces, 3, C)
+ pix_to_face = self.pix_to_face.unsqueeze(0).repeat(N, 1,
+ 1).unsqueeze(-1)
+ bary_coords = self.bary_coords[None].repeat(N, 1, 1, 1).unsqueeze(-2)
+ maps_padded = interpolate_face_attributes(
+ pix_to_face=pix_to_face.to(self.device),
+ barycentric_coords=bary_coords.to(self.device),
+ face_attributes=face_attr.to(self.device),
+ ).squeeze(-2)
+ return maps_padded
+
+ def forward_normal_map(self,
+ meshes: Meshes = None,
+ vertices: torch.Tensor = None,
+ resolution: Optional[Iterable[int]] = None,
+ cameras: MMCamerasBase = None) -> torch.Tensor:
+ """Interpolate verts normals to a normal map.
+
+ Args:
+ meshes (Meshes): input smpl mesh.
+ Will override vertices if both not None.
+ Defaults to None.
+ vertices (torch.Tensor, optional):
+ smpl vertices. Defaults to None.
+ resolution (Optional[Iterable[int]], optional): resolution to
+ override self.resolution. If None, will use self.resolution.
+ Defaults to None.
+ cameras (MMCamerasBase, optional):
+ cameras to see the mesh.
+ Defaults to None.
+ Returns:
+ torch.Tensor: Normal map of shape (N, H, W, 3)
+ """
+ if meshes is not None:
+ verts_normals = meshes.verts_normals_padded()
+ elif meshes is None and vertices is not None:
+ meshes = Meshes(verts=vertices,
+ faces=self.faces_tensor[None].repeat(
+ vertices.shape[0], 1, 1))
+ verts_normals = meshes.verts_normals_padded()
+ else:
+ raise ValueError('No valid input.')
+ verts_normals = meshes.verts_normals_padded()
+ if cameras:
+ verts_normals = cameras.get_world_to_view_transform(
+ ).transform_normals(verts_normals)
+ normal_map = self.forward(verts_attr=verts_normals,
+ resolution=resolution)
+ return normal_map
+
+ def forward_uvd_map(self,
+ meshes: Meshes = None,
+ vertices: torch.Tensor = None,
+ resolution: Optional[Iterable[int]] = None,
+ cameras: MMCamerasBase = None) -> torch.Tensor:
+ """Interpolate the verts xyz value to a uvd map.
+
+ Args:
+ meshes (Meshes): input smpl mesh.
+ Defaults to None.
+ vertices (torch.Tensor, optional):
+ smpl vertices. Will override meshes if both not None.
+ Defaults to None.
+ resolution (Optional[Iterable[int]], optional): resolution to
+ override self.resolution. If None, will use self.resolution.
+ Defaults to None.
+ cameras (MMCamerasBase, optional):
+ cameras to see the mesh.
+ Defaults to None.
+
+ Returns:
+ torch.Tensor: UVD map of shape (N, H, W, 3)
+ """
+ if vertices is not None:
+ verts_uvd = vertices
+ elif vertices is None and meshes is not None:
+ verts_uvd = meshes.verts_padded()
+ else:
+ raise ValueError('No valid input.')
+ if cameras:
+ verts_uvd = cameras.get_world_to_view_transform(
+ ).transform_normals(verts_uvd)
+ uvd_map = self.forward(verts_attr=verts_uvd, resolution=resolution)
+ return uvd_map
+
+ def vertex_resample(
+ self,
+ maps_padded: torch.Tensor,
+ h_flip: bool = False,
+ ) -> torch.Tensor:
+ """Resample the vertex attributes from a map.
+
+ Args:
+ maps_padded (torch.Tensor): shape should be (N, H, W, C). Required.
+ h_flip (bool, optional): whether flip horizontally.
+ Defaults to False.
+
+ Returns:
+ torch.Tensor: resampled vertex attributes. Shape will be (N, V, C)
+ """
+ if maps_padded.ndim == 3:
+ maps_padded = maps_padded[None]
+
+ if h_flip:
+ maps_padded = torch.flip(maps_padded, dims=[2])
+ N, H, W, C = maps_padded.shape
+
+ if H < self.threshold_size or W < self.threshold_size:
+ maps_padded = F.interpolate(
+ maps_padded.permute(0, 3, 1, 2),
+ size=(self.threshold_size, self.threshold_size),
+ mode='bicubic',
+ align_corners=False).permute(0, 2, 3, 1)
+ H, W = self.threshold_size, self.threshold_size
+ if (H, W) != self.resolution:
+ self.resolution = (H, W)
+ self.update_fragments()
+ self.update_face_uv_pixel()
+ offset_idx = torch.arange(0, N).long() * (self.NUM_VERTS - 1)
+ faces_packed = self.faces_tensor[None].repeat(
+ N, 1, 1) + offset_idx.view(-1, 1, 1).to(self.device)
+ faces_packed = faces_packed.view(-1, 3)
+
+ verts_feature_packed = torch.zeros(N * self.NUM_VERTS,
+ C).to(self.device)
+
+ face_uv_pixel = self.face_uv_pixel.view(-1, 2)
+ verts_feature_packed[
+ faces_packed] = maps_padded[:, face_uv_pixel[:, 1],
+ face_uv_pixel[:, 0]].view(
+ N * self.num_faces, 3, C)
+ verts_feature_padded = verts_feature_packed.view(N, self.NUM_VERTS, C)
+
+ return verts_feature_padded
+
+ def wrap_normal(
+ self,
+ meshes: Meshes,
+ normal: torch.Tensor = None,
+ normal_map: torch.Tensor = None,
+ ) -> Meshes:
+ """Warp a normal map or vertex normal to the input meshes.
+
+ Args:
+ meshes (Meshes): the input meshes.
+ normal (torch.Tensor, optional): vertex normal. Shape should be
+ (N, V, 3).
+ Defaults to None.
+ normal_map (torch.Tensor, optional):
+ normal map. Defaults to None.
+
+ Returns:
+ Meshes: returned meshes.
+ """
+ if normal_map is not None and normal is None:
+ normal = self.vertex_resample(normal_map)
+ elif normal_map is not None and normal is not None:
+ normal_map = None
+ elif normal_map is None and normal is None:
+ warnings.warn('Redundant input, will only take displacement.')
+ batch_size = len(meshes)
+ if normal.ndim == 2:
+ normal = normal[None]
+ assert normal.shape[1:] == (self.NUM_VERTS, 3)
+ assert normal.shape[0] in [batch_size, 1]
+
+ if normal.shape[0] == 1:
+ normal = normal.repeat(batch_size, 1, 1)
+ meshes = meshes.clone()
+
+ meshes._set_verts_normals(normal)
+ return meshes
+
+ def wrap_displacement(
+ self,
+ meshes: Meshes,
+ displacement: torch.Tensor = None,
+ displacement_map: torch.Tensor = None,
+ ) -> Meshes:
+ """Offset a vertex displacement or displacement_map to the input
+ meshes.
+
+ Args:
+ meshes (Meshes): the input meshes.
+ displacement (torch.Tensor, optional): vertex displacement.
+ shape should be (N, V, 3).
+ Defaults to None.
+ displacement_map (torch.Tensor, optional): displacement_map,
+ shape should be (N, H, W, 3).
+ Defaults to None.
+
+ Returns:
+ Meshes: returned meshes.
+ """
+ if displacement_map is not None and displacement is None:
+ displacement = self.vertex_resample(displacement_map)
+ elif displacement_map is not None and displacement is not None:
+ displacement_map = None
+ warnings.warn('Redundant input, will only take displacement.')
+ elif displacement_map is None and displacement is None:
+ raise ValueError('No valid input.')
+ batch_size = len(meshes)
+ if displacement.ndim == 2:
+ displacement = displacement[None]
+ assert displacement.shape[1] == self.NUM_VERTS
+ assert displacement.shape[0] in [batch_size, 1]
+
+ if displacement.shape[0] == 1:
+ displacement = displacement.repeat(batch_size, 1, 1)
+ C = displacement.shape[-1]
+ if C == 1:
+ displacement = meshes.verts_normals_padded() * displacement
+
+ displacement = padded_to_packed(displacement)
+
+ meshes = meshes.to(self.device)
+ meshes = meshes.offset_verts(displacement)
+ return meshes
+
+ def wrap_texture(self,
+ texture_map: torch.Tensor,
+ resolution: Optional[Iterable[int]] = None,
+ mode: Optional[str] = 'bicubic',
+ is_bgr: bool = True) -> Meshes:
+ """Wrap a texture map to the input meshes.
+
+ Args:
+ texture_map (torch.Tensor): the texture map to be wrapped.
+ Shape should be (N, H, W, 3)
+ resolution (Optional[Iterable[int]], optional): resolution to
+ override self.resolution. If None, will use self.resolution.
+ Defaults to None.
+ mode (Optional[str], optional): interpolate mode.
+ Should be in ['nearest', 'bilinear', 'trilinear', 'bicubic',
+ 'area'].
+ Defaults to 'bicubic'.
+ is_bgr (bool, optional): Whether the color channel is BGR.
+ Defaults to True.
+
+ Returns:
+ Meshes: returned meshes.
+ """
+
+ assert texture_map.shape[-1] == 3
+ if texture_map.ndim == 3:
+ texture_map_padded = texture_map[None]
+ elif texture_map.ndim == 4:
+ texture_map_padded = texture_map
+ else:
+ raise ValueError(f'Wrong texture_map shape: {texture_map.shape}.')
+ N, H, W, _ = texture_map_padded.shape
+
+ resolution = resolution if resolution is not None else (H, W)
+
+ if resolution != (H, W):
+ texture_map_padded = F.interpolate(texture_map_padded.view(
+ 0, 3, 1, 2),
+ resolution,
+ mode=mode).view(0, 2, 3, 1)
+ assert texture_map_padded.shape[0] in [N, 1]
+
+ if isinstance(texture_map_padded, np.ndarray):
+ texture_map_padded = array2tensor(texture_map_padded)
+ is_bgr = True
+ if is_bgr:
+ texture_map_padded = rgb2bgr(texture_map_padded)
+
+ if texture_map_padded.shape[0] == 1:
+ texture_map_padded = texture_map_padded.repeat(N, 1, 1, 1)
+
+ faces_uvs = self.faces_uv[None].repeat(N, 1, 1)
+ verts_uvs = self.verts_uv[None].repeat(N, 1, 1)
+ textures = TexturesUV(faces_uvs=faces_uvs,
+ verts_uvs=verts_uvs,
+ maps=texture_map_padded)
+ return textures
diff --git a/detrsmpl/core/renderer/vedo_render.py b/detrsmpl/core/renderer/vedo_render.py
new file mode 100644
index 0000000000000000000000000000000000000000..c772f61b89a2a57c4ecc4358bc981caf4c19ddfc
--- /dev/null
+++ b/detrsmpl/core/renderer/vedo_render.py
@@ -0,0 +1,107 @@
+import numpy as np
+import vedo
+from scipy.spatial.transform import Rotation as scipy_Rotation
+
+
+class VedoRenderer(object):
+ """An interactive renderer for camera visualization."""
+ def __init__(self, scale=0.03):
+ """Visualize cameras in an interactive scene supported by vedo.
+
+ Args:
+ scale (float, optional):
+ Scale factor. Defaults to 0.03.
+ """
+ self.scale = scale
+ self.axis_list = self.__init_axis()
+ self.camera_list = []
+ self.frames_dir_path = ''
+ self.y_reverse = False
+
+ def __init_axis(self, axis_len=80):
+ """Prepare arrows for axis.
+
+ Args:
+ axis_len (int, optional):
+ Length of each axis.
+ Defaults to 80.
+
+ Returns:
+ List[Arrows]:
+ A list of three arrows.
+ """
+ arrow_end_np = np.eye(3) * axis_len * self.scale
+ colors = ['r', 'g', 'b'] # r-x, g-y, b-z
+ ret_list = []
+ for axis_index in range(3):
+ ret_list.append(
+ vedo.Arrows([[0, 0, 0]],
+ [arrow_end_np[axis_index]]).c(colors[axis_index]))
+ return ret_list
+
+ def set_y_reverse(self):
+ """Set y reverse before add_camera if it is needed.
+
+ Vedo defines y+ as up direction. When visualizing kinect cameras, y- is
+ up, call set_y_reverse in this situation to make text in correct
+ direction.
+ """
+ self.y_reverse = True
+ self.y_reverse_rotation = \
+ scipy_Rotation.from_euler('z', 180, degrees=True)
+
+ def add_camera(self, camera_parameter, arrow_len=30):
+ """Add an camera to the scene.
+
+ Args:
+ camera_parameter (CameraParameter):
+ An instance of class CameraParameter which stores
+ rotation, translation and name of a camera.
+ arrow_len (int, optional):
+ Length of the arrow. Defaults to 30.
+
+ Returns:
+ list:
+ A list of vedo items related to the input camera.
+ """
+ rot_mat = np.asarray(camera_parameter.get_value('rotation_mat'))
+ translation = np.asarray(camera_parameter.get_value('translation'))
+ cam_center = -np.linalg.inv(rot_mat).dot(translation)
+ arrow_end_origin = np.eye(3) * arrow_len * self.scale
+ colors = ['r', 'g', 'b'] # r-x, g-y, b-z
+ arrow_end_camera = \
+ np.einsum('ij,kj->ki', np.linalg.inv(rot_mat), arrow_end_origin)
+ if self.y_reverse:
+ cam_center = self.y_reverse_rotation.apply(cam_center)
+ for axis_index in range(3):
+ arrow_end_camera[axis_index, :] = \
+ self.y_reverse_rotation.apply(
+ arrow_end_camera[axis_index, :]
+ )
+ vedo_list = []
+ for i in range(3):
+ vedo_list.append(
+ vedo.Arrows([cam_center],
+ [cam_center + arrow_end_camera[i]]).c(colors[i]))
+ vedo_list.append(
+ vedo.Text3D(camera_parameter.name, cam_center, s=self.scale * 10))
+ self.camera_list += vedo_list
+ return vedo_list
+
+ def show(self, with_axis=True, interactive=True):
+ """Show cameras as well as axis arrow by vedo.show()
+
+ Args:
+ with_axis (bool, optional):
+ Whether to show the axis arrow. Defaults to True.
+ interactive (bool, optional):
+ Pause and interact with window (True) or
+ continue execution (False).
+ Defaults to True.
+ """
+ list_to_show = []
+ list_to_show += self.camera_list
+ if with_axis:
+ list_to_show += self.axis_list
+ vedo.show(*list_to_show, interactive=interactive, axes=1)
+ vedo.clear()
diff --git a/detrsmpl/core/visualization/__init__.py b/detrsmpl/core/visualization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..732bb4f496c953cc12204e57c21a888362a5dbad
--- /dev/null
+++ b/detrsmpl/core/visualization/__init__.py
@@ -0,0 +1,2 @@
+from .visualize_keypoints2d import visualize_kp2d # noqa:F401
+from .visualize_keypoints3d import visualize_kp3d # noqa:F401
diff --git a/detrsmpl/core/visualization/visualize_cameras.py b/detrsmpl/core/visualization/visualize_cameras.py
new file mode 100644
index 0000000000000000000000000000000000000000..7285a19915a483cd883fb81e5c3e267e0d4e766f
--- /dev/null
+++ b/detrsmpl/core/visualization/visualize_cameras.py
@@ -0,0 +1,82 @@
+import json
+import os
+
+from detrsmpl.core.cameras.camera_parameters import CameraParameter
+from detrsmpl.core.renderer.vedo_render import VedoRenderer
+from detrsmpl.utils.path_utils import check_path_suffix
+
+
+def visualize_chessboard_kinects_rgb(chessboard_path: str,
+ interactive: bool = True,
+ show: bool = True):
+ """Visualize all the RGB cameras in a chessboard file.
+
+ Args:
+ chessboard_path (str):
+ Path to the chessboard file.
+ interactive (bool, optional):
+ Pause and interact with window (True) or
+ continue execution (False).
+ Defaults to True.
+ show (bool, optional):
+ Whether to show in a window.
+ Defaults to True.
+ """
+ # Load camera parameter from a json file
+ camera_para_json_dict = json.load(open(chessboard_path))
+ camera_para_dict = {}
+ for camera_id in camera_para_json_dict.keys():
+ try:
+ camera_id_int = int(camera_id)
+ # if camera_id is an instance of int
+ # and it can be divided by 2, it's an rgb camera
+ if camera_id_int % 2 == 0:
+ pass
+ else:
+ continue
+ except ValueError:
+ continue
+ temp_camera_parameter = CameraParameter(name=camera_id)
+ temp_camera_parameter.load_from_chessboard(
+ camera_para_json_dict[camera_id], camera_id)
+ camera_para_dict[camera_id] = temp_camera_parameter
+ camera_vedo_renderer = VedoRenderer()
+ camera_vedo_renderer.set_y_reverse()
+ for camera_id in camera_para_dict.keys():
+ camera_vedo_renderer.add_camera(camera_para_dict[camera_id])
+ if show:
+ camera_vedo_renderer.show(with_axis=False, interactive=interactive)
+
+
+def visualize_dumped_camera_parameter(dumped_dir: str,
+ interactive: bool = True,
+ show: bool = True):
+ """Visualize all cameras dumped in a directory.
+
+ Args:
+ dumped_dir (str):
+ Path to the directory.
+ interactive (bool, optional):
+ Pause and interact with window (True) or
+ continue execution (False).
+ Defaults to True.
+ show (bool, optional):
+ Whether to show in a window.
+ Defaults to True.
+ """
+ file_list = os.listdir(dumped_dir)
+ camera_para_list = []
+ for file_name in file_list:
+ file_path = os.path.join(dumped_dir, file_name)
+ if not check_path_suffix(file_path, ['.json']):
+ continue
+ else:
+ cam_para = CameraParameter()
+ cam_para.load(file_path)
+ camera_para_list.append(cam_para)
+ camera_vedo_renderer = VedoRenderer()
+ camera_vedo_renderer.set_y_reverse()
+ for camera_para in camera_para_list:
+ camera_vedo_renderer.add_camera(camera_para)
+ if show:
+ camera_vedo_renderer.show(with_axis=False, interactive=interactive)
diff --git a/detrsmpl/core/visualization/visualize_keypoints2d.py b/detrsmpl/core/visualization/visualize_keypoints2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..bff23cc1a2e2649ea9ad9c12b710774126f73437
--- /dev/null
+++ b/detrsmpl/core/visualization/visualize_keypoints2d.py
@@ -0,0 +1,610 @@
+import glob
+import os
+import os.path as osp
+import shutil
+import warnings
+from pathlib import Path
+from typing import Iterable, List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+from tqdm import tqdm
+
+from detrsmpl.core.conventions.keypoints_mapping import KEYPOINTS_FACTORY
+from detrsmpl.core.conventions.keypoints_mapping.human_data import (
+ HUMAN_DATA_LIMBS_INDEX,
+ HUMAN_DATA_PALETTE,
+)
+from detrsmpl.utils.demo_utils import get_different_colors
+from detrsmpl.utils.ffmpeg_utils import images_to_video, video_to_images
+from detrsmpl.utils.keypoint_utils import search_limbs
+from detrsmpl.utils.path_utils import (
+ Existence,
+ check_input_path,
+ check_path_existence,
+ check_path_suffix,
+ prepare_output_path,
+)
+
+
+def _plot_kp2d_frame(kp2d_person: np.ndarray,
+ canvas: np.ndarray,
+ limbs: Union[list, dict,
+ np.ndarray] = HUMAN_DATA_LIMBS_INDEX,
+ palette: Optional[Union[dict, np.ndarray]] = None,
+ draw_bbox: bool = False,
+ with_number: bool = False,
+ font_size: Union[float, int] = 0.5,
+ disable_limbs: bool = False) -> np.ndarray:
+ """Plot a single frame(array) with keypoints, limbs, bbox, index.
+
+ Args:
+ kp2d_person (np.ndarray): `np.ndarray` shape of (J * 2).
+ canvas (np.ndarray): cv2 image, (H * W * 3) array.
+ limbs (Union[list, dict, np.ndarray], optional): limbs in form of
+ `dict` or 2-dimensional `list` or `np.ndarray` of shape
+ (num_limb, 2).
+ `dict` is used mainly for function `visualize_kp2d`, you can also
+ get the limbs by function `search_limbs`.
+ Defaults to `HUMAN_DATA_LIMBS_INDEX`.
+ palette (Optional[Union[dict, np.ndarray, list]], optional):
+ Pass an (1, 3) `np.ndarray` or `list` [B, G, R] if want the whole
+ limbs and keypoints will be in same color.
+ Pass `None` to use our colorful palette.
+ Pass an (num_limb, 3) `np.ndarray` to get each limb your specific
+ color.
+ `dict` is used mainly for function `visualize_kp2d`, you can also
+ get the palette by function `search_limbs`.
+ Defaults to `HUMAN_DATA_PALETTE`.
+ draw_bbox (bool, optional): whether need to draw bounding boxes.
+ Defaults to False.
+ with_number (bool, optional): whether need to draw index numbers.
+ Defaults to False.
+ font_size (Union[float, int], optional): the font size of the index.
+ Defaults to 0.5.
+ disable_limbs (bool, optional): whether need to disable drawing limbs.
+ Defaults to False.
+
+ Returns:
+ np.ndarray: opencv image of shape (H * W * 3).
+ """
+ # slice the kp2d array
+ kp2d_person = kp2d_person.copy()
+ if kp2d_person.shape[-1] >= 3:
+ kp2d_person = kp2d_person[..., :-1]
+ warnings.warn(
+ 'The input array has more than 2-Dimensional coordinates, will'
+ 'keep only the first 2-Dimensions of the last axis. The new'
+ f'array shape: {kp2d_person.shape}')
+ if kp2d_person.ndim == 3 and kp2d_person.shape[0] == 1:
+ kp2d_person = kp2d_person[0]
+ assert kp2d_person.ndim == 2 and kp2d_person.shape[
+ -1] == 2, f'Wrong input array shape {kp2d_person.shape}, \
+ should be (num_kp, 2)'
+
+ if draw_bbox:
+ bbox = _get_bbox(kp2d_person, canvas, expand=True)
+ else:
+ bbox = None
+
+ # determine the limb connections and palette
+ if not disable_limbs:
+ if isinstance(limbs, list):
+ limbs = {'body': limbs}
+ elif isinstance(limbs, np.ndarray):
+ limbs = {'body': limbs.reshape(-1, 2).astype(np.int32).tolist()}
+ else:
+ assert set(limbs.keys()).issubset(HUMAN_DATA_LIMBS_INDEX)
+
+ if palette is None:
+ palette = {'body': None}
+ elif isinstance(palette, dict):
+ assert set(palette.keys()) == set(limbs.keys())
+ else:
+ limbs = {'body': None}
+ # draw by part to specify the thickness and color
+ for part_name, part_limbs in limbs.items():
+ # scatter_points_index means the limb end points
+ if not disable_limbs:
+ scatter_points_index = list(
+ set(np.array([part_limbs]).reshape(-1).tolist()))
+ else:
+ scatter_points_index = list(range(len(kp2d_person)))
+ if isinstance(palette, dict) and part_name == 'body':
+ thickness = 2
+ radius = 3
+ color = get_different_colors(len(scatter_points_index))
+ elif disable_limbs and palette is None:
+ radius = 2
+ color = get_different_colors(len(scatter_points_index))
+ else:
+ thickness = 2
+ radius = 2
+ if isinstance(palette, np.ndarray):
+ color = palette.astype(np.int32)
+ elif isinstance(palette, dict):
+ color = np.array(palette[part_name]).astype(np.int32)
+ elif isinstance(palette, list):
+ color = np.array(palette).reshape(-1, 3).astype(np.int32)
+ if not disable_limbs:
+ for limb_index, limb in enumerate(part_limbs):
+ limb_index = min(limb_index, len(color) - 1)
+ cv2.line(canvas,
+ tuple(kp2d_person[limb[0]].astype(np.int32)),
+ tuple(kp2d_person[limb[1]].astype(np.int32)),
+ color=tuple(color[limb_index].tolist()),
+ thickness=thickness)
+ # draw the points inside the image region
+ for index in scatter_points_index:
+ x, y = kp2d_person[index, :2]
+ if np.isnan(x) or np.isnan(y):
+ continue
+ if 0 <= x < canvas.shape[1] and 0 <= y < canvas.shape[0]:
+ if disable_limbs:
+ point_color = color[index].tolist()
+ else:
+ point_color = color[min(color.shape[0] - 1,
+ len(scatter_points_index) -
+ 1)].tolist()
+
+ cv2.circle(canvas, (int(x), int(y)),
+ radius,
+ point_color,
+ thickness=-1)
+ if with_number:
+ cv2.putText(
+ canvas, str(index), (int(x), int(y)),
+ cv2.FONT_HERSHEY_SIMPLEX, font_size,
+ np.array([255, 255, 255]).astype(np.int32).tolist(), 2)
+ # draw the bboxes
+ if bbox is not None:
+ bbox = bbox.astype(np.int32)
+ cv2.rectangle(canvas, (bbox[0], bbox[2]), (bbox[1], bbox[3]),
+ (0, 255, 255), 1)
+ return canvas
+
+
+def _get_bbox(keypoint_np: np.ndarray,
+ img_mat: Optional[np.ndarray] = None,
+ expand: bool = False):
+ """get bbox of kp2d."""
+ x_max = np.max(keypoint_np[:, 0])
+ x_min = np.min(keypoint_np[:, 0])
+ y_max = np.max(keypoint_np[:, 1])
+ y_min = np.min(keypoint_np[:, 1])
+ if expand and img_mat is not None:
+ x_expand = (x_max - x_min) * 0.1
+ y_expand = (y_max - y_min) * 0.1
+ x_min = max(0, x_min - x_expand)
+ x_max = min(img_mat.shape[1], x_max + x_expand)
+ y_min = max(0, y_min - y_expand)
+ y_max = min(img_mat.shape[0], y_max + y_expand)
+ return np.asarray([x_min, x_max, y_min, y_max])
+
+
+def _prepare_limb_palette(limbs,
+ palette,
+ pop_parts,
+ data_source,
+ mask,
+ search_limbs_func=search_limbs):
+ """Prepare limbs and their palette for plotting.
+
+ Args:
+ limbs (Union[np.ndarray, List[int]]):
+ The preset limbs. This option is for free skeletons like BVH file.
+ In most cases, it's set to None,
+ this function will search a result for limbs automatically.
+ palette (Iterable):
+ The preset palette for limbs. Specified palette,
+ three int represents (B, G, R). Should be tuple or list.
+ In most cases, it's set to None,
+ a palette will be generated with the result of search_limbs.
+ pop_parts (Iterable[str]):
+ The body part names you do not
+ want to visualize.
+ When it's none, nothing will be removed.
+ data_source (str):
+ Data source type.
+ mask (Union[list, np.ndarray):
+ A mask to mask out the incorrect points.
+
+ Returns:
+ Tuple[dict, dict]: (limbs_target, limbs_palette).
+ """
+ if limbs is not None:
+ limbs_target, limbs_palette = {
+ 'body': limbs.tolist() if isinstance(limbs, np.ndarray) else limbs
+ }, get_different_colors(len(limbs))
+ else:
+ limbs_target, limbs_palette = search_limbs_func(
+ data_source=data_source, mask=mask)
+
+ if palette:
+ limbs_palette = np.array(palette, dtype=np.uint8)[None]
+
+ # check and pop the pop_parts
+ assert set(pop_parts).issubset(
+ HUMAN_DATA_PALETTE
+ ), f'wrong part_names in pop_parts, supported parts are\
+ {set(HUMAN_DATA_PALETTE.keys())}'
+
+ for part_name in pop_parts:
+ if part_name in limbs_target:
+ limbs_target.pop(part_name)
+ limbs_palette.pop(part_name)
+ return limbs_target, limbs_palette
+
+
+def _prepare_output_path(output_path, overwrite):
+ """Prepare output path."""
+ prepare_output_path(output_path,
+ allowed_suffix=['.mp4', ''],
+ tag='output video',
+ path_type='auto',
+ overwrite=overwrite)
+ # output_path is a directory
+ if check_path_suffix(output_path, ['']):
+ temp_folder = output_path
+ os.makedirs(temp_folder, exist_ok=True)
+ else:
+ temp_folder = output_path + '_temp_images'
+ if check_path_existence(temp_folder, 'dir') in [
+ Existence.DirectoryExistNotEmpty, Existence.DirectoryExistEmpty
+ ]:
+ shutil.rmtree(temp_folder)
+ os.makedirs(temp_folder, exist_ok=True)
+ return temp_folder
+
+
+def _check_frame_path(frame_list):
+ """Check frame path."""
+ for frame_path in frame_list:
+ if check_path_existence(frame_path, 'file') != Existence.FileExist or \
+ not check_path_suffix(frame_path, ['.png', '.jpg', '.jpeg']):
+ raise FileNotFoundError(
+ f'The frame should be .png or .jp(e)g: {frame_path}')
+
+
+def _check_temp_path(temp_folder, frame_list, overwrite):
+ """Check temp frame folder path."""
+ if not overwrite and frame_list is not None and len(frame_list) > 0:
+ if Path(temp_folder).absolute() == \
+ Path(frame_list[0]).parent.absolute():
+ raise FileExistsError(
+ f'{temp_folder} exists (set --overwrite to overwrite).')
+
+
+class _CavasProducer:
+ """Prepare background canvas, pure white if not set."""
+ def __init__(self,
+ frame_list,
+ resolution,
+ kp2d=None,
+ image_array=None,
+ default_scale=1.5):
+ """Initialize a canvas writer."""
+ # check the origin background frames
+ if frame_list is not None:
+ _check_frame_path(frame_list)
+ self.frame_list = frame_list
+ else:
+ self.frame_list = []
+ self.resolution = resolution
+ self.kp2d = kp2d
+
+ # with numpy array frames
+ self.image_array = image_array
+
+ if self.resolution is None:
+ if self.image_array is not None:
+ self.auto_resolution = self.image_array.shape[1:3]
+ elif len(self.frame_list) > 1 and \
+ check_path_existence(
+ self.frame_list[0], 'file') == Existence.FileExist:
+ tmp_image_array = cv2.imread(self.frame_list[0])
+ self.auto_resolution = tmp_image_array.shape[:2]
+ else:
+
+ self.auto_resolution = [
+ int(np.max(kp2d) * default_scale),
+ int(np.max(kp2d) * default_scale)
+ ]
+ self.len = kp2d.shape[0]
+
+ if self.image_array is None:
+ self.len_frame = len(self.frame_list)
+ else:
+ self.len_frame = self.image_array.shape[0]
+
+ def __getitem__(self, frame_index):
+ """Get frame data from frame_list of image_array."""
+ # frame file exists, resolution not set
+ if frame_index < self.len_frame and self.resolution is None:
+ if self.image_array is not None:
+ canvas = self.image_array[frame_index]
+ else:
+ canvas = cv2.imread(self.frame_list[frame_index])
+ if self.kp2d is None:
+ kp2d_frame = None
+ else:
+ kp2d_frame = self.kp2d[frame_index]
+ # no frame file, resolution has been set
+ elif frame_index >= self.len_frame and self.resolution is not None:
+ canvas = np.ones((self.resolution[0], self.resolution[1], 3),
+ dtype=np.uint8) * 255
+ if self.kp2d is None:
+ kp2d_frame = None
+ else:
+ kp2d_frame = self.kp2d[frame_index]
+ # frame file exists, resolution has been set
+ elif frame_index < self.len_frame and self.resolution is not None:
+ if self.image_array is not None:
+ canvas = self.image_array[frame_index]
+ else:
+ canvas = cv2.imread(self.frame_list[frame_index])
+ w_scale = self.resolution[1] / canvas.shape[1]
+ h_scale = self.resolution[0] / canvas.shape[0]
+ canvas = cv2.resize(canvas,
+ (self.resolution[1], self.resolution[0]),
+ cv2.INTER_CUBIC)
+ if self.kp2d is None:
+ kp2d_frame = None
+ else:
+ kp2d_frame = np.array([[w_scale, h_scale]
+ ]) * self.kp2d[frame_index]
+ # no frame file, no resolution
+ else:
+ canvas = np.ones(
+ (self.auto_resolution[0], self.auto_resolution[1], 3),
+ dtype=np.uint8) * 255
+ if self.kp2d is None:
+ kp2d_frame = None
+ else:
+ kp2d_frame = self.kp2d[frame_index]
+ return canvas, kp2d_frame
+
+ def __len__(self):
+ return self.len
+
+
+def update_frame_list(frame_list, origin_frames, img_format, start, end):
+ """Update frame list if have origin_frames."""
+ input_temp_folder = None
+ # choose in frame_list or origin_frames
+ if frame_list is None and origin_frames is None:
+ print('No background provided, will use pure white background.')
+ elif frame_list is not None and origin_frames is not None:
+ warnings.warn('Redundant input, will only use frame_list.')
+ origin_frames = None
+ if origin_frames is not None:
+ check_input_path(input_path=origin_frames,
+ allowed_suffix=['.mp4', '.gif', ''],
+ tag='origin frames',
+ path_type='auto')
+ if Path(origin_frames).is_file():
+ input_temp_folder = origin_frames + '_temp_images/'
+ video_to_images(origin_frames,
+ input_temp_folder,
+ start=start,
+ end=end)
+ frame_list = glob.glob(osp.join(input_temp_folder, '*.png'))
+ frame_list.sort()
+ else:
+ if img_format is None:
+ frame_list = []
+ for im_name in os.listdir(origin_frames):
+ if Path(im_name).suffix.lower() in [
+ '.png', '.jpg', '.jpeg'
+ ]:
+ frame_list.append(osp.join(origin_frames, im_name))
+ else:
+ frame_list = []
+ for index in range(start, end):
+ frame_path = osp.join(origin_frames, img_format % index)
+ if osp.exists(frame_path):
+ frame_list.append(frame_path)
+ frame_list.sort()
+ return frame_list, input_temp_folder
+
+
+def visualize_kp2d(
+ kp2d: np.ndarray,
+ output_path: Optional[str] = None,
+ frame_list: Optional[List[str]] = None,
+ origin_frames: Optional[str] = None,
+ image_array: Optional[np.ndarray] = None,
+ limbs: Optional[Union[np.ndarray, List[int]]] = None,
+ palette: Optional[Iterable[int]] = None,
+ data_source: str = 'coco',
+ mask: Optional[Union[list, np.ndarray]] = None,
+ img_format: str = '%06d.png',
+ start: int = 0,
+ end: int = -1,
+ overwrite: bool = False,
+ with_file_name: bool = True,
+ resolution: Optional[Union[Tuple[int, int], list]] = None,
+ fps: Union[float, int] = 30,
+ draw_bbox: bool = False,
+ with_number: bool = False,
+ pop_parts: Iterable[str] = None,
+ disable_tqdm: bool = False,
+ disable_limbs: bool = False,
+ return_array: Optional[bool] = False,
+ keypoints_factory: dict = KEYPOINTS_FACTORY,
+ remove_raw_file: bool = True,
+) -> Union[None, np.ndarray]:
+ """Visualize 2d keypoints to a video or into a folder of frames.
+
+ Args:
+ kp2d (np.ndarray): should be array of shape (f * J * 2)
+ or (f * n * J * 2)]
+ output_path (str): output video path or image folder.
+ frame_list (Optional[List[str]], optional): list of origin background
+ frame paths, element in list each should be a image path like
+ `*.jpg` or `*.png`. Higher priority than `origin_frames`.
+ Use this when your file names is hard to sort or you only want to
+ render a small number frames.
+ Defaults to None.
+ origin_frames (Optional[str], optional): origin background frame path,
+ could be `.mp4`, `.gif`(will be sliced into a folder) or an image
+ folder. Lower priority than `frame_list`.
+ Defaults to None.
+ limbs (Optional[Union[np.ndarray, List[int]]], optional):
+ if not specified, the limbs will be searched by search_limbs,
+ this option is for free skeletons like BVH file.
+ Defaults to None.
+ palette (Iterable, optional): specified palette, three int represents
+ (B, G, R). Should be tuple or list.
+ Defaults to None.
+ data_source (str, optional): data source type. Defaults to 'coco'.
+ mask (Optional[Union[list, np.ndarray]], optional):
+ mask to mask out the incorrect point.
+ Pass a `np.ndarray` of shape (J,) or `list` of length J.
+ Defaults to None.
+ img_format (str, optional): input image format. Default to '%06d.png',
+ start (int, optional): start frame index. Defaults to 0.
+ end (int, optional): end frame index. Defaults to -1.
+ overwrite (bool, optional): whether replace the origin frames.
+ Defaults to False.
+ with_file_name (bool, optional): whether write origin frame name on
+ the images. Defaults to True.
+ resolution (Optional[Union[Tuple[int, int], list]], optional):
+ (height, width) of the output video
+ will be the same size as the original images if not specified.
+ Defaults to None.
+ fps (Union[float, int], optional): fps. Defaults to 30.
+ draw_bbox (bool, optional): whether need to draw bounding boxes.
+ Defaults to False.
+ with_number (bool, optional): whether draw index number.
+ Defaults to False.
+ pop_parts (Iterable[str], optional): The body part names you do not
+ want to visualize. Supported parts are ['left_eye','right_eye'
+ ,'nose', 'mouth', 'face', 'left_hand', 'right_hand'].
+ Defaults to [].frame_list
+ disable_tqdm (bool, optional):
+ Whether to disable the entire progressbar wrapper.
+ Defaults to False.
+ disable_limbs (bool, optional): whether need to disable drawing limbs.
+ Defaults to False.
+ return_array (bool, optional): Whether to return images as a opencv
+ array. Defaults to None.
+ keypoints_factory (dict, optional): Dict of all the conventions.
+ Defaults to KEYPOINTS_FACTORY.
+
+ Raises:
+ FileNotFoundError: check output video path.
+ FileNotFoundError: check input frame paths.
+
+ Returns:
+ Union[None, np.ndarray].
+ """
+
+ # check the input array shape, reshape to (num_frames, num_person, J, 2)
+ kp2d = kp2d[..., :2].copy()
+ if kp2d.ndim == 3:
+ kp2d = kp2d[:, np.newaxis]
+ assert kp2d.ndim == 4
+ num_frames, num_person = kp2d.shape[0], kp2d.shape[1]
+ # slice the input array temporally
+ end = (min(num_frames - 1, end) + num_frames) % num_frames
+ kp2d = kp2d[start:end + 1]
+
+ if image_array is not None:
+ origin_frames = None
+ frame_list = None
+ return_array = True
+ input_temp_folder = None
+ else:
+ frame_list, input_temp_folder = update_frame_list(
+ frame_list, origin_frames, img_format, start, end)
+
+ kp2d = kp2d[:num_frames]
+ # check output path
+ if output_path is not None:
+ output_temp_folder = _prepare_output_path(output_path, overwrite)
+ # check whether temp_folder will overwrite frame_list by accident
+ _check_temp_path(output_temp_folder, frame_list, overwrite)
+ else:
+ output_temp_folder = None
+
+ # check data_source & mask
+ if data_source not in keypoints_factory:
+ raise ValueError('Wrong data_source. Should choose in'
+ f'{list(keypoints_factory.keys())}')
+ if mask is not None:
+ if isinstance(mask, list):
+ mask = np.array(mask).reshape(-1)
+ assert mask.shape == (
+ len(keypoints_factory[data_source]),
+ ), f'mask length should fit with keypoints number \
+ {len(keypoints_factory[data_source])}'
+
+ # search the limb connections and palettes from superset smplx
+ # check and pop the pop_parts
+ if pop_parts is None:
+ pop_parts = []
+
+ if disable_limbs:
+ limbs_target, limbs_palette = None, None
+ else:
+ # *** changed by wyj ***
+ limbs_target, limbs_palette = _prepare_limb_palette(
+ limbs, palette, pop_parts, data_source, mask)
+ # limbs_target, limbs_palette = limbs, palette
+ canvas_producer = _CavasProducer(frame_list, resolution, kp2d, image_array)
+
+ out_image_array = []
+ # start plotting by frame
+ for frame_index in tqdm(range(kp2d.shape[0]), disable=disable_tqdm):
+ canvas, kp2d_frame = canvas_producer[frame_index]
+ # start plotting by person
+ for person_index in range(num_person):
+ if num_person >= 2 and not disable_limbs:
+ limbs_palette = get_different_colors(
+ num_person)[person_index].reshape(1, 3)
+ canvas = _plot_kp2d_frame(kp2d_person=kp2d_frame[person_index],
+ canvas=canvas,
+ limbs=limbs_target,
+ palette=limbs_palette,
+ draw_bbox=draw_bbox,
+ with_number=with_number,
+ font_size=0.5,
+ disable_limbs=disable_limbs)
+ if with_file_name and frame_list is not None:
+ h, w, _ = canvas.shape
+ if frame_index <= len(frame_list) - 1:
+ cv2.putText(
+ canvas, str(Path(frame_list[frame_index]).name),
+ (w // 2, h // 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5 * h / 500,
+ np.array([255, 255, 255]).astype(np.int32).tolist(), 2)
+ if output_path is not None:
+ # write the frame with opencv
+ if frame_list is not None and check_path_suffix(
+ output_path,
+ '') and len(frame_list) >= len(canvas_producer):
+ frame_path = os.path.join(output_temp_folder,
+ Path(frame_list[frame_index]).name)
+ img_format = None
+ else:
+ frame_path = \
+ os.path.join(output_temp_folder, f'{frame_index:06d}.png')
+ img_format = '%06d.png'
+ cv2.imwrite(frame_path, canvas)
+ if return_array:
+ out_image_array.append(canvas[None])
+
+ if input_temp_folder is not None:
+ shutil.rmtree(input_temp_folder)
+ # convert frames to video
+ if output_path is not None:
+ if check_path_suffix(output_path, ['.mp4']):
+ images_to_video(input_folder=output_temp_folder,
+ output_path=output_path,
+ remove_raw_file=remove_raw_file,
+ img_format=img_format,
+ fps=fps)
+
+ if return_array:
+ out_image_array = np.concatenate(out_image_array)
+ return out_image_array
diff --git a/detrsmpl/core/visualization/visualize_keypoints3d.py b/detrsmpl/core/visualization/visualize_keypoints3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd7b2579084b781f5716a6301649861f2834a988
--- /dev/null
+++ b/detrsmpl/core/visualization/visualize_keypoints3d.py
@@ -0,0 +1,218 @@
+import warnings
+from typing import Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+
+import detrsmpl.core.conventions.keypoints_mapping as keypoints_mapping
+from detrsmpl.core.renderer.matplotlib3d_renderer import Axes3dJointsRenderer
+from detrsmpl.utils.demo_utils import get_different_colors
+from detrsmpl.utils.keypoint_utils import search_limbs
+from detrsmpl.utils.path_utils import prepare_output_path
+
+
+def _norm_pose(pose_numpy: np.ndarray, min_value: Union[float, int],
+ max_value: Union[float, int], mask: Union[np.ndarray, list]):
+ """Normalize the poses and make the center close to axis center."""
+ assert max_value > min_value
+ pose_np_normed = pose_numpy.copy()
+ if not mask:
+ mask = list(range(pose_numpy.shape[-2]))
+ axis_num = 3
+ axis_stat = np.zeros(shape=[axis_num, 4])
+ for axis_index in range(axis_num):
+ axis_data = pose_np_normed[..., mask, axis_index]
+ axis_min = np.min(axis_data)
+ axis_max = np.max(axis_data)
+ axis_mid = (axis_min + axis_max) / 2.0
+ axis_span = axis_max - axis_min
+ axis_stat[axis_index] = np.asarray(
+ (axis_min, axis_max, axis_mid, axis_span))
+ target_mid = (max_value + min_value) / 2.0
+ max_span = np.max(axis_stat[:, 3])
+ target_span = max_value - min_value
+ for axis_index in range(axis_num):
+ pose_np_normed[..., axis_index] = \
+ pose_np_normed[..., axis_index] - \
+ axis_stat[axis_index, 2]
+ pose_np_normed = pose_np_normed / max_span * target_span
+ pose_np_normed = pose_np_normed + target_mid
+ return pose_np_normed
+
+
+def visualize_kp3d(
+ kp3d: np.ndarray,
+ output_path: Optional[str] = None,
+ limbs: Optional[Union[np.ndarray, List[int]]] = None,
+ palette: Optional[Iterable[int]] = None,
+ data_source: str = 'coco',
+ mask: Optional[Union[list, tuple, np.ndarray]] = None,
+ start: int = 0,
+ end: Optional[int] = None,
+ resolution: Union[list, Tuple[int, int]] = (1024, 1024),
+ fps: Union[float, int] = 30,
+ frame_names: Optional[Union[List[str], str]] = None,
+ orbit_speed: Union[float, int] = 0.5,
+ value_range: Union[Tuple[int, int], list] = (-100, 100),
+ pop_parts: Iterable[str] = (),
+ disable_limbs: bool = False,
+ return_array: Optional[bool] = None,
+ convention: str = 'opencv',
+ keypoints_factory: dict = keypoints_mapping.KEYPOINTS_FACTORY,
+) -> Union[None, np.ndarray]:
+ """Visualize 3d keypoints to a video with matplotlib. Support multi person
+ and specified limb connections.
+
+ Args:
+ kp3d (np.ndarray): shape could be (f * J * 4/3/2) or
+ (f * num_person * J * 4/3/2)
+ output_path (str): output video path image folder.
+ limbs (Optional[Union[np.ndarray, List[int]]], optional):
+ if not specified, the limbs will be searched by search_limbs,
+ this option is for free skeletons like BVH file.
+ Defaults to None.
+ palette (Iterable, optional): specified palette, three int represents
+ (B, G, R). Should be tuple or list.
+ Defaults to None.
+ data_source (str, optional): data source type. Defaults to 'coco'.
+ choose in ['coco', 'smplx', 'smpl', 'coco_wholebody',
+ 'mpi_inf_3dhp', 'mpi_inf_3dhp_test', 'h36m', 'pw3d', 'mpii']
+ mask (Optional[Union[list, tuple, np.ndarray]], optional):
+ mask to mask out the incorrect points. Defaults to None.
+ start (int, optional): start frame index. Defaults to 0.
+ end (int, optional): end frame index.
+ Could be positive int or negative int or None.
+ None represents include all the frames.
+ Defaults to None.
+ resolution (Union[list, Tuple[int, int]], optional):
+ (width, height) of the output video
+ will be the same size as the original images if not specified.
+ Defaults to None.
+ fps (Union[float, int], optional): fps. Defaults to 30.
+ frame_names (Optional[Union[List[str], str]], optional): List(should be
+ the same as frame numbers) or single string or string format
+ (like 'frame%06d')for frame title, no title if None.
+ Defaults to None.
+ orbit_speed (Union[float, int], optional): orbit speed of camera.
+ Defaults to 0.5.
+ value_range (Union[Tuple[int, int], list], optional):
+ range of axis value. Defaults to (-100, 100).
+ pop_parts (Iterable[str], optional): The body part names you do not
+ want to visualize. Choose in ['left_eye','right_eye', 'nose',
+ 'mouth', 'face', 'left_hand', 'right_hand']Defaults to [].
+ disable_limbs (bool, optional): whether need to disable drawing limbs.
+ Defaults to False.
+ return_array (bool, optional): Whether to return images as opencv array
+ .If None, an array will be returned when frame number is below 100.
+ Defaults to None.
+ keypoints_factory (dict, optional): Dict of all the conventions.
+ Defaults to KEYPOINTS_FACTORY.
+ Raises:
+ TypeError: check the type of input keypoints.
+ FileNotFoundError: check the output video path.
+
+ Returns:
+ Union[None, np.ndarray].
+ """
+ # check input shape
+ if not isinstance(kp3d, np.ndarray):
+ raise TypeError(
+ f'Input type is {type(kp3d)}, which should be numpy.ndarray.')
+ kp3d = kp3d.copy()
+ if kp3d.shape[-1] == 2:
+ kp3d = np.concatenate([kp3d, np.zeros_like(kp3d)[..., 0:1]], axis=-1)
+ warnings.warn(
+ 'The input array is 2-Dimensional coordinates, will concatenate ' +
+ f'zeros to the last axis. The new array shape: {kp3d.shape}')
+ elif kp3d.shape[-1] >= 4:
+ kp3d = kp3d[..., :3]
+ warnings.warn(
+ 'The input array has more than 3-Dimensional coordinates, will ' +
+ 'keep only the first 3-Dimensions of the last axis. The new ' +
+ f'array shape: {kp3d.shape}')
+ if kp3d.ndim == 3:
+ kp3d = np.expand_dims(kp3d, 1)
+ num_frames = kp3d.shape[0]
+ assert kp3d.ndim == 4
+ assert kp3d.shape[-1] == 3
+
+ if return_array is None:
+ if num_frames > 100:
+ return_array = False
+ else:
+ return_array = True
+
+ # check data_source & mask
+ if data_source not in keypoints_factory:
+ raise ValueError('Wrong data_source. Should choose in' +
+ f'{list(keypoints_factory.keys())}')
+ if mask is not None:
+ if not isinstance(mask, np.ndarray):
+ mask = np.array(mask).reshape(-1)
+ assert mask.shape == (
+ len(keypoints_factory[data_source]),
+ ), f'mask length should fit with keypoints number \
+ {len(keypoints_factory[data_source])}'
+
+ # check the output path
+ if output_path is not None:
+ prepare_output_path(output_path,
+ path_type='auto',
+ tag='output video',
+ allowed_suffix=['.mp4', '.gif', ''])
+
+ # slice the frames
+ end = num_frames if end is None else end
+ kp3d = kp3d[start:end]
+ # norm the coordinates
+ if value_range is not None:
+ # norm pose location to value_range (70% value range)
+ mask_index = np.where(np.array(mask) > 0) if mask is not None else None
+ margin_width = abs(value_range[1] - value_range[1]) * 0.15
+ pose_np_normed = _norm_pose(kp3d, value_range[0] + margin_width,
+ value_range[1] - margin_width, mask_index)
+ input_pose_np = pose_np_normed
+ else:
+ input_pose_np = kp3d
+
+ # determine the limb connections and palettes
+ if limbs is not None:
+ limbs_target, limbs_palette = {
+ 'body': limbs.tolist() if isinstance(limbs, np.ndarray) else limbs
+ }, get_different_colors(len(limbs))
+ else:
+ limbs_target, limbs_palette = search_limbs(data_source=data_source,
+ mask=mask)
+ if palette is not None:
+ limbs_palette = np.array(palette, dtype=np.uint8)[None]
+
+ # check and pop the pop_parts
+ assert set(pop_parts).issubset(
+ keypoints_mapping.human_data.HUMAN_DATA_PALETTE.keys(
+ )), f'wrong part_names in pop_parts, could only \
+ choose in{set(keypoints_mapping.human_data.HUMAN_DATA_PALETTE.keys())}'
+
+ for part_name in pop_parts:
+ if part_name in limbs_target:
+ limbs_target.pop(part_name)
+
+ # initialize renderer and start render
+ renderer = Axes3dJointsRenderer()
+ renderer.init_camera(cam_hori_speed=orbit_speed, cam_elev_speed=0.2)
+ renderer.set_connections(limbs_target, limbs_palette)
+ if isinstance(frame_names, str):
+ if '%' in frame_names:
+ frame_names = [
+ frame_names % index for index in range(input_pose_np.shape[0])
+ ]
+ else:
+ frame_names = [frame_names] * input_pose_np.shape[0]
+ image_array = renderer.render_kp3d_to_video(input_pose_np,
+ output_path,
+ convention,
+ fps=fps,
+ resolution=resolution,
+ visual_range=value_range,
+ frame_names=frame_names,
+ disable_limbs=disable_limbs,
+ return_array=return_array)
+ return image_array
diff --git a/detrsmpl/core/visualization/visualize_smpl.py b/detrsmpl/core/visualization/visualize_smpl.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed68b904c90b1482978e47d84f47746b0c489431
--- /dev/null
+++ b/detrsmpl/core/visualization/visualize_smpl.py
@@ -0,0 +1,1209 @@
+# yapf: disable
+import copy
+import glob
+import os
+import os.path as osp
+import shutil
+import warnings
+from functools import partial
+from pathlib import Path
+from typing import List, Optional, Tuple, Union
+
+import mmcv
+import numpy as np
+import torch
+import torch.nn as nn
+from colormap import Color
+
+from detrsmpl.core.cameras import (
+ WeakPerspectiveCameras,
+ compute_orbit_cameras,
+)
+from detrsmpl.core.cameras.builder import build_cameras
+from detrsmpl.core.conventions.cameras.convert_convention import \
+ convert_camera_matrix # prevent yapf isort conflict
+from detrsmpl.core.conventions.segmentation import body_segmentation
+from detrsmpl.core.renderer.torch3d_renderer import render_runner
+from detrsmpl.core.renderer.torch3d_renderer.meshes import \
+ ParametricMeshes # noqa: E501
+from detrsmpl.core.renderer.torch3d_renderer.render_smpl_config import (
+ RENDER_CONFIGS,
+)
+from detrsmpl.core.renderer.torch3d_renderer.smpl_renderer import SMPLRenderer
+from detrsmpl.core.renderer.torch3d_renderer.utils import \
+ align_input_to_padded # noqa: E501
+from detrsmpl.models.body_models.builder import build_body_model
+from detrsmpl.utils.demo_utils import (
+ convert_bbox_to_intrinsic,
+ convert_crop_cam_to_orig_img,
+ convert_kp2d_to_bbox,
+ get_default_hmr_intrinsic,
+ get_different_colors,
+)
+from detrsmpl.utils.ffmpeg_utils import (
+ check_input_path,
+ images_to_array,
+ prepare_output_path,
+ vid_info_reader,
+ video_to_array,
+ video_to_images,
+)
+from detrsmpl.utils.mesh_utils import save_meshes_as_objs, save_meshes_as_plys
+from detrsmpl.utils.path_utils import check_path_suffix
+
+# yapf: enable
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+
+def _prepare_background(image_array, frame_list, origin_frames, output_path,
+ start, end, img_format, overwrite, num_frames,
+ read_frames_batch):
+ """Compare among `image_array`, `frame_list` and `origin_frames` and decide
+ whether to save the temp background images."""
+ if num_frames > 300:
+ read_frames_batch = True
+
+ frames_folder = None
+ remove_folder = False
+
+ if isinstance(image_array, np.ndarray):
+
+ image_array = torch.Tensor(image_array)
+
+ if image_array is not None:
+ if image_array.ndim == 3:
+ image_array = image_array[None]
+ if image_array.shape[0] == 1:
+ image_array = image_array.repeat(num_frames, 1, 1, 1)
+ frame_list = None
+ origin_frames = None
+ image_array = image_array[start:end]
+
+ # check the output path and get the image_array
+ if output_path is not None:
+ prepare_output_path(output_path=output_path,
+ allowed_suffix=['.mp4', 'gif', '.png', '.jpg','.jpeg'],
+ tag='output video',
+ path_type='auto',
+ overwrite=overwrite)
+ if image_array is None:
+ # choose in frame_list or origin_frames
+ # if all None, will use pure white background
+ if frame_list is None and origin_frames is None:
+ print(
+ 'No background provided, will use pure white background.')
+ elif frame_list is not None and origin_frames is not None:
+ warnings.warn('Redundant input, will only use frame_list.')
+ origin_frames = None
+
+ # read the origin frames as array if any.
+ if frame_list is None and origin_frames is not None:
+ check_input_path(input_path=origin_frames,
+ allowed_suffix=['.mp4', '.gif', ''],
+ tag='origin frames',
+ path_type='auto')
+ # if origin_frames is a video, write it as a folder of images
+ # if read_frames_batch is True, else read directly as an array.
+ if Path(origin_frames).is_file():
+ if read_frames_batch:
+ frames_folder = osp.join(
+ Path(output_path).parent,
+ Path(output_path).name + '_input_temp')
+ os.makedirs(frames_folder, exist_ok=True)
+ video_to_images(origin_frames,
+ frames_folder,
+ img_format=img_format,
+ start=start,
+ end=end)
+ remove_folder = True
+ else:
+ remove_folder = False
+ frames_folder = None
+ image_array = video_to_array(origin_frames,
+ start=start,
+ end=end)
+ # if origin_frames is a folder, write it as a folder of images
+ # read the folder as an array if read_frames_batch is True
+ # else return frames_folder for reading during rendering.
+ else:
+ if read_frames_batch:
+ frames_folder = origin_frames
+ remove_folder = False
+ image_array = None
+ else:
+ image_array = images_to_array(origin_frames,
+ img_format=img_format,
+ start=start,
+ end=end)
+ remove_folder = False
+ frames_folder = origin_frames
+ # if frame_list is not None, move the images into a folder
+ # read the folder as an array if read_frames_batch is True
+ # else return frames_folder for reading during rendering.
+ elif frame_list is not None and origin_frames is None:
+ frames_folder = osp.join(
+ Path(output_path).parent,
+ Path(output_path).name + '_input_temp')
+ os.makedirs(frames_folder, exist_ok=True)
+ for frame_idx, frame_path in enumerate(frame_list):
+ if check_path_suffix(
+ path_str=frame_path,
+ allowed_suffix=['.jpg', '.png', '.jpeg']):
+ shutil.copy(
+ frame_path,
+ os.path.join(frames_folder,
+ '%06d.png' % frame_idx))
+ img_format = '%06d.png'
+ if not read_frames_batch:
+
+ image_array = images_to_array(frames_folder,
+ img_format=img_format,
+ remove_raw_files=True)
+ frames_folder = None
+ remove_folder = False
+ else:
+ image_array = None
+ remove_folder = True
+ return image_array, remove_folder, frames_folder
+
+
+def _prepare_body_model(body_model, body_model_config):
+ """Prepare `body_model` from `body_model_config` or existing
+ `body_model`."""
+ if body_model is None:
+ if body_model_config is not None:
+ body_model_config = copy.deepcopy(body_model_config)
+ model_path = body_model_config.get('model_path', None)
+
+ model_type = body_model_config.get('type').lower()
+ if model_type not in ['smpl', 'smplx']:
+ raise ValueError(f'Do not support {model_type}, please choose'
+ f' in `smpl` or `smplx.')
+
+ if model_path and osp.isdir(model_path):
+ model_path = osp.join(model_path, model_type)
+ body_model_config.update(model_path=model_path)
+ body_model = build_body_model(body_model_config)
+ assert os.path.isdir(model_path)
+ else:
+ raise FileNotFoundError('Wrong model_path.'
+ ' File or directory does not exist.')
+ else:
+ raise ValueError('Please input body_model_config.')
+ else:
+ if body_model_config is not None:
+ warnings.warn('Redundant input, will take body_model directly'
+ 'and ignore body_model_config.')
+ return body_model
+
+
+def _prepare_input_pose(verts, poses, betas, transl):
+ """Prepare input pose data as tensor and ensure correct temporal slice."""
+ if verts is None and poses is None:
+ raise ValueError('Please input valid poses or verts.')
+ elif (verts is not None) and (poses is not None):
+ warnings.warn('Redundant input, will take verts and ignore poses & '
+ 'betas & transl.')
+ poses = None
+ transl = None
+ betas = None
+ elif isinstance(poses, dict):
+ transl = poses.get('transl', transl)
+ betas = poses.get('betas', betas)
+
+ if isinstance(verts, np.ndarray):
+ verts = torch.Tensor(verts)
+ num_frames = verts.shape[0]
+ elif isinstance(verts, torch.Tensor):
+ num_frames = verts.shape[0]
+
+ if isinstance(poses, np.ndarray):
+ poses = torch.Tensor(poses)
+ num_frames = poses.shape[0]
+ elif isinstance(poses, torch.Tensor):
+ num_frames = poses.shape[0]
+ elif isinstance(poses, dict):
+ for k, v in poses.items():
+ if isinstance(v, np.ndarray):
+ poses[k] = torch.tensor(v)
+ num_frames = poses['body_pose'].shape[0]
+
+ if isinstance(betas, np.ndarray):
+ betas = torch.Tensor(betas)
+
+ if betas is not None:
+ if betas.shape[0] != num_frames:
+ times = num_frames // betas.shape[0]
+ if betas.ndim == 2:
+ betas = betas.repeat(times, 1)[:num_frames]
+ elif betas.ndim == 3:
+ betas = betas.repeat(times, 1, 1)[:num_frames]
+ print(f'betas will be repeated by dim 0 for {times} times.')
+ if isinstance(transl, np.ndarray):
+ transl = torch.Tensor(transl)
+
+ return verts, poses, betas, transl
+
+
+def _prepare_mesh(poses, betas, transl, verts, start, end, body_model):
+ """Prepare the mesh info for rendering."""
+ NUM_JOINTS = body_model.NUM_JOINTS
+ NUM_BODY_JOINTS = body_model.NUM_BODY_JOINTS
+ NUM_DIM = 3 * (NUM_JOINTS + 1)
+ body_pose_keys = body_model.body_pose_keys
+ joints = None
+ if poses is not None:
+ if isinstance(poses, dict):
+ if not body_pose_keys.issubset(poses):
+ raise KeyError(
+ f'{str(poses.keys())}, Please make sure that your '
+ f'input dict has all of {", ".join(body_pose_keys)}')
+ num_frames = poses['body_pose'].shape[0]
+ _, num_person, _ = poses['body_pose'].view(
+ num_frames, -1, NUM_BODY_JOINTS * 3).shape
+
+ full_pose = body_model.dict2tensor(poses)
+ full_pose = full_pose[start:end]
+
+ elif isinstance(poses, torch.Tensor):
+ if poses.shape[-1] != NUM_DIM:
+ raise ValueError(
+ f'Please make sure your poses is {NUM_DIM} dims in'
+ f'the last axis. Your input shape: {poses.shape}')
+ poses = poses.view(poses.shape[0], -1, (NUM_JOINTS + 1) * 3)
+ num_frames, num_person, _ = poses.shape
+ full_pose = poses[start:end]
+ else:
+ raise ValueError('Wrong pose type, should be `dict` or `tensor`.')
+
+ # multi person check
+ if num_person > 1:
+ if betas is not None:
+ num_betas = betas.shape[-1]
+ betas = betas.view(num_frames, -1, num_betas)
+
+ if betas.shape[1] == 1:
+ betas = betas.repeat(1, num_person, 1)
+ warnings.warn(
+ 'Only one betas for multi-person, will all be the '
+ 'same body shape.')
+ elif betas.shape[1] > num_person:
+ betas = betas[:, :num_person]
+ warnings.warn(
+ f'Betas shape exceed, will be sliced as {betas.shape}.'
+ )
+ elif betas.shape[1] == num_person:
+ pass
+ else:
+ raise ValueError(
+ f'Odd betas shape: {betas.shape}, inconsistent'
+ f'with poses in num_person: {poses.shape}.')
+ else:
+ warnings.warn('None betas for multi-person, will all be the '
+ 'default body shape.')
+
+ if transl is not None:
+ transl = transl.view(poses.shape[0], -1, 3)
+ if transl.shape[1] == 1:
+ transl = transl.repeat(1, num_person, 1)
+ warnings.warn(
+ 'Only one transl for multi-person, will all be the '
+ 'same translation.')
+ elif transl.shape[1] > num_person:
+ transl = transl[:, :num_person]
+ warnings.warn(f'Transl shape exceed, will be sliced as'
+ f'{transl.shape}.')
+ elif transl.shape[1] == num_person:
+ pass
+ else:
+ raise ValueError(
+ f'Odd transl shape: {transl.shape}, inconsistent'
+ f'with poses in num_person: {poses.shape}.')
+ else:
+ warnings.warn('None transl for multi-person, will all be the '
+ 'default translation.')
+
+ # slice the input poses, betas, and transl.
+ betas = betas[start:end] if betas is not None else None
+ transl = transl[start:end] if transl is not None else None
+ pose_dict = body_model.tensor2dict(full_pose=full_pose,
+ betas=betas,
+ transl=transl)
+
+ # get new num_frames
+ num_frames = full_pose.shape[0]
+
+ model_output = body_model(**pose_dict)
+ vertices = model_output['vertices']
+ joints = model_output['joints'][0] # hardcode here
+
+ elif verts is not None:
+ if isinstance(verts, np.ndarray):
+ verts = torch.Tensor(verts)
+ verts = verts[start:end]
+ pose_dict = body_model.tensor2dict(torch.zeros(1,
+ (NUM_JOINTS + 1) * 3))
+
+ if verts.ndim == 3:
+ joints = torch.einsum('bik,ji->bjk',
+ [verts, body_model.J_regressor])
+ elif verts.ndim == 4:
+ joints = torch.einsum('fpik,ji->fpjk',
+ [verts, body_model.J_regressor])
+ num_verts = body_model.NUM_VERTS
+ assert verts.shape[-2] == num_verts, 'Wrong input verts shape.'
+ num_frames = verts.shape[0]
+ vertices = verts.view(num_frames, -1, num_verts, 3)
+ num_joints = joints.shape[-2]
+ joints = joints.view(num_frames, -1, num_joints, 3)
+ num_person = vertices.shape[1]
+ else:
+ raise ValueError('Poses and verts are all None.')
+ return vertices, joints, num_frames, num_person
+
+
+def _prepare_colors(palette, render_choice, num_person, num_verts, model_type):
+ """Prepare the `color` as a tensor of shape (num_person, num_verts, 3)
+ according to `palette`.
+
+ This is to make the identity in video clear.
+ """
+ if not len(palette) == num_person:
+ raise ValueError('Please give the right number of palette.')
+ body_segger = body_segmentation(model_type)
+
+ if render_choice == 'silhouette':
+ colors = torch.ones(num_person, num_verts, 3)
+ elif render_choice == 'part_silhouette':
+ colors = torch.zeros(num_person, num_verts, 3)
+ for i, k in enumerate(body_segger.keys()):
+ colors[:, body_segger[k]] = i + 1
+ else:
+ if isinstance(palette, torch.Tensor):
+ if palette.max() > 1:
+ palette = palette / 255.0
+ palette = torch.clip(palette, min=0, max=1)
+ colors = palette.view(num_person,
+ 3).unsqueeze(1).repeat(1, num_verts, 1)
+
+ elif isinstance(palette, list):
+ colors = []
+ for person_idx in range(num_person):
+
+ if palette[person_idx] == 'random':
+ color_person = get_different_colors(
+ num_person, int_dtype=False)[person_idx]
+ color_person = torch.FloatTensor(color_person)
+ color_person = torch.clip(color_person * 1.5,
+ min=0.6,
+ max=1)
+ color_person = color_person.view(1, 1, 3).repeat(
+ 1, num_verts, 1)
+ elif palette[person_idx] == 'segmentation':
+ verts_labels = torch.zeros(num_verts)
+ color_person = torch.ones(1, num_verts, 3)
+ color_part = get_different_colors(len(body_segger),
+ int_dtype=False)
+ for part_idx, k in enumerate(body_segger.keys()):
+ index = body_segger[k]
+ verts_labels[index] = part_idx
+ color_person[:, index] = torch.FloatTensor(
+ color_part[part_idx])
+ elif palette[person_idx] in Color.color_names:
+ color_person = torch.FloatTensor(
+ Color(palette[person_idx]).rgb).view(1, 1, 3).repeat(
+ 1, num_verts, 1)
+ else:
+ raise ValueError('Wrong palette string. '
+ 'Please choose in the pre-defined range.')
+ colors.append(color_person)
+ colors = torch.cat(colors, 0)
+ assert colors.shape == (num_person, num_verts, 3)
+ # the color passed to renderer will be (num_person, num_verts, 3)
+ else:
+ raise ValueError(
+ 'Palette should be tensor, array or list of strs.')
+ return colors
+
+
+def render_smpl(
+ # smpl parameters
+ poses: Optional[Union[torch.Tensor, np.ndarray, dict]] = None,
+ betas: Optional[Union[torch.Tensor, np.ndarray]] = None,
+ transl: Optional[Union[torch.Tensor, np.ndarray]] = None,
+ verts: Optional[Union[torch.Tensor, np.ndarray]] = None,
+ body_model: Optional[nn.Module] = None,
+ body_model_config: Optional[dict] = None,
+ # camera parameters
+ R: Optional[Union[torch.Tensor, np.ndarray]] = None,
+ T: Optional[Union[torch.Tensor, np.ndarray]] = None,
+ K: Optional[Union[torch.Tensor, np.ndarray]] = None,
+ orig_cam: Optional[Union[torch.Tensor, np.ndarray]] = None,
+ Ks: Optional[Union[torch.Tensor, np.ndarray]] = None,
+ in_ndc: bool = True,
+ convention: str = 'pytorch3d',
+ projection: Literal['weakperspective', 'perspective', 'fovperspective',
+ 'orthographics',
+ 'fovorthographics'] = 'perspective',
+ orbit_speed: Union[float, Tuple[float, float]] = 0.0,
+ # render choice parameters
+ render_choice: Literal['lq', 'mq', 'hq', 'silhouette', 'depth',
+ 'normal', 'pointcloud',
+ 'part_silhouette'] = 'hq',
+ palette: Union[List[str], str, np.ndarray, torch.Tensor] = 'white',
+ texture_image: Union[torch.Tensor, np.ndarray] = None,
+ resolution: Optional[Union[List[int], Tuple[int, int]]] = None,
+ start: int = 0,
+ end: Optional[int] = None,
+ alpha: float = 1.0,
+ no_grad: bool = True,
+ batch_size: int = 10,
+ device: Union[torch.device, str] = 'cuda',
+ # file io parameters
+ return_tensor: bool = False,
+ output_path: str = None,
+ origin_frames: Optional[str] = None,
+ frame_list: Optional[List[str]] = None,
+ image_array: Optional[Union[np.ndarray, torch.Tensor]] = None,
+ img_format: str = '%06d.png',
+ overwrite: bool = False,
+ mesh_file_path: Optional[str] = None,
+ read_frames_batch: bool = False,
+ # visualize keypoints
+ plot_kps: bool = False,
+ kp3d: Optional[Union[np.ndarray, torch.Tensor]] = None,
+ mask: Optional[Union[np.ndarray, List[int]]] = None,
+ vis_kp_index: bool = False,
+ verbose: bool = False) -> Union[None, torch.Tensor]:
+ """Render SMPL or SMPL-X mesh or silhouette into differentiable tensors,
+ and export video or images.
+
+ Args:
+ # smpl parameters:
+ poses (Union[torch.Tensor, np.ndarray, dict]):
+
+ 1). `tensor` or `array` and ndim is 2, shape should be
+ (frame, 72).
+
+ 2). `tensor` or `array` and ndim is 3, shape should be
+ (frame, num_person, 72/165). num_person equals 1 means
+ single-person.
+ Rendering predicted multi-person should feed together with
+ multi-person weakperspective cameras. meshes would be computed
+ and use an identity intrinsic matrix.
+
+ 3). `dict`, standard dict format defined in smplx.body_models.
+ will be treated as single-person.
+
+ Lower priority than `verts`.
+
+ Defaults to None.
+ betas (Optional[Union[torch.Tensor, np.ndarray]], optional):
+ 1). ndim is 2, shape should be (frame, 10).
+
+ 2). ndim is 3, shape should be (frame, num_person, 10). num_person
+ equals 1 means single-person. If poses are multi-person, betas
+ should be set to the same person number.
+
+ None will use default betas.
+
+ Defaults to None.
+ transl (Optional[Union[torch.Tensor, np.ndarray]], optional):
+ translations of smpl(x).
+
+ 1). ndim is 2, shape should be (frame, 3).
+
+ 2). ndim is 3, shape should be (frame, num_person, 3). num_person
+ equals 1 means single-person. If poses are multi-person,
+ transl should be set to the same person number.
+
+ Defaults to None.
+ verts (Optional[Union[torch.Tensor, np.ndarray]], optional):
+ 1). ndim is 3, shape should be (frame, num_verts, 3).
+
+ 2). ndim is 4, shape should be (frame, num_person, num_verts, 3).
+ num_person equals 1 means single-person.
+
+ Higher priority over `poses` & `betas` & `transl`.
+
+ Defaults to None.
+ body_model (nn.Module, optional): body_model created from smplx.create.
+ Higher priority than `body_model_config`. If `body_model` is not
+ None, it will override `body_model_config`.
+ Should not both be None.
+
+ Defaults to None.
+ body_model_config (dict, optional): body_model_config for build_model.
+ Lower priority than `body_model`. Should not both be None.
+ Defaults to None.
+
+ # camera parameters:
+
+ K (Optional[Union[torch.Tensor, np.ndarray]], optional):
+ shape should be (frame, 4, 4) or (frame, 3, 3), frame could be 1.
+ if (4, 4) or (3, 3), dim 0 will be added automatically.
+ Will be default `FovPerspectiveCameras` intrinsic if None.
+ Lower priority than `orig_cam`.
+ R (Optional[Union[torch.Tensor, np.ndarray]], optional):
+ shape should be (frame, 3, 3), If f equals 1, camera will have
+ identical rotation.
+ If `K` and `orig_cam` is None, will be generated by `look_at_view`.
+ If have `K` or `orig_cam` and `R` is None, will be generated by
+ `convert_camera_matrix`.
+
+ Defaults to None.
+ T (Optional[Union[torch.Tensor, np.ndarray]], optional):
+ shape should be (frame, 3). If f equals 1, camera will have
+ identical translation.
+ If `K` and `orig_cam` is None, will be generated by `look_at_view`.
+ If have `K` or `orig_cam` and `T` is None, will be generated by
+ `convert_camera_matrix`.
+
+ Defaults to None.
+ orig_cam (Optional[Union[torch.Tensor, np.ndarray]], optional):
+ shape should be (frame, 4) or (frame, num_person, 4). If f equals
+ 1, will be repeated to num_frames. num_person should be 1 if single
+ person. Usually for HMR, VIBE predicted cameras.
+ Higher priority than `K` & `R` & `T`.
+
+ Defaults to None.
+ Ks (Optional[Union[torch.Tensor, np.ndarray]], optional):
+ shape should be (frame, 4, 4).
+ This is for HMR or SPIN multi-person demo.
+ in_ndc (bool, optional): . Defaults to True.
+ convention (str, optional): If want to use an existing convention,
+ choose in ['opengl', 'opencv', 'pytorch3d', 'pyrender', 'open3d',
+ 'maya', 'blender', 'unity'].
+ If want to use a new convention, define your convention in
+ (CAMERA_CONVENTION_FACTORY)[mmhuman3d/core/conventions/cameras/
+ __init__.py] by the order of right, front and up.
+
+ Defaults to 'pytorch3d'.
+ projection (Literal[, optional): projection mode of camers. Choose in
+ ['orthographics, fovperspective', 'perspective', 'weakperspective',
+ 'fovorthographics']
+ Defaults to 'perspective'.
+ orbit_speed (float, optional): orbit speed for viewing when no `K`
+ provided. `float` for only azim speed and Tuple for `azim` and
+ `elev`.
+
+ # render choice parameters:
+
+ render_choice (Literal[, optional):
+ choose in ['lq', 'mq', 'hq', 'silhouette', 'depth', 'normal',
+ 'pointcloud', 'part_silhouette'] .
+
+ `lq`, `mq`, `hq` would output (frame, h, w, 4) FloatTensor.
+
+ `lq` means low quality, `mq` means medium quality,
+ h`q means high quality.
+
+ `silhouette` would output (frame, h, w) soft binary FloatTensor.
+
+ `part_silhouette` would output (frame, h, w, 1) LongTensor.
+
+ Every pixel stores a class index.
+
+ `depth` will output a depth map of (frame, h, w, 1) FloatTensor
+ and 'normal' will output a normal map of (frame, h, w, 1).
+
+ `pointcloud` will output a (frame, h, w, 4) FloatTensor.
+
+ Defaults to 'mq'.
+ palette (Union[List[str], str, np.ndarray], optional):
+ color theme str or list of color str or `array`.
+
+ 1). If use str to represent the color,
+ should choose in ['segmentation', 'random'] or color from
+ Colormap https://en.wikipedia.org/wiki/X11_color_names.
+ If choose 'segmentation', will get a color for each part.
+
+ 2). If you have multi-person, better give a list of str or all
+ will be in the same color.
+
+ 3). If you want to define your specific color, use an `array`
+ of shape (3,) for single person and (N, 3) for multiple persons.
+
+ If (3,) for multiple persons, all will be in the same color.
+
+ Your `array` should be in range [0, 255] for 8 bit color.
+
+ Defaults to 'white'.
+
+ texture_image (Union[torch.Tensor, np.ndarray], optional):
+ Texture image to be wrapped on the smpl mesh. If not None,
+ the `palette` will be ignored, and the `body_model` is required
+ to have `uv_param_path`.
+ Should pass list or tensor of shape (num_person, H, W, 3).
+ The color channel should be `RGB`.
+
+ Defaults to None.
+
+ resolution (Union[Iterable[int], int], optional):
+ 1). If iterable, should be (height, width) of output images.
+
+ 2). If int, would be taken as (resolution, resolution).
+
+ Defaults to (1024, 1024).
+
+ This will influence the overlay results when render with
+ backgrounds. The output video will be rendered following the
+ size of background images and finally resized to resolution.
+ start (int, optional): start frame index. Defaults to 0.
+
+ end (int, optional): end frame index. Exclusive.
+ Could be positive int or negative int or None.
+ None represents include all the frames.
+
+ Defaults to None.
+ alpha (float, optional): Transparency of the mesh.
+ Range in [0.0, 1.0]
+
+ Defaults to 1.0.
+ no_grad (bool, optional): Set to True if do not need differentiable
+ render.
+
+ Defaults to False.
+ batch_size (int, optional): Batch size for render.
+ Related to your gpu memory.
+
+ Defaults to 10.
+ # file io parameters:
+
+ return_tensor (bool, optional): Whether return the result tensors.
+
+ Defaults to False, will return None.
+ output_path (str, optional): output video or gif or image folder.
+
+ Defaults to None, pass export procedure.
+
+ # background frames, priority: image_array > frame_list > origin_frames
+
+ origin_frames (Optional[str], optional): origin background frame path,
+ could be `.mp4`, `.gif`(will be sliced into a folder) or an image
+ folder.
+
+ Defaults to None.
+ frame_list (Optional[List[str]], optional): list of origin background
+ frame paths, element in list each should be a image path like
+ `*.jpg` or `*.png`.
+ Use this when your file names is hard to sort or you only want to
+ render a small number frames.
+
+ Defaults to None.
+ image_array: (Optional[Union[np.ndarray, torch.Tensor]], optional):
+ origin background frame `tensor` or `array`, use this when you
+ want your frames in memory as array or tensor.
+ overwrite (bool, optional): whether overwriting the existing files.
+
+ Defaults to False.
+ mesh_file_path (bool, optional): the directory path to store the `.ply`
+ or '.ply' files. Will be named like 'frame_idx_person_idx.ply'.
+
+ Defaults to None.
+ read_frames_batch (bool, optional): Whether read frames by batch.
+ Set it as True if your video is large in size.
+
+ Defaults to False.
+
+ # visualize keypoints
+ plot_kps (bool, optional): whether plot keypoints on the output video.
+
+ Defaults to False.
+ kp3d (Optional[Union[np.ndarray, torch.Tensor]], optional):
+ the keypoints of any convention, should pass `mask` if have any
+ none-sense points. Shape should be (frame, )
+
+ Defaults to None.
+ mask (Optional[Union[np.ndarray, List[int]]], optional):
+ Mask of keypoints existence.
+
+ Defaults to None.
+ vis_kp_index (bool, optional):
+ Whether plot keypoint index number on human mesh.
+
+ Defaults to False.
+ # visualize render progress
+ verbose (bool, optional):
+ Whether print the progress bar for rendering.
+ Returns:
+ Union[None, torch.Tensor]: return the rendered image tensors or None.
+ """
+ # initialize the device
+ device = torch.device(device) if isinstance(device, str) else device
+
+ if isinstance(resolution, int):
+ resolution = (resolution, resolution)
+ elif isinstance(resolution, list):
+ resolution = tuple(resolution)
+
+ verts, poses, betas, transl = _prepare_input_pose(verts, poses, betas,
+ transl)
+
+ body_model = _prepare_body_model(body_model, body_model_config)
+ model_type = body_model.name().replace('-', '').lower()
+ assert model_type in ['smpl', 'smplx']
+
+ vertices, joints, num_frames, num_person = _prepare_mesh(
+ poses, betas, transl, verts, start, end, body_model)
+ end = num_frames if end is None else end
+ vertices = vertices.view(num_frames, num_person, -1, 3)
+ num_verts = vertices.shape[-2]
+
+ if not plot_kps:
+ joints = None
+ if kp3d is not None:
+ warnings.warn('`plot_kps` is False, `kp3d` will be set as None.')
+ kp3d = None
+
+ image_array, remove_folder, frames_folder = _prepare_background(
+ image_array, frame_list, origin_frames, output_path, start, end,
+ img_format, overwrite, num_frames, read_frames_batch)
+
+ render_resolution = None
+ if image_array is not None:
+ render_resolution = (image_array.shape[1], image_array.shape[2])
+ elif frames_folder is not None:
+ frame_path_list = glob.glob(osp.join(
+ frames_folder, '*.jpg')) + glob.glob(
+ osp.join(frames_folder, '*.png')) + glob.glob(
+ osp.join(frames_folder, '*.jpeg'))
+ vid_info = vid_info_reader(frame_path_list[0])
+ render_resolution = (int(vid_info['height']), int(vid_info['width']))
+ if resolution is not None:
+ if render_resolution is not None:
+ if render_resolution != resolution:
+ warnings.warn(
+ f'Size of background: {render_resolution} !='
+ f' resolution: {resolution}, the output video will be '
+ f'resized as {resolution}')
+ final_resolution = resolution
+ elif render_resolution is None:
+ render_resolution = final_resolution = resolution
+ elif resolution is None:
+ if render_resolution is None:
+ render_resolution = final_resolution = (1024, 1024)
+ elif render_resolution is not None:
+ final_resolution = render_resolution
+
+ if isinstance(kp3d, np.ndarray):
+ kp3d = torch.Tensor(kp3d)
+
+ if kp3d is not None:
+ if mask is not None:
+ map_index = np.where(np.array(mask) != 0)[0]
+ kp3d = kp3d[map_index.tolist()]
+ kp3d = kp3d[start:end]
+ kp3d = kp3d.view(num_frames, -1, 3)
+
+ # prepare render_param_dict
+ render_param_dict = copy.deepcopy(RENDER_CONFIGS[render_choice.lower()])
+ if model_type == 'smpl':
+ render_param_dict.update(num_class=24)
+ elif model_type == 'smplx':
+ render_param_dict.update(num_class=27)
+
+ if render_choice not in [
+ 'hq', 'mq', 'lq', 'silhouette', 'part_silhouette', 'depth',
+ 'pointcloud', 'normal'
+ ]:
+ raise ValueError('Please choose the right render_choice.')
+
+ # body part colorful visualization should use flat shader to be sharper.
+ if texture_image is None:
+ if isinstance(palette, str):
+ palette = [palette] * num_person
+ elif isinstance(palette, np.ndarray):
+ palette = torch.Tensor(palette)
+ palette = palette.view(-1, 3)
+ if palette.shape[0] != num_person:
+ _times = num_person // palette.shape[0]
+ palette = palette.repeat(_times, 1)[:num_person]
+ if palette.shape[0] == 1:
+ print(f'Same color for all the {num_person} people')
+ else:
+ print('Repeat palette for multi-person.')
+ else:
+ raise ValueError('Wrong input palette type. '
+ 'Palette should be tensor, array or list of strs')
+ colors_all = _prepare_colors(palette, render_choice, num_person,
+ num_verts, model_type)
+ colors_all = colors_all.view(-1, num_person * num_verts, 3)
+ # verts of ParametricMeshes should be in (N, V, 3)
+ vertices = vertices.view(num_frames, -1, 3)
+ meshes = ParametricMeshes(
+ body_model=body_model,
+ verts=vertices,
+ N_individual_overdide=num_person,
+ model_type=model_type,
+ texture_image=texture_image,
+ use_nearest=bool(render_choice == 'part_silhouette'),
+ vertex_color=colors_all)
+
+ # write .ply or .obj files
+ if mesh_file_path is not None:
+ mmcv.mkdir_or_exist(mesh_file_path)
+
+ for person_idx in range(meshes.shape[1]):
+ mesh_person = meshes[:, person_idx]
+ if texture_image is None:
+ ply_paths = [
+ f'{mesh_file_path}/frame{frame_idx}_'
+ f'person{person_idx}.ply'
+ for frame_idx in range(num_frames)
+ ]
+ save_meshes_as_plys(meshes=mesh_person, files=ply_paths)
+
+ else:
+ obj_paths = [
+ f'{mesh_file_path}/frame{frame_idx}_'
+ f'person{person_idx}.obj'
+ for frame_idx in range(num_frames)
+ ]
+ save_meshes_as_objs(meshes=mesh_person, files=obj_paths)
+
+ vertices = meshes.verts_padded().view(num_frames, num_person, -1, 3)
+
+ # prepare camera matrixs
+ if Ks is not None:
+ projection = 'perspective'
+ orig_cam = None
+ if isinstance(Ks, np.ndarray):
+ Ks = torch.Tensor(Ks)
+ Ks = Ks.view(-1, num_person, 3, 3)
+ Ks = Ks[start:end]
+ Ks = Ks.view(-1, 3, 3)
+ K = K.repeat(num_frames * num_person, 1, 1)
+
+ Ks = K.inverse() @ Ks @ K
+ vertices = vertices.view(num_frames * num_person, -1, 3)
+ if T is None:
+ T = torch.zeros(num_frames, num_person, 1, 3)
+ elif isinstance(T, np.ndarray):
+ T = torch.Tensor(T)
+ T = T[start:end]
+ T = T.view(num_frames * num_person, 1, 3)
+ vertices = torch.einsum('blc,bvc->bvl', Ks, vertices + T)
+
+ R = None
+ T = None
+ vertices = vertices.view(num_frames, num_person, -1, 3)
+
+ if orig_cam is not None:
+ if isinstance(orig_cam, np.ndarray):
+ orig_cam = torch.Tensor(orig_cam)
+ projection = 'weakperspective'
+ r = render_resolution[1] / render_resolution[0]
+ orig_cam = orig_cam[start:end]
+ orig_cam = orig_cam.view(num_frames, num_person, 4)
+ # if num_person > 1:
+ sx, sy, tx, ty = torch.unbind(orig_cam, -1)
+
+ vertices[..., 0] += tx.view(num_frames, num_person, 1)
+ vertices[..., 1] += ty.view(num_frames, num_person, 1)
+ vertices[..., 0] *= sx.view(num_frames, num_person, 1)
+ vertices[..., 1] *= sy.view(num_frames, num_person, 1)
+ orig_cam = torch.tensor([1.0, 1.0, 0.0,
+ 0.0]).view(1, 4).repeat(num_frames, 1)
+ K, R, T = WeakPerspectiveCameras.convert_orig_cam_to_matrix(
+ orig_cam=orig_cam,
+ znear=torch.min(vertices[..., 2] - 1),
+ aspect_ratio=r)
+
+ if num_person > 1:
+ vertices = vertices.reshape(num_frames, -1, 3)
+ else:
+ vertices = vertices.view(num_frames, -1, 3)
+ meshes = meshes.update_padded(new_verts_padded=vertices)
+
+ # orig_cam and K are None, use look_at_view
+ if K is None:
+ projection = 'fovperspective'
+ K, R, T = compute_orbit_cameras(at=(torch.mean(vertices.view(-1, 3),
+ 0)).detach().cpu(),
+ orbit_speed=orbit_speed,
+ batch_size=num_frames,
+ convention=convention)
+ convention = 'pytorch3d'
+
+ if isinstance(R, np.ndarray):
+ R = torch.Tensor(R).view(-1, 3, 3)
+ elif isinstance(R, torch.Tensor):
+ R = R.view(-1, 3, 3)
+ elif isinstance(R, list):
+ R = torch.Tensor(R).view(-1, 3, 3)
+ elif R is None:
+ pass
+ else:
+ raise ValueError(f'Wrong type of R: {type(R)}!')
+
+ if R is not None:
+ if len(R) > num_frames:
+ R = R[start:end]
+
+ if isinstance(T, np.ndarray):
+ T = torch.Tensor(T).view(-1, 3)
+ elif isinstance(T, torch.Tensor):
+ T = T.view(-1, 3)
+ elif isinstance(T, list):
+ T = torch.Tensor(T).view(-1, 3)
+ elif T is None:
+ pass
+ else:
+ raise ValueError(f'Wrong type of T: {type(T)}!')
+
+ if T is not None:
+ if len(T) > num_frames:
+ T = T[start:end]
+
+ if isinstance(K, np.ndarray):
+ K = torch.Tensor(K).view(-1, K.shape[-2], K.shape[-1])
+ elif isinstance(K, torch.Tensor):
+ K = K.view(-1, K.shape[-2], K.shape[-1])
+ elif isinstance(K, list):
+ K = torch.Tensor(K)
+ K = K.view(-1, K.shape[-2], K.shape[-1])
+ else:
+ raise ValueError(f'Wrong type of K: {type(K)}!')
+
+ if K is not None:
+ if len(K) > num_frames:
+ K = K[start:end]
+
+ assert projection in [
+ 'perspective', 'weakperspective', 'orthographics', 'fovorthographics',
+ 'fovperspective'
+ ], f'Wrong camera projection: {projection}'
+ if projection in ['fovperspective', 'perspective']:
+ is_perspective = True
+ elif projection in [
+ 'fovorthographics', 'weakperspective', 'orthographics'
+ ]:
+ is_perspective = False
+ if projection in ['fovperspective', 'fovorthographics', 'weakperspective']:
+ assert in_ndc
+
+ K, R, T = convert_camera_matrix(convention_dst='pytorch3d',
+ K=K,
+ R=R,
+ T=T,
+ is_perspective=is_perspective,
+ convention_src=convention,
+ resolution_src=render_resolution,
+ in_ndc_src=in_ndc,
+ in_ndc_dst=in_ndc)
+
+ # initialize the renderer.
+ renderer = SMPLRenderer(resolution=render_resolution,
+ device=device,
+ output_path=output_path,
+ return_tensor=return_tensor,
+ alpha=alpha,
+ read_img_format=img_format,
+ render_choice=render_choice,
+ frames_folder=frames_folder,
+ plot_kps=plot_kps,
+ vis_kp_index=vis_kp_index,
+ final_resolution=final_resolution,
+ **render_param_dict)
+
+ cameras = build_cameras(
+ dict(type=projection,
+ in_ndc=in_ndc,
+ device=device,
+ K=K,
+ R=R,
+ T=T,
+ resolution=render_resolution))
+
+ if image_array is not None:
+ image_array = torch.Tensor(image_array)
+ image_array = align_input_to_padded(image_array,
+ ndim=4,
+ batch_size=num_frames,
+ padding_mode='ones')
+ # prepare the render data.
+ render_data = dict(
+ images=image_array,
+ meshes=meshes,
+ cameras=cameras,
+ joints=joints,
+ joints_gt=kp3d,
+ )
+
+ results = render_runner.render(renderer=renderer,
+ device=device,
+ batch_size=batch_size,
+ output_path=output_path,
+ return_tensor=return_tensor,
+ no_grad=no_grad,
+ verbose=verbose,
+ **render_data)
+
+ if remove_folder:
+ if Path(frames_folder).is_dir():
+ shutil.rmtree(frames_folder)
+
+ if return_tensor:
+ return results
+ else:
+ return None
+
+
+def visualize_smpl_calibration(
+ K,
+ R,
+ T,
+ resolution,
+ **kwargs,
+) -> None:
+ """Visualize a smpl mesh which has opencv calibration matrix defined in
+ screen."""
+ assert K is not None, '`K` is required.'
+ assert resolution is not None, '`resolution`(h, w) is required.'
+ func = partial(render_smpl,
+ projection='perspective',
+ convention='opencv',
+ orig_cam=None,
+ in_ndc=False)
+ for k in func.keywords.keys():
+ if k in kwargs:
+ kwargs.pop(k)
+ return func(K=K, R=R, T=T, resolution=resolution, **kwargs)
+
+
+def visualize_smpl_hmr(cam_transl,
+ bbox=None,
+ kp2d=None,
+ focal_length=5000,
+ det_width=224,
+ det_height=224,
+ bbox_format='xyxy',
+ **kwargs) -> None:
+ """Simplest way to visualize HMR or SPIN or Smplify pred smpl with origin
+ frames and predicted cameras."""
+ if kp2d is not None:
+ bbox = convert_kp2d_to_bbox(kp2d, bbox_format=bbox_format)
+ Ks = convert_bbox_to_intrinsic(bbox, bbox_format=bbox_format)
+ K = torch.Tensor(
+ get_default_hmr_intrinsic(focal_length=focal_length,
+ det_height=det_height,
+ det_width=det_width))
+ func = partial(
+ render_smpl,
+ projection='perspective',
+ convention='opencv',
+ in_ndc=False,
+ K=None,
+ R=None,
+ orig_cam=None,
+ )
+ if isinstance(cam_transl, np.ndarray):
+ cam_transl = torch.Tensor(cam_transl)
+ T = torch.cat([
+ cam_transl[..., [1]], cam_transl[..., [2]], 2 * focal_length /
+ (det_width * cam_transl[..., [0]] + 1e-9)
+ ], -1)
+ for k in func.keywords.keys():
+ if k in kwargs:
+ kwargs.pop(k)
+ return func(Ks=Ks, K=K, T=T, **kwargs)
+
+
+def visualize_smpl_vibe(orig_cam=None,
+ pred_cam=None,
+ bbox=None,
+ output_path='sample.mp4',
+ resolution=None,
+ aspect_ratio=1.0,
+ bbox_scale_factor=1.25,
+ bbox_format='xyxy',
+ **kwargs) -> None:
+ """Simplest way to visualize pred smpl with origin frames and predicted
+ cameras."""
+ assert resolution is not None
+ if pred_cam is not None and bbox is not None:
+ orig_cam = torch.Tensor(
+ convert_crop_cam_to_orig_img(pred_cam, bbox, resolution[1],
+ resolution[0], aspect_ratio,
+ bbox_scale_factor, bbox_format))
+ assert orig_cam is not None, '`orig_cam` is required.'
+
+ func = partial(
+ render_smpl,
+ projection='weakperspective',
+ convention='opencv',
+ in_ndc=True,
+ )
+ for k in func.keywords.keys():
+ if k in kwargs:
+ kwargs.pop(k)
+ return func(orig_cam=orig_cam,
+ output_path=output_path,
+ resolution=resolution,
+ **kwargs)
+
+
+def visualize_T_pose(num_frames,
+ body_model_config=None,
+ body_model=None,
+ orbit_speed=1.0,
+ **kwargs) -> None:
+ """Simplest way to visualize a sequence of T pose."""
+ assert num_frames > 0, '`num_frames` is required.'
+ assert body_model_config is not None or body_model is not None
+ model_type = body_model_config[
+ 'type'] if body_model_config is not None else body_model.name(
+ ).replace('-', '').lower()
+ if model_type == 'smpl':
+ poses = torch.zeros(num_frames, 72)
+ else:
+ poses = torch.zeros(num_frames, 165)
+
+ func = partial(render_smpl,
+ betas=None,
+ transl=None,
+ verts=None,
+ convention='pytorch3d',
+ projection='fovperspective',
+ K=None,
+ R=None,
+ T=None,
+ origin_frames=None)
+ for k in func.keywords.keys():
+ if k in kwargs:
+ kwargs.pop(k)
+ return func(poses=poses,
+ body_model_config=body_model_config,
+ body_model=body_model,
+ orbit_speed=orbit_speed,
+ **kwargs)
+
+
+def visualize_smpl_pose(poses=None, verts=None, **kwargs) -> None:
+ """Simplest way to visualize a sequence of smpl pose.
+
+ Cameras will focus on the center of smpl mesh. `orbit speed` is
+ recommended.
+ """
+ assert (poses
+ is not None) or (verts
+ is not None), 'Pass either `poses` or `verts`.'
+ func = partial(render_smpl,
+ convention='opencv',
+ projection='fovperspective',
+ K=None,
+ R=None,
+ T=None,
+ in_ndc=True,
+ origin_frames=None,
+ frame_list=None,
+ image_array=None)
+ for k in func.keywords.keys():
+ if k in kwargs:
+ kwargs.pop(k)
+ return func(poses=poses, verts=verts, **kwargs)
diff --git a/detrsmpl/data/__init__.py b/detrsmpl/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/data/data_structures/__init__.py b/detrsmpl/data/data_structures/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/data/data_structures/human_data.py b/detrsmpl/data/data_structures/human_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..c32def905edd8474d1dd89b3ec1d0515562a8c89
--- /dev/null
+++ b/detrsmpl/data/data_structures/human_data.py
@@ -0,0 +1,1413 @@
+import logging
+import pickle
+from enum import Enum
+from math import ceil
+from typing import Any, List, Optional, TypeVar, Union, overload
+
+import numpy as np
+import torch
+from mmcv.utils import print_log
+
+from detrsmpl.utils.path_utils import (
+ Existence,
+ check_path_existence,
+ check_path_suffix,
+)
+
+# In T = TypeVar('T'), T can be anything.
+# See definition of typing.TypeVar for details.
+_T1 = TypeVar('_T1')
+_KT = TypeVar('_KT')
+_VT = TypeVar('_VT')
+_HumanData = TypeVar('_HumanData')
+_CPU_DEVICE = torch.device('cpu')
+
+_HumanData_SUPPORTED_KEYS = {
+ 'image_path': {
+ 'type': list,
+ },
+ 'image_id': {
+ 'type': list,
+ },
+ 'bbox_xywh': {
+ 'type': np.ndarray,
+ 'shape': (-1, 5),
+ 'dim': 0
+ },
+ 'config': {
+ 'type': str,
+ 'dim': None
+ },
+ 'keypoints2d': {
+ 'type': np.ndarray,
+ 'shape': (-1, -1, 3),
+ 'dim': 0
+ },
+ 'keypoints3d': {
+ 'type': np.ndarray,
+ 'shape': (-1, -1, 4),
+ 'dim': 0
+ },
+ 'smpl': {
+ 'type': dict,
+ 'slice_key': 'betas',
+ 'dim': 0
+ },
+ 'smplh': {
+ 'type': dict,
+ 'slice_key': 'betas',
+ 'dim': 0
+ },
+ 'smplx': {
+ 'type': dict,
+ 'slice_key': 'betas',
+ 'dim': 0
+ },
+ 'meta': {
+ 'type': dict,
+ },
+ 'keypoints2d_mask': {
+ 'type': np.ndarray,
+ 'shape': (-1, ),
+ 'dim': None
+ },
+ 'keypoints2d_convention': {
+ 'type': str,
+ 'dim': None
+ },
+ 'keypoints3d_mask': {
+ 'type': np.ndarray,
+ 'shape': (-1, ),
+ 'dim': None
+ },
+ 'keypoints3d_convention': {
+ 'type': str,
+ 'dim': None
+ },
+ 'vertices': {
+ 'type': np.ndarray,
+ 'shape': (-1, ),
+ 'dim': None
+ },
+ 'focal_length': {
+ 'type': np.ndarray,
+ 'shape': (-1, ),
+ 'dim': 0
+ },
+ 'principal_point': {
+ 'type': np.ndarray,
+ 'shape': (-1, ),
+ 'dim': 0
+ },
+ 'misc': {
+ 'type': dict,
+ },
+}
+
+
+class _KeyCheck(Enum):
+ PASS = 0
+ WARN = 1
+ ERROR = 2
+
+
+class HumanData(dict):
+ logger = None
+ SUPPORTED_KEYS = _HumanData_SUPPORTED_KEYS
+ WARNED_KEYS = []
+
+ def __new__(cls: _HumanData, *args: Any, **kwargs: Any) -> _HumanData:
+ """New an instance of HumanData.
+
+ Args:
+ cls (HumanData): HumanData class.
+
+ Returns:
+ HumanData: An instance of HumanData.
+ """
+ ret_human_data = super().__new__(cls, args, kwargs)
+ setattr(ret_human_data, '__data_len__', -1)
+ setattr(ret_human_data, '__key_strict__', False)
+ setattr(ret_human_data, '__keypoints_compressed__', False)
+ return ret_human_data
+
+ @classmethod
+ def set_logger(cls, logger: Union[logging.Logger, str, None] = None):
+ """Set logger of HumanData class.
+
+ Args:
+ logger (logging.Logger | str | None, optional):
+ The way to print summary.
+ See `mmcv.utils.print_log()` for details.
+ Defaults to None.
+ """
+ cls.logger = logger
+
+ @classmethod
+ def fromfile(cls, npz_path: str) -> _HumanData:
+ """Construct a HumanData instance from an npz file.
+
+ Args:
+ npz_path (str):
+ Path to a dumped npz file.
+
+ Returns:
+ HumanData:
+ A HumanData instance load from file.
+ """
+ ret_human_data = cls()
+ ret_human_data.load(npz_path)
+ return ret_human_data
+
+ @classmethod
+ def new(cls,
+ source_dict: dict = None,
+ key_strict: bool = False) -> _HumanData:
+ """Construct a HumanData instance from a dict.
+
+ Args:
+ source_dict (dict, optional):
+ A dict with items in HumanData fashion.
+ Defaults to None.
+ key_strict (bool, optional):
+ Whether to raise error when setting unsupported keys.
+ Defaults to False.
+
+ Returns:
+ HumanData:
+ A HumanData instance.
+ """
+ if source_dict is None:
+ ret_human_data = cls()
+ else:
+ ret_human_data = cls(source_dict)
+ ret_human_data.set_key_strict(key_strict)
+ return ret_human_data
+
+ def get_key_strict(self) -> bool:
+ """Get value of attribute key_strict.
+
+ Returns:
+ bool:
+ Whether to raise error when setting unsupported keys.
+ """
+ return self.__key_strict__
+
+ def set_key_strict(self, value: bool):
+ """Set value of attribute key_strict.
+
+ Args:
+ value (bool, optional):
+ Whether to raise error when setting unsupported keys.
+ Defaults to True.
+ """
+ former__key_strict__ = self.__key_strict__
+ self.__key_strict__ = value
+ if former__key_strict__ is False and \
+ value is True:
+ self.pop_unsupported_items()
+
+ def check_keypoints_compressed(self) -> bool:
+ """Check whether the keypoints are compressed.
+
+ Returns:
+ bool:
+ Whether the keypoints are compressed.
+ """
+ return self.__keypoints_compressed__
+
+ def load(self, npz_path: str):
+ """Load data from npz_path and update them to self.
+
+ Args:
+ npz_path (str):
+ Path to a dumped npz file.
+ """
+ supported_keys = self.__class__.SUPPORTED_KEYS
+ with np.load(npz_path, allow_pickle=True) as npz_file:
+ tmp_data_dict = dict(npz_file)
+ for key, value in list(tmp_data_dict.items()):
+ if isinstance(value, np.ndarray) and\
+ len(value.shape) == 0:
+ # value is not an ndarray before dump
+ value = value.item()
+ elif key in supported_keys and\
+ type(value) != supported_keys[key]['type']:
+ value = supported_keys[key]['type'](value)
+ if value is None:
+ tmp_data_dict.pop(key)
+ elif key == '__key_strict__' or \
+ key == '__data_len__' or\
+ key == '__keypoints_compressed__':
+ self.__setattr__(key, value)
+ # pop the attributes to keep dict clean
+ tmp_data_dict.pop(key)
+ elif key == 'bbox_xywh' and value.shape[1] == 4:
+ value = np.hstack([value, np.ones([value.shape[0], 1])])
+ tmp_data_dict[key] = value
+ else:
+ tmp_data_dict[key] = value
+ self.update(tmp_data_dict)
+ self.__set_default_values__()
+
+ def dump(self, npz_path: str, overwrite: bool = True):
+ """Dump keys and items to an npz file.
+
+ Args:
+ npz_path (str):
+ Path to a dumped npz file.
+ overwrite (bool, optional):
+ Whether to overwrite if there is already a file.
+ Defaults to True.
+
+ Raises:
+ ValueError:
+ npz_path does not end with '.npz'.
+ FileExistsError:
+ When overwrite is False and file exists.
+ """
+ if not check_path_suffix(npz_path, ['.npz']):
+ raise ValueError('Not an npz file.')
+ if not overwrite:
+ if check_path_existence(npz_path, 'file') == Existence.FileExist:
+ raise FileExistsError
+ dict_to_dump = {
+ '__key_strict__': self.__key_strict__,
+ '__data_len__': self.__data_len__,
+ '__keypoints_compressed__': self.__keypoints_compressed__,
+ }
+ dict_to_dump.update(self)
+ np.savez_compressed(npz_path, **dict_to_dump)
+
+ def get_sliced_cache(self, slice_size=10) -> List:
+ """Slice the whole HumanData into pieces for HumanDataCacheWriter.
+
+ Args:
+ slice_size (int, optional):
+ The length of each unit in HumanData cache.
+ Defaults to 10.
+
+ Returns:
+ List:
+ Two dicts for HumanDataCacheWriter.
+ Init HumanDataCacheWriter by HumanDataCacheWriter(**Returns[0])
+ and set data by
+ human_data_cache_writer.update_sliced_dict(Returns[1]).
+ """
+ keypoints_info = {}
+ non_sliced_data = {}
+ sliced_data = {}
+ slice_num = ceil(self.__data_len__ / slice_size)
+ for slice_index in range(slice_num):
+ sliced_data[str(slice_index)] = {}
+ dim_dict = self.__get_slice_dim__()
+ for key, dim in dim_dict.items():
+ # no dim to slice
+ if dim is None:
+ if key.startswith('keypoints') and\
+ (key.endswith('_mask') or
+ key.endswith('_convention')):
+ keypoints_info[key] = self[key]
+ else:
+ non_sliced_data[key] = self[key]
+ elif isinstance(dim, dict):
+ value_dict = self.get_raw_value(key)
+ non_sliced_sub_dict = {}
+ for sub_key in value_dict.keys():
+ sub_value = value_dict[sub_key]
+ if dim[sub_key] is None:
+ non_sliced_sub_dict[sub_key] = sub_value
+ else:
+ sub_dim = dim[sub_key]
+ for slice_index in range(slice_num):
+ slice_start = slice_index * slice_size
+ slice_end = min((slice_index + 1) * slice_size,
+ self.__data_len__)
+ slice_range = slice(slice_start, slice_end)
+ sliced_sub_value = \
+ HumanData.__get_sliced_result__(
+ sub_value, sub_dim, slice_range
+ )
+ if key not in sliced_data[str(slice_index)]:
+ sliced_data[str(slice_index)][key] = {}
+ sliced_data[str(slice_index)][key][sub_key] = \
+ sliced_sub_value
+ if len(non_sliced_sub_dict) > 0:
+ non_sliced_data[key] = non_sliced_sub_dict
+ else:
+ value = self.get_raw_value(key)
+ # slice as ndarray
+ if isinstance(value, np.ndarray):
+ slice_list = [
+ slice(None),
+ ] * len(value.shape)
+ for slice_index in range(slice_num):
+ slice_start = slice_index * slice_size
+ slice_end = min((slice_index + 1) * slice_size,
+ self.__data_len__)
+ slice_list[dim] = slice(slice_start, slice_end)
+ sliced_value = value[tuple(slice_list)]
+ sliced_data[str(slice_index)][key] = sliced_value
+ # slice as list/tuple
+ else:
+ for slice_index in range(slice_num):
+ slice_start = slice_index * slice_size
+ slice_end = min((slice_index + 1) * slice_size,
+ self.__data_len__)
+ sliced_value = value[slice(slice_start, slice_end)]
+ sliced_data[str(slice_index)][key] = sliced_value
+ writer_args_dict = {
+ 'slice_size': slice_size,
+ 'keypoints_info': keypoints_info,
+ 'data_len': self.data_len,
+ 'non_sliced_data': non_sliced_data,
+ 'key_strict': self.get_key_strict()
+ }
+ return writer_args_dict, sliced_data
+
+ def to(self,
+ device: Optional[Union[torch.device, str]] = _CPU_DEVICE,
+ dtype: Optional[torch.dtype] = None,
+ non_blocking: Optional[bool] = False,
+ copy: Optional[bool] = False,
+ memory_format: Optional[torch.memory_format] = None) -> dict:
+ """Convert values in numpy.ndarray type to torch.Tensor, and move
+ Tensors to the target device. All keys will exist in the returned dict.
+
+ Args:
+ device (Union[torch.device, str], optional):
+ A specified device. Defaults to CPU_DEVICE.
+ dtype (torch.dtype, optional):
+ The data type of the expected torch.Tensor.
+ If dtype is None, it is decided according to numpy.ndarry.
+ Defaults to None.
+ non_blocking (bool, optional):
+ When non_blocking, tries to convert asynchronously with
+ respect to the host if possible, e.g.,
+ converting a CPU Tensor with pinned memory to a CUDA Tensor.
+ Defaults to False.
+ copy (bool, optional):
+ When copy is set, a new Tensor is created even when
+ the Tensor already matches the desired conversion.
+ No matter what value copy is, Tensor constructed from numpy
+ will not share the same memory with the source numpy.ndarray.
+ Defaults to False.
+ memory_format (torch.memory_format, optional):
+ The desired memory format of returned Tensor.
+ Not supported by pytorch-cpu.
+ Defaults to None.
+
+ Returns:
+ dict:
+ A dict with all numpy.ndarray values converted into
+ torch.Tensor and all Tensors moved to the target device.
+ """
+ ret_dict = {}
+ for key in self.keys():
+ raw_value = self.get_raw_value(key)
+ tensor_value = None
+ if isinstance(raw_value, np.ndarray):
+ tensor_value = torch.from_numpy(raw_value).clone()
+ elif isinstance(raw_value, torch.Tensor):
+ tensor_value = raw_value
+ if tensor_value is None:
+ ret_dict[key] = raw_value
+ else:
+ if memory_format is None:
+ ret_dict[key] = \
+ tensor_value.to(device, dtype,
+ non_blocking, copy)
+ else:
+ ret_dict[key] = \
+ tensor_value.to(device, dtype,
+ non_blocking, copy,
+ memory_format=memory_format)
+ return ret_dict
+
+ def __getitem__(self, key: _KT) -> _VT:
+ """Get value defined by HumanData. This function will be called by
+ self[key]. In keypoints_compressed mode, if the key contains
+ 'keypoints', an array with zero-padding at absent keypoint will be
+ returned. Call self.get_raw_value(k) to get value without padding.
+
+ Args:
+ key (_KT):
+ Key in HumanData.
+
+ Returns:
+ _VT:
+ Value to the key.
+ """
+ value = super().__getitem__(key)
+ if self.__keypoints_compressed__:
+ mask_key = f'{key}_mask'
+ if key in self and \
+ isinstance(value, np.ndarray) and \
+ 'keypoints' in key and \
+ mask_key in self:
+ mask_array = np.asarray(super().__getitem__(mask_key))
+ value = \
+ self.__class__.__add_zero_pad__(value, mask_array)
+ return value
+
+ def get_raw_value(self, key: _KT) -> _VT:
+ """Get raw value from the dict. It acts the same as
+ dict.__getitem__(k).
+
+ Args:
+ key (_KT):
+ Key in dict.
+
+ Returns:
+ _VT:
+ Value to the key.
+ """
+ value = super().__getitem__(key)
+ return value
+
+ def get_value_in_shape(self,
+ key: _KT,
+ shape: Union[list, tuple],
+ padding_constant: int = 0) -> np.ndarray:
+ """Get value in a specific shape. For each dim, if the required shape
+ is smaller than current shape, ndarray will be sliced. Otherwise, it
+ will be padded with padding_constant at the end.
+
+ Args:
+ key (_KT):
+ Key in dict. The value of this key must be
+ an instance of numpy.ndarray.
+ shape (Union[list, tuple]):
+ Shape of the returned array. Its length
+ must be equal to value.ndim. Set -1 for
+ a dimension if you do not want to edit it.
+ padding_constant (int, optional):
+ The value to set the padded values for each axis.
+ Defaults to 0.
+
+ Raises:
+ ValueError:
+ A value in shape is neither positive integer nor -1.
+
+ Returns:
+ np.ndarray:
+ An array in required shape.
+ """
+ value = self.get_raw_value(key)
+ assert isinstance(value, np.ndarray)
+ assert value.ndim == len(shape)
+ pad_width_list = []
+ slice_list = []
+ for dim_index in range(len(shape)):
+ if shape[dim_index] == -1:
+ # no pad or slice
+ pad_width_list.append((0, 0))
+ slice_list.append(slice(None))
+ elif shape[dim_index] > 0:
+ # valid shape value
+ wid = shape[dim_index] - value.shape[dim_index]
+ if wid > 0:
+ pad_width_list.append((0, wid))
+ else:
+ pad_width_list.append((0, 0))
+ slice_list.append(slice(0, shape[dim_index]))
+ else:
+ # invalid
+ raise ValueError
+ pad_value = np.pad(value,
+ pad_width=pad_width_list,
+ mode='constant',
+ constant_values=padding_constant)
+ return pad_value[tuple(slice_list)]
+
+ @overload
+ def get_slice(self, stop: int):
+ """Slice [0, stop, 1] of all sliceable values."""
+ ...
+
+ @overload
+ def get_slice(self, start: int, stop: int):
+ """Slice [start, stop, 1] of all sliceable values."""
+ ...
+
+ @overload
+ def get_slice(self, start: int, stop: int, step: int):
+ """Slice [start, stop, step] of all sliceable values."""
+ ...
+
+ def get_slice(self,
+ arg_0: int,
+ arg_1: Union[int, Any] = None,
+ step: int = 1) -> _HumanData:
+ """Slice all sliceable values along major_dim dimension.
+
+ Args:
+ arg_0 (int):
+ When arg_1 is None, arg_0 is stop and start=0.
+ When arg_1 is not None, arg_0 is start.
+ arg_1 (Union[int, Any], optional):
+ None or where to stop.
+ Defaults to None.
+ step (int, optional):
+ Length of step. Defaults to 1.
+
+ Returns:
+ HumanData:
+ A new HumanData instance with sliced values.
+ """
+ ret_human_data = \
+ HumanData.new(key_strict=self.get_key_strict())
+ if arg_1 is None:
+ start = 0
+ stop = arg_0
+ else:
+ start = arg_0
+ stop = arg_1
+ slice_index = slice(start, stop, step)
+ dim_dict = self.__get_slice_dim__()
+ for key, dim in dim_dict.items():
+ # keys not expected be sliced
+ if dim is None:
+ ret_human_data[key] = self[key]
+ elif isinstance(dim, dict):
+ value_dict = self.get_raw_value(key)
+ sliced_dict = {}
+ for sub_key in value_dict.keys():
+ sub_value = value_dict[sub_key]
+ if dim[sub_key] is None:
+ sliced_dict[sub_key] = sub_value
+ else:
+ sub_dim = dim[sub_key]
+ sliced_sub_value = \
+ HumanData.__get_sliced_result__(
+ sub_value, sub_dim, slice_index)
+ sliced_dict[sub_key] = sliced_sub_value
+ ret_human_data[key] = sliced_dict
+ else:
+ value = self[key]
+ sliced_value = \
+ HumanData.__get_sliced_result__(
+ value, dim, slice_index)
+ ret_human_data[key] = sliced_value
+ # check keypoints compressed
+ if self.check_keypoints_compressed():
+ ret_human_data.compress_keypoints_by_mask()
+ return ret_human_data
+
+ def __get_slice_dim__(self) -> dict:
+ """For each key in this HumanData, get the dimension for slicing. 0 for
+ default, if no other value specified.
+
+ Returns:
+ dict:
+ Keys are self.keys().
+ Values indicate where to slice.
+ None for not expected to be sliced or
+ failed.
+ """
+ supported_keys = self.__class__.SUPPORTED_KEYS
+ ret_dict = {}
+ for key in self.keys():
+ # keys not expected be sliced
+ if key in supported_keys and \
+ 'dim' in supported_keys[key] and \
+ supported_keys[key]['dim'] is None:
+ ret_dict[key] = None
+ else:
+ value = self[key]
+ if isinstance(value, dict) and len(value) > 0:
+ ret_dict[key] = {}
+ for sub_key in value.keys():
+ try:
+ sub_value_len = len(value[sub_key])
+ if sub_value_len != self.__data_len__:
+ ret_dict[key][sub_key] = None
+ elif 'dim' in value:
+ ret_dict[key][sub_key] = value['dim']
+ else:
+ ret_dict[key][sub_key] = 0
+ except TypeError:
+ ret_dict[key][sub_key] = None
+ continue
+ # instance cannot be sliced without len method
+ try:
+ value_len = len(value)
+ except TypeError:
+ ret_dict[key] = None
+ continue
+ # slice on dim 0 by default
+ slice_dim = 0
+ if key in supported_keys and \
+ 'dim' in supported_keys[key]:
+ slice_dim = \
+ supported_keys[key]['dim']
+ data_len = value_len if slice_dim == 0 \
+ else value.shape[slice_dim]
+ # dim not for slice
+ if data_len != self.__data_len__:
+ ret_dict[key] = None
+ continue
+ else:
+ ret_dict[key] = slice_dim
+ return ret_dict
+
+ def __setitem__(self, key: _KT, val: _VT) -> None:
+ """Set self[key] to value. Only be called when using
+ human_data[key] = val. Methods like update won't call __setitem__.
+ In keypoints_compressed mode, if the key contains 'keypoints',
+ and f'{key}_mask' is in self.keys(), invalid zeros
+ will be removed before setting value.
+
+ Args:
+ key (_KT):
+ Key in HumanData.
+ Better be an element in HumanData.SUPPORTED_KEYS.
+ If not, an Error will be raised in key_strict mode.
+ val (_VT):
+ Value to the key.
+
+ Raises:
+ KeyError:
+ self.get_key_strict() is True and
+ key cannot be found in
+ HumanData.SUPPORTED_KEYS.
+ ValueError:
+ Value is supported but doesn't match definition.
+ ValueError:
+ self.check_keypoints_compressed() is True and
+ mask of a keypoint item is missing.
+ """
+ self.__check_key__(key)
+ self.__check_value__(key, val)
+ # if it can be compressed by mask
+ if self.__keypoints_compressed__:
+ class_logger = self.__class__.logger
+ if 'keypoints' in key and \
+ '_mask' in key:
+ msg = 'Mask cannot be modified ' +\
+ 'in keypoints_compressed mode.'
+ print_log(msg=msg, logger=class_logger, level=logging.WARN)
+ return
+ elif isinstance(val, np.ndarray) and \
+ 'keypoints' in key and \
+ '_mask' not in key:
+ mask_key = f'{key}_mask'
+ if mask_key in self:
+ mask_array = np.asarray(super().__getitem__(mask_key))
+ val = \
+ self.__class__.__remove_zero_pad__(val, mask_array)
+ else:
+ msg = f'Mask for {key} has not been set.' +\
+ f' Please set {mask_key} before compression.'
+ print_log(msg=msg,
+ logger=class_logger,
+ level=logging.ERROR)
+ raise ValueError
+ dict.__setitem__(self, key, val)
+
+ def set_raw_value(self, key: _KT, val: _VT) -> None:
+ """Set the raw value of self[key] to val after key check. It acts the
+ same as dict.__setitem__(self, key, val) if the key satisfied
+ constraints.
+
+ Args:
+ key (_KT):
+ Key in dict.
+ val (_VT):
+ Value to the key.
+
+ Raises:
+ KeyError:
+ self.get_key_strict() is True and
+ key cannot be found in
+ HumanData.SUPPORTED_KEYS.
+ ValueError:
+ Value is supported but doesn't match definition.
+ """
+ self.__check_key__(key)
+ self.__check_value__(key, val)
+ dict.__setitem__(self, key, val)
+
+ def pop_unsupported_items(self) -> None:
+ """Find every item with a key not in HumanData.SUPPORTED_KEYS, and pop
+ it to save memory."""
+ for key in list(self.keys()):
+ if key not in self.__class__.SUPPORTED_KEYS:
+ self.pop(key)
+
+ def __check_key__(self, key: Any) -> _KeyCheck:
+ """Check whether the key matches definition in
+ HumanData.SUPPORTED_KEYS.
+
+ Args:
+ key (Any):
+ Key in HumanData.
+
+ Returns:
+ _KeyCheck:
+ PASS, WARN or ERROR.
+
+ Raises:
+ KeyError:
+ self.get_key_strict() is True and
+ key cannot be found in
+ HumanData.SUPPORTED_KEYS.
+ """
+ ret_key_check = _KeyCheck.PASS
+ if self.get_key_strict():
+ if key not in self.__class__.SUPPORTED_KEYS:
+ ret_key_check = _KeyCheck.ERROR
+ else:
+ if key not in self.__class__.SUPPORTED_KEYS and \
+ key not in self.__class__.WARNED_KEYS:
+ # log warning message at the first time
+ ret_key_check = _KeyCheck.WARN
+ self.__class__.WARNED_KEYS.append(key)
+ if ret_key_check == _KeyCheck.ERROR:
+ raise KeyError(self.__class__.__get_key_error_msg__(key))
+ elif ret_key_check == _KeyCheck.WARN:
+ class_logger = self.__class__.logger
+ if class_logger == 'silent':
+ pass
+ else:
+ print_log(msg=self.__class__.__get_key_warn_msg__(key),
+ logger=class_logger,
+ level=logging.WARN)
+ return ret_key_check
+
+ def __check_value__(self, key: Any, val: Any) -> bool:
+ """Check whether the value matches definition in
+ HumanData.SUPPORTED_KEYS.
+
+ Args:
+ key (Any):
+ Key in HumanData.
+ val (Any):
+ Value to the key.
+
+ Returns:
+ bool:
+ True for matched, ortherwise False.
+
+ Raises:
+ ValueError:
+ Value is supported but doesn't match definition.
+ """
+ ret_bool = self.__check_value_type__(key, val) and\
+ self.__check_value_shape__(key, val) and\
+ self.__check_value_len__(key, val)
+ if not ret_bool:
+ raise ValueError(self.__class__.__get_value_error_msg__())
+ return ret_bool
+
+ def __check_value_type__(self, key: Any, val: Any) -> bool:
+ """Check whether the type of val matches definition in
+ HumanData.SUPPORTED_KEYS.
+
+ Args:
+ key (Any):
+ Key in HumanData.
+ val (Any):
+ Value to the key.
+
+ Returns:
+ bool:
+ If type doesn't match, return False.
+ Else return True.
+ """
+ ret_bool = True
+ supported_keys = self.__class__.SUPPORTED_KEYS
+ # check definition
+ if key in supported_keys:
+ # check type
+ if type(val) != supported_keys[key]['type']:
+ ret_bool = False
+ if not ret_bool:
+ expected_type = supported_keys[key]['type']
+ err_msg = 'Type check Failed:\n'
+ err_msg += f'key={str(key)}\n'
+ err_msg += f'type(val)={type(val)}\n'
+ err_msg += f'expected type={expected_type}\n'
+ print_log(msg=err_msg,
+ logger=self.__class__.logger,
+ level=logging.ERROR)
+ return ret_bool
+
+ def __check_value_shape__(self, key: Any, val: Any) -> bool:
+ """Check whether the shape of val matches definition in
+ HumanData.SUPPORTED_KEYS.
+
+ Args:
+ key (Any):
+ Key in HumanData.
+ val (Any):
+ Value to the key.
+
+ Returns:
+ bool:
+ If expected shape is defined and doesn't match,
+ return False.
+ Else return True.
+ """
+ ret_bool = True
+ supported_keys = self.__class__.SUPPORTED_KEYS
+ # check definition
+ if key in supported_keys:
+ # check shape
+ if 'shape' in supported_keys[key]:
+ val_shape = val.shape
+ for shape_ind in range(len(supported_keys[key]['shape'])):
+ # length not match
+ if shape_ind >= len(val_shape):
+ ret_bool = False
+ break
+ expect_val = supported_keys[key]['shape'][shape_ind]
+ # value not match
+ if expect_val > 0 and \
+ expect_val != val_shape[shape_ind]:
+ ret_bool = False
+ break
+ if not ret_bool:
+ expected_shape = str(supported_keys[key]['shape'])
+ expected_shape = expected_shape.replace('-1', 'Any')
+ err_msg = 'Shape check Failed:\n'
+ err_msg += f'key={str(key)}\n'
+ err_msg += f'val.shape={val_shape}\n'
+ err_msg += f'expected shape={expected_shape}\n'
+ print_log(msg=err_msg,
+ logger=self.__class__.logger,
+ level=logging.ERROR)
+ return ret_bool
+
+ @property
+ def data_len(self) -> int:
+ """Get the temporal length of this HumanData instance.
+
+ Returns:
+ int:
+ Number of frames related to this instance.
+ """
+ return self.__data_len__
+
+ @data_len.setter
+ def data_len(self, value: int):
+ """Set the temporal length of this HumanData instance.
+
+ Args:
+ value (int):
+ Number of frames related to this instance.
+ """
+ self.__data_len__ = value
+
+ def __check_value_len__(self, key: Any, val: Any) -> bool:
+ """Check whether the temporal length of val matches other values.
+
+ Args:
+ key (Any):
+ Key in HumanData.
+ val (Any):
+ Value to the key.
+
+ Returns:
+ bool:
+ If temporal dim is defined and temporal length doesn't match,
+ return False.
+ Else return True.
+ """
+ ret_bool = True
+ supported_keys = self.__class__.SUPPORTED_KEYS
+ # check definition
+ if key in supported_keys:
+ # check temporal length
+ if 'dim' in supported_keys[key] and \
+ supported_keys[key]['dim'] is not None:
+ val_slice_dim = supported_keys[key]['dim']
+ if supported_keys[key]['type'] == dict:
+ slice_key = supported_keys[key]['slice_key']
+ val_data_len = val[slice_key].shape[val_slice_dim]
+ else:
+ val_data_len = val.shape[val_slice_dim]
+ if self.data_len < 0:
+ # no data_len yet, assign a new one
+ self.data_len = val_data_len
+ else:
+ # check if val_data_len matches recorded data_len
+ if self.data_len != val_data_len:
+ ret_bool = False
+ if not ret_bool:
+ err_msg = 'Temporal check Failed:\n'
+ err_msg += f'key={str(key)}\n'
+ err_msg += f'val\'s data_len={val_data_len}\n'
+ err_msg += f'expected data_len={self.data_len}\n'
+ print_log(msg=err_msg,
+ logger=self.__class__.logger,
+ level=logging.ERROR)
+ return ret_bool
+
+ def generate_mask_from_confidence(self, keys=None) -> None:
+ """Generate mask from keypoints' confidence. Keypoints that have zero
+ confidence in all occurrences will have a zero mask. Note that the last
+ value of the keypoint is assumed to be confidence.
+
+ Args:
+ keys: None, str, or list of str.
+ None: all keys with `keypoint` in it will have mask
+ generated from their confidence.
+ str: key of the keypoint, the mask has name f'{key}_name'
+ list of str: a list of keys of the keypoints.
+ Generate mask for multiple keypoints.
+ Defaults to None.
+
+ Returns:
+ None
+
+ Raises:
+ KeyError:
+ A key is not not found
+ """
+ if keys is None:
+ keys = []
+ for key in self.keys():
+ val = self.get_raw_value(key)
+ if isinstance(val, np.ndarray) and \
+ 'keypoints' in key and \
+ '_mask' not in key:
+ keys.append(key)
+ elif isinstance(keys, str):
+ keys = [keys]
+ elif isinstance(keys, list):
+ for key in keys:
+ assert isinstance(key, str)
+ else:
+ raise TypeError(f'`Keys` must be None, str, or list of str, '
+ f'got {type(keys)}.')
+
+ update_dict = {}
+ for kpt_key in keys:
+ kpt_array = self.get_raw_value(kpt_key)
+ num_joints = kpt_array.shape[-2]
+ # if all conf of a joint are zero, this joint is masked
+ joint_conf = kpt_array[..., -1].reshape(-1, num_joints)
+ mask_array = (joint_conf > 0).astype(np.uint8).max(axis=0)
+ assert len(mask_array) == num_joints
+ # generate mask
+ update_dict[f'{kpt_key}_mask'] = mask_array
+ self.update(update_dict)
+
+ def compress_keypoints_by_mask(self) -> None:
+ """If a key contains 'keypoints', and f'{key}_mask' is in self.keys(),
+ invalid zeros will be removed and f'{key}_mask' will be locked.
+
+ Raises:
+ KeyError:
+ A key contains 'keypoints' has been found
+ but its corresponding mask is missing.
+ """
+ assert self.__keypoints_compressed__ is False
+ key_pairs = []
+ for key in self.keys():
+ mask_key = f'{key}_mask'
+ val = self.get_raw_value(key)
+ if isinstance(val, np.ndarray) and \
+ 'keypoints' in key and \
+ '_mask' not in key and 'has' not in key:
+ if mask_key in self:
+ key_pairs.append([key, mask_key])
+ else:
+ msg = f'Mask for {key} has not been set.' +\
+ f'Please set {mask_key} before compression.'
+ raise KeyError(msg)
+ compressed_dict = {}
+ for kpt_key, mask_key in key_pairs:
+ kpt_array = self.get_raw_value(kpt_key)
+ mask_array = np.asarray(self.get_raw_value(mask_key))
+ compressed_kpt = \
+ self.__class__.__remove_zero_pad__(kpt_array, mask_array)
+ compressed_dict[kpt_key] = compressed_kpt
+ # set value after all pairs are compressed
+ self.update(compressed_dict)
+ self.__keypoints_compressed__ = True
+
+ def decompress_keypoints(self) -> None:
+ """If a key contains 'keypoints', and f'{key}_mask' is in self.keys(),
+ invalid zeros will be inserted to the right places and f'{key}_mask'
+ will be unlocked.
+
+ Raises:
+ KeyError:
+ A key contains 'keypoints' has been found
+ but its corresponding mask is missing.
+ """
+ assert self.__keypoints_compressed__ is True
+ key_pairs = []
+ for key in self.keys():
+ mask_key = f'{key}_mask'
+ val = self.get_raw_value(key)
+ if isinstance(val, np.ndarray) and \
+ 'keypoints' in key and \
+ '_mask' not in key:
+ if mask_key in self:
+ key_pairs.append([key, mask_key])
+ else:
+ class_logger = self.__class__.logger
+ msg = f'Mask for {key} has not been found.' +\
+ f'Please remove {key} before decompression.'
+ print_log(msg=msg,
+ logger=class_logger,
+ level=logging.ERROR)
+ raise KeyError
+ decompressed_dict = {}
+ for kpt_key, mask_key in key_pairs:
+ mask_array = np.asarray(self.get_raw_value(mask_key))
+ compressed_kpt = self.get_raw_value(kpt_key)
+ kpt_array = \
+ self.__class__.__add_zero_pad__(compressed_kpt, mask_array)
+ decompressed_dict[kpt_key] = kpt_array
+ # set value after all pairs are decompressed
+ self.update(decompressed_dict)
+ self.__keypoints_compressed__ = False
+
+ def dump_by_pickle(self, pkl_path: str, overwrite: bool = True) -> None:
+ """Dump keys and items to a pickle file. It's a secondary dump method,
+ when a HumanData instance is too large to be dumped by self.dump()
+
+ Args:
+ pkl_path (str):
+ Path to a dumped pickle file.
+ overwrite (bool, optional):
+ Whether to overwrite if there is already a file.
+ Defaults to True.
+
+ Raises:
+ ValueError:
+ npz_path does not end with '.pkl'.
+ FileExistsError:
+ When overwrite is False and file exists.
+ """
+ if not check_path_suffix(pkl_path, ['.pkl']):
+ raise ValueError('Not an pkl file.')
+ if not overwrite:
+ if check_path_existence(pkl_path, 'file') == Existence.FileExist:
+ raise FileExistsError
+ dict_to_dump = {
+ '__key_strict__': self.__key_strict__,
+ '__data_len__': self.__data_len__,
+ '__keypoints_compressed__': self.__keypoints_compressed__,
+ }
+ dict_to_dump.update(self)
+ with open(pkl_path, 'wb') as f_writeb:
+ pickle.dump(dict_to_dump,
+ f_writeb,
+ protocol=pickle.HIGHEST_PROTOCOL)
+
+ def load_by_pickle(self, pkl_path: str) -> None:
+ """Load data from pkl_path and update them to self.
+
+ When a HumanData Instance was dumped by
+ self.dump_by_pickle(), use this to load.
+ Args:
+ npz_path (str):
+ Path to a dumped npz file.
+ """
+ with open(pkl_path, 'rb') as f_readb:
+ tmp_data_dict = pickle.load(f_readb)
+ for key, value in list(tmp_data_dict.items()):
+ if value is None:
+ tmp_data_dict.pop(key)
+ elif key == '__key_strict__' or \
+ key == '__data_len__' or\
+ key == '__keypoints_compressed__':
+ self.__setattr__(key, value)
+ # pop the attributes to keep dict clean
+ tmp_data_dict.pop(key)
+ elif key == 'bbox_xywh' and value.shape[1] == 4:
+ value = np.hstack([value, np.ones([value.shape[0], 1])])
+ tmp_data_dict[key] = value
+ else:
+ tmp_data_dict[key] = value
+ self.update(tmp_data_dict)
+ self.__set_default_values__()
+
+ def __set_default_values__(self) -> None:
+ """For older versions of HumanData, call this method to apply missing
+ values (also attributes)."""
+ supported_keys = self.__class__.SUPPORTED_KEYS
+ if self.__data_len__ == -1:
+ for key in supported_keys:
+ if key in self and \
+ 'dim' in supported_keys[key] and\
+ supported_keys[key]['dim'] is not None:
+ if 'slice_key' in supported_keys[key] and\
+ supported_keys[key]['type'] == dict:
+ sub_key = supported_keys[key]['slice_key']
+ slice_dim = supported_keys[key]['dim']
+ self.__data_len__ = \
+ self[key][sub_key].shape[slice_dim]
+ else:
+ slice_dim = supported_keys[key]['dim']
+ self.__data_len__ = self[key].shape[slice_dim]
+ break
+ for key in list(self.keys()):
+ convention_key = f'{key}_convention'
+ if key.startswith('keypoints') and \
+ not key.endswith('_mask') and \
+ not key.endswith('_convention') and \
+ convention_key not in self:
+ self[convention_key] = 'human_data'
+
+ @classmethod
+ def concatenate(cls, human_data_0: _HumanData,
+ human_data_1: _HumanData) -> _HumanData:
+ """Concatenate two human_data. All keys will be kept it the returned
+ human_data. If either value from human_data_0 or human_data_1 matches
+ data_len from its HumanData, the two values will be concatenated as a
+ single value. If not, postfix will be added to the key to specify
+ source of the value.
+
+ Args:
+ human_data_0 (_HumanData)
+ human_data_1 (_HumanData)
+
+ Returns:
+ _HumanData:
+ A new human_data instance with all concatenated data.
+ """
+ ret_human_data = cls.new(key_strict=False)
+ set_0 = set(human_data_0.keys())
+ set_1 = set(human_data_1.keys())
+ common_keys = set_0.intersection(set_1)
+ dim_dict_0 = human_data_0.__get_slice_dim__()
+ dim_dict_1 = human_data_1.__get_slice_dim__()
+ for key in common_keys:
+ value_0 = human_data_0[key]
+ value_1 = human_data_1[key]
+ # align type
+ value_0 = list(value_0) if isinstance(value_0, tuple)\
+ else value_0
+ value_1 = list(value_1) if isinstance(value_1, tuple)\
+ else value_1
+ assert type(value_0) == type(value_1)
+ # align convention
+ if key.startswith('keypoints') and\
+ key.endswith('_convention'):
+ assert value_0 == value_1
+ ret_human_data[key] = value_0
+ continue
+ # mask_0 and mask_1
+ elif key.startswith('keypoints') and\
+ key.endswith('_mask'):
+ new_mask = value_0 * value_1
+ ret_human_data[key] = new_mask
+ continue
+ # go through the sub dict
+ if isinstance(value_0, dict):
+ sub_dict = {}
+ for sub_key, sub_value_0 in value_0.items():
+ # only found in value_0
+ if sub_key not in value_1:
+ sub_dict[sub_key] = sub_value_0
+ # found in both values
+ else:
+ sub_value_1 = value_1[sub_key]
+ concat_sub_dict = cls.__concat_value__(
+ key=sub_key,
+ value_0=sub_value_0,
+ dim_0=dim_dict_0[key][sub_key],
+ value_1=sub_value_1,
+ dim_1=dim_dict_1[key][sub_key])
+ sub_dict.update(concat_sub_dict)
+ for sub_key, sub_value_1 in value_1.items():
+ if sub_key not in value_0:
+ sub_dict[sub_key] = sub_value_1
+
+ ret_human_data[key] = sub_dict
+ # try concat
+ else:
+ concat_dict = cls.__concat_value__(key=key,
+ value_0=value_0,
+ dim_0=dim_dict_0[key],
+ value_1=value_1,
+ dim_1=dim_dict_1[key])
+ ret_human_data.update(concat_dict)
+ # check exclusive keys
+ for key, value in human_data_0.items():
+ if key not in common_keys:
+ # value not for concat and slice
+ if dim_dict_0[key] is None:
+ ret_human_data[key] = value
+ # value aligned with data_len of HumanData_0
+ else:
+ ret_human_data[f'{key}_0'] = value
+ for key, value in human_data_1.items():
+ if key not in common_keys:
+ # same as above
+ if dim_dict_1[key] is None:
+ ret_human_data[key] = value
+ else:
+ ret_human_data[f'{key}_1'] = value
+ return ret_human_data
+
+ @classmethod
+ def __concat_value__(cls, key: Any, value_0: Any, value_1: Any,
+ dim_0: Union[None, int], dim_1: Union[None,
+ int]) -> dict:
+ """Concat two values from two different HumanData.
+
+ Args:
+ key (Any):
+ The common key of the two values.
+ value_0 (Any):
+ Value from 0.
+ value_1 (Any):
+ Value from 1.
+ dim_0 (Union[None, int]):
+ The dim for concat and slice. None for N/A.
+ dim_1 (Union[None, int]):
+ The dim for concat and slice. None for N/A.
+
+ Returns:
+ dict:
+ Dict for concatenated result.
+ """
+ ret_dict = {}
+ if dim_0 is None or dim_1 is None:
+ ret_dict[f'{key}_0'] = value_0
+ ret_dict[f'{key}_1'] = value_1
+ elif isinstance(value_0, list):
+ ret_dict[key] = value_0 + value_1
+ # elif isinstance(value_0, np.ndarray):
+ else:
+ ret_dict[key] = np.concatenate((value_0, value_1), axis=dim_0)
+ return ret_dict
+
+ @classmethod
+ def __add_zero_pad__(cls, compressed_array: np.ndarray,
+ mask_array: np.ndarray) -> np.ndarray:
+ """Pad zeros to a compressed keypoints array.
+
+ Args:
+ compressed_array (np.ndarray):
+ A compressed keypoints array.
+ mask_array (np.ndarray):
+ The mask records compression relationship.
+
+ Returns:
+ np.ndarray:
+ A keypoints array in full-size.
+ """
+ assert mask_array.sum() == compressed_array.shape[1]
+ data_len, _, dim = compressed_array.shape
+ mask_len = mask_array.shape[0]
+ ret_value = np.zeros(shape=[data_len, mask_len, dim],
+ dtype=compressed_array.dtype)
+ valid_mask_index = np.where(mask_array == 1)[0]
+ ret_value[:, valid_mask_index, :] = compressed_array
+ return ret_value
+
+ @classmethod
+ def __remove_zero_pad__(cls, zero_pad_array: np.ndarray,
+ mask_array: np.ndarray) -> np.ndarray:
+ """Remove zero-padding from a full-size keypoints array.
+
+ Args:
+ zero_pad_array (np.ndarray):
+ A keypoints array in full-size.
+ mask_array (np.ndarray):
+ The mask records compression relationship.
+
+ Returns:
+ np.ndarray:
+ A compressed keypoints array.
+ """
+ assert mask_array.shape[0] == zero_pad_array.shape[1]
+ valid_mask_index = np.where(mask_array == 1)[0]
+ ret_value = np.take(zero_pad_array, valid_mask_index, axis=1)
+ return ret_value
+
+ @classmethod
+ def __get_key_warn_msg__(cls, key: Any) -> str:
+ """Get the warning message when a key fails the check.
+
+ Args:
+ key (Any):
+ The key with wrong.
+
+ Returns:
+ str:
+ The warning message.
+ """
+ class_name = cls.__name__
+ warn_message = \
+ f'{key} is absent in' +\
+ f' {class_name}.SUPPORTED_KEYS.\n'
+ suggestion_message = \
+ 'Ignore this if you know exactly' +\
+ ' what you are doing.\n' +\
+ 'Otherwise, Call self.set_key_strict(True)' +\
+ ' to avoid wrong keys.\n'
+ return warn_message + suggestion_message
+
+ @classmethod
+ def __get_key_error_msg__(cls, key: Any) -> str:
+ """Get the error message when a key fails the check.
+
+ Args:
+ key (Any):
+ The key with wrong.
+
+ Returns:
+ str:
+ The error message.
+ """
+ class_name = cls.__name__
+ absent_message = \
+ f'{key} is absent in' +\
+ f' {class_name}.SUPPORTED_KEYS.\n'
+ suggestion_message = \
+ 'Call self.set_key_strict(False)' +\
+ ' to allow unsupported keys.\n'
+ return absent_message + suggestion_message
+
+ @classmethod
+ def __get_value_error_msg__(cls) -> str:
+ """Get the error message when a value fails the check.
+
+ Returns:
+ str:
+ The error message.
+ """
+ error_message = \
+ 'An supported value doesn\'t ' +\
+ 'match definition.\n'
+ suggestion_message = \
+ 'See error log for details.\n'
+ return error_message + suggestion_message
+
+ @classmethod
+ def __get_sliced_result__(
+ cls, input_data: Union[np.ndarray, list, tuple], slice_dim: int,
+ slice_range: slice) -> Union[np.ndarray, list, tuple]:
+ """Slice input_data along slice_dim with slice_range.
+
+ Args:
+ input_data (Union[np.ndarray, list, tuple]):
+ Data to be sliced.
+ slice_dim (int):
+ Dimension to be sliced.
+ slice_range (slice):
+ An instance of class slice.
+
+ Returns:
+ Union[np.ndarray, list, tuple]:
+ A slice of input_data.
+ """
+ if isinstance(input_data, np.ndarray):
+ slice_list = [
+ slice(None),
+ ] * len(input_data.shape)
+ slice_list[slice_dim] = slice_range
+ sliced_data = input_data[tuple(slice_list)]
+ else:
+ sliced_data = \
+ input_data[slice_range]
+ return sliced_data
diff --git a/detrsmpl/data/data_structures/human_data_cache.py b/detrsmpl/data/data_structures/human_data_cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..122b3ed1d5663e5ffa903b9b5d1e9d56ed102cec
--- /dev/null
+++ b/detrsmpl/data/data_structures/human_data_cache.py
@@ -0,0 +1,104 @@
+from typing import List
+
+import numpy as np
+
+from detrsmpl.utils.path_utils import (
+ Existence,
+ check_path_existence,
+ check_path_suffix,
+)
+from .human_data import HumanData
+
+
+class HumanDataCacheReader():
+ def __init__(self, npz_path: str):
+ self.npz_path = npz_path
+ npz_file = np.load(npz_path, allow_pickle=True)
+ self.slice_size = npz_file['slice_size'].item()
+ self.data_len = npz_file['data_len'].item()
+ self.keypoints_info = npz_file['keypoints_info'].item()
+ self.non_sliced_data = None
+ self.npz_file = None
+
+ def __del__(self):
+ if self.npz_file is not None:
+ self.npz_file.close()
+
+ def get_item(self, index, required_keys: List[str] = []):
+ if self.npz_file is None:
+ self.npz_file = np.load(self.npz_path, allow_pickle=True)
+ cache_key = str(int(index / self.slice_size))
+ base_data = self.npz_file[cache_key].item()
+ base_data.update(self.keypoints_info)
+ for key in required_keys:
+ non_sliced_value = self.get_non_sliced_data(key)
+ if isinstance(non_sliced_value, dict) and\
+ key in base_data and\
+ isinstance(base_data[key], dict):
+ base_data[key].update(non_sliced_value)
+ else:
+ base_data[key] = non_sliced_value
+ ret_human_data = HumanData.new(source_dict=base_data)
+ # data in cache is compressed
+ ret_human_data.__keypoints_compressed__ = True
+ # set missing values and attributes by default method
+ ret_human_data.__set_default_values__()
+ return ret_human_data
+
+ def get_non_sliced_data(self, key: str):
+ if self.non_sliced_data is None:
+ if self.npz_file is None:
+ npz_file = np.load(self.npz_path, allow_pickle=True)
+ self.non_sliced_data = npz_file['non_sliced_data'].item()
+ else:
+ self.non_sliced_data = self.npz_file['non_sliced_data'].item()
+ return self.non_sliced_data[key]
+
+
+class HumanDataCacheWriter():
+ def __init__(self,
+ slice_size: int,
+ data_len: int,
+ keypoints_info: dict,
+ non_sliced_data: dict,
+ key_strict: bool = True):
+ self.slice_size = slice_size
+ self.data_len = data_len
+ self.keypoints_info = keypoints_info
+ self.non_sliced_data = non_sliced_data
+ self.sliced_data = {}
+ self.key_strict = key_strict
+
+ def update_sliced_dict(self, sliced_dict):
+ self.sliced_data.update(sliced_dict)
+
+ def dump(self, npz_path: str, overwrite: bool = True):
+ """Dump keys and items to an npz file.
+
+ Args:
+ npz_path (str):
+ Path to a dumped npz file.
+ overwrite (bool, optional):
+ Whether to overwrite if there is already a file.
+ Defaults to True.
+
+ Raises:
+ ValueError:
+ npz_path does not end with '.npz'.
+ FileExistsError:
+ When overwrite is False and file exists.
+ """
+ if not check_path_suffix(npz_path, ['.npz']):
+ raise ValueError('Not an npz file.')
+ if not overwrite:
+ if check_path_existence(npz_path, 'file') == Existence.FileExist:
+ raise FileExistsError
+ dict_to_dump = {
+ 'slice_size': self.slice_size,
+ 'data_len': self.data_len,
+ 'keypoints_info': self.keypoints_info,
+ 'non_sliced_data': self.non_sliced_data,
+ 'key_strict': self.key_strict,
+ }
+ dict_to_dump.update(self.sliced_data)
+ np.savez_compressed(npz_path, **dict_to_dump)
diff --git a/detrsmpl/data/data_structures/multi_human_data.py b/detrsmpl/data/data_structures/multi_human_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..68ed9d9336572db13ee52f51bd45a1b4cd4be5da
--- /dev/null
+++ b/detrsmpl/data/data_structures/multi_human_data.py
@@ -0,0 +1,480 @@
+import logging
+import pickle
+from enum import Enum
+from typing import Any, TypeVar, Union
+
+import numpy as np
+from mmcv.utils import print_log
+
+from detrsmpl.data.data_structures.human_data import HumanData
+from detrsmpl.utils.path_utils import (
+ Existence,
+ check_path_existence,
+ check_path_suffix,
+)
+
+# In T = TypeVar('T'), T can be anything.
+# See definition of typing.TypeVar for details.
+_HumanData = TypeVar('_HumanData')
+
+_MultiHumanData_SUPPORTED_KEYS = HumanData.SUPPORTED_KEYS.copy()
+_MultiHumanData_SUPPORTED_KEYS.update(
+ {'optional': {
+ 'type': dict,
+ 'slice_key': 'frame_range',
+ 'dim': 0
+ }})
+
+
+class _KeyCheck(Enum):
+ PASS = 0
+ WARN = 1
+ ERROR = 2
+
+
+class MultiHumanData(HumanData):
+ SUPPORTED_KEYS = _MultiHumanData_SUPPORTED_KEYS
+
+ def __new__(cls: _HumanData, *args: Any, **kwargs: Any) -> _HumanData:
+ """New an instance of HumanData.
+
+ Args:
+ cls (HumanData): HumanData class.
+
+ Returns:
+ HumanData: An instance of Hu
+ """
+ ret_human_data = super().__new__(cls, args, kwargs)
+ setattr(ret_human_data, '__data_len__', -1)
+ setattr(ret_human_data, '__instance_num__', -1)
+ setattr(ret_human_data, '__key_strict__', False)
+ setattr(ret_human_data, '__keypoints_compressed__', False)
+ return ret_human_data
+
+ def load(self, npz_path: str):
+ """Load data from npz_path and update them to self.
+
+ Args:
+ npz_path (str):
+ Path to a dumped npz file.
+ """
+ supported_keys = self.__class__.SUPPORTED_KEYS
+ with np.load(npz_path, allow_pickle=True) as npz_file:
+ tmp_data_dict = dict(npz_file)
+ for key, value in list(tmp_data_dict.items()):
+ if isinstance(value, np.ndarray) and\
+ len(value.shape) == 0:
+ # value is not an ndarray before dump
+ value = value.item()
+ elif key in supported_keys and\
+ type(value) != supported_keys[key]['type']:
+ value = supported_keys[key]['type'](value)
+ if value is None:
+ tmp_data_dict.pop(key)
+ elif key == '__key_strict__' or \
+ key == '__data_len__' or\
+ key == '__instance_num__' or\
+ key == '__keypoints_compressed__':
+ self.__setattr__(key, value)
+ # pop the attributes to keep dict clean
+ tmp_data_dict.pop(key)
+ elif key == 'bbox_xywh' and value.shape[1] == 4:
+ value = np.hstack([value, np.ones([value.shape[0], 1])])
+ tmp_data_dict[key] = value
+ else:
+ tmp_data_dict[key] = value
+ self.update(tmp_data_dict)
+ self.__set_default_values__()
+
+ def dump(self, npz_path: str, overwrite: bool = True):
+ """Dump keys and items to an npz file.
+
+ Args:
+ npz_path (str):
+ Path to a dumped npz file.
+ overwrite (bool, optional):
+ Whether to overwrite if there is already a file.
+ Defaults to True.
+
+ Raises:
+ ValueError:
+ npz_path does not end with '.npz'.
+ FileExistsError:
+ When overwrite is False and file exists.
+ """
+ if not check_path_suffix(npz_path, ['.npz']):
+ raise ValueError('Not an npz file.')
+ if not overwrite:
+ if check_path_existence(npz_path, 'file') == Existence.FileExist:
+ raise FileExistsError
+ dict_to_dump = {
+ '__key_strict__': self.__key_strict__,
+ '__data_len__': self.__data_len__,
+ '__instance_num__': self.__instance_num__,
+ '__keypoints_compressed__': self.__keypoints_compressed__,
+ }
+ dict_to_dump.update(self)
+ np.savez_compressed(npz_path, **dict_to_dump)
+
+ def dump_by_pickle(self, pkl_path: str, overwrite: bool = True) -> None:
+ """Dump keys and items to a pickle file. It's a secondary dump method,
+ when a HumanData instance is too large to be dumped by self.dump()
+
+ Args:
+ pkl_path (str):
+ Path to a dumped pickle file.
+ overwrite (bool, optional):
+ Whether to overwrite if there is already a file.
+ Defaults to True.
+
+ Raises:
+ ValueError:
+ npz_path does not end with '.pkl'.
+ FileExistsError:
+ When overwrite is False and file exists.
+ """
+ if not check_path_suffix(pkl_path, ['.pkl']):
+ raise ValueError('Not an pkl file.')
+ if not overwrite:
+ if check_path_existence(pkl_path, 'file') == Existence.FileExist:
+ raise FileExistsError
+ dict_to_dump = {
+ '__key_strict__': self.__key_strict__,
+ '__data_len__': self.__data_len__,
+ '__instance_num__': self.__instance_num__,
+ '__keypoints_compressed__': self.__keypoints_compressed__,
+ }
+ dict_to_dump.update(self)
+ with open(pkl_path, 'wb') as f_writeb:
+ pickle.dump(dict_to_dump,
+ f_writeb,
+ protocol=pickle.HIGHEST_PROTOCOL)
+
+ def load_by_pickle(self, pkl_path: str) -> None:
+ """Load data from pkl_path and update them to self.
+
+ When a HumanData Instance was dumped by
+ self.dump_by_pickle(), use this to load.
+ Args:
+ npz_path (str):
+ Path to a dumped npz file.
+ """
+ with open(pkl_path, 'rb') as f_readb:
+ tmp_data_dict = pickle.load(f_readb)
+ for key, value in list(tmp_data_dict.items()):
+ if value is None:
+ tmp_data_dict.pop(key)
+ elif key == '__key_strict__' or \
+ key == '__data_len__' or\
+ key == '__instance_num__' or\
+ key == '__keypoints_compressed__':
+ self.__setattr__(key, value)
+ # pop the attributes to keep dict clean
+ tmp_data_dict.pop(key)
+ elif key == 'bbox_xywh' and value.shape[1] == 4:
+ value = np.hstack([value, np.ones([value.shape[0], 1])])
+ tmp_data_dict[key] = value
+ else:
+ tmp_data_dict[key] = value
+ self.update(tmp_data_dict)
+ self.__set_default_values__()
+
+ @property
+ def instance_num(self) -> int:
+ """Get the human instance num of this MultiHumanData instance. In
+ MuliHumanData, an image may have multiple corresponding human
+ instances.
+
+ Returns:
+ int:
+ Number of human instance related to this instance.
+ """
+ return self.__instance_num__
+
+ @instance_num.setter
+ def instance_num(self, value: int):
+ """Set the human instance num of this MultiHumanData instance.
+
+ Args:
+ value (int):
+ Number of human instance related to this instance.
+ """
+ self.__instance_num__ = value
+
+ def get_slice(self,
+ arg_0: int,
+ arg_1: Union[int, Any] = None,
+ step: int = 1) -> _HumanData:
+ """Slice all sliceable values along major_dim dimension.
+
+ Args:
+ arg_0 (int):
+ When arg_1 is None, arg_0 is stop and start=0.
+ When arg_1 is not None, arg_0 is start.
+ arg_1 (Union[int, Any], optional):
+ None or where to stop.
+ Defaults to None.
+ step (int, optional):
+ Length of step. Defaults to 1.
+
+ Returns:
+ MultiHumanData:
+ A new MultiHumanData instance with sliced values.
+ """
+ ret_human_data = \
+ MultiHumanData.new(key_strict=self.get_key_strict())
+ if arg_1 is None:
+ start = 0
+ stop = arg_0
+ else:
+ start = arg_0
+ stop = arg_1
+ slice_index = slice(start, stop, step)
+ dim_dict = self.__get_slice_dim__()
+ # frame_range = self.get_raw_value('optional')['frame_range']
+ for key, dim in dim_dict.items():
+ # primary index
+ if key == 'optional':
+ frame_range = None
+ else:
+ frame_range = self.get_raw_value('optional')['frame_range']
+ # keys not expected be sliced
+ if dim is None:
+ ret_human_data[key] = self[key]
+ elif isinstance(dim, dict):
+ value_dict = self.get_raw_value(key)
+ sliced_dict = {}
+ for sub_key in value_dict.keys():
+ sub_value = value_dict[sub_key]
+ if dim[sub_key] is None:
+ sliced_dict[sub_key] = sub_value
+ else:
+ sub_dim = dim[sub_key]
+ sliced_sub_value = \
+ MultiHumanData.__get_sliced_result__(
+ sub_value, sub_dim, slice_index, frame_range)
+ sliced_dict[sub_key] = sliced_sub_value
+ ret_human_data[key] = sliced_dict
+ else:
+ value = self[key]
+ sliced_value = \
+ MultiHumanData.__get_sliced_result__(
+ value, dim, slice_index, frame_range)
+ ret_human_data[key] = sliced_value
+ # check keypoints compressed
+ if self.check_keypoints_compressed():
+ ret_human_data.compress_keypoints_by_mask()
+ return ret_human_data
+
+ def __get_slice_dim__(self) -> dict:
+ """For each key in this HumanData, get the dimension for slicing. 0 for
+ default, if no other value specified.
+
+ Returns:
+ dict:
+ Keys are self.keys().
+ Values indicate where to slice.
+ None for not expected to be sliced or
+ failed.
+ """
+ supported_keys = self.__class__.SUPPORTED_KEYS
+ ret_dict = {}
+ for key in self.keys():
+ # keys not expected be sliced
+ if key in supported_keys and \
+ 'dim' in supported_keys[key] and \
+ supported_keys[key]['dim'] is None:
+ ret_dict[key] = None
+ else:
+ value = self[key]
+ if isinstance(value, dict) and len(value) > 0:
+ ret_dict[key] = {}
+ for sub_key in value.keys():
+ try:
+ sub_value_len = len(value[sub_key])
+ if sub_value_len != self.instance_num and \
+ sub_value_len != self.data_len:
+ ret_dict[key][sub_key] = None
+ elif 'dim' in value:
+ ret_dict[key][sub_key] = value['dim']
+ else:
+ ret_dict[key][sub_key] = 0
+ except TypeError:
+ ret_dict[key][sub_key] = None
+ continue
+ # instance cannot be sliced without len method
+ try:
+ value_len = len(value)
+ except TypeError:
+ ret_dict[key] = None
+ continue
+ # slice on dim 0 by default
+ slice_dim = 0
+ if key in supported_keys and \
+ 'dim' in supported_keys[key]:
+ slice_dim = \
+ supported_keys[key]['dim']
+ data_len = value_len if slice_dim == 0 \
+ else value.shape[slice_dim]
+ # dim not for slice
+ if data_len != self.__instance_num__:
+ ret_dict[key] = None
+ continue
+ else:
+ ret_dict[key] = slice_dim
+ return ret_dict
+
+ # TODO: to support cache
+
+ def __check_value_len__(self, key: Any, val: Any) -> bool:
+ """Check whether the temporal length of val matches other values.
+
+ Args:
+ key (Any):
+ Key in MultiHumanData.
+ val (Any):
+ Value to the key.
+
+ Returns:
+ bool:
+ If temporal dim is defined and temporal length doesn't match,
+ return False.
+ Else return True.
+ """
+ ret_bool = True
+ supported_keys = self.__class__.SUPPORTED_KEYS
+
+ # MultiHumanData
+ instance_num = 0
+ if key == 'optional' and \
+ 'frame_range' in val:
+ for frame_range in val['frame_range']:
+ instance_num += (frame_range[-1] - frame_range[0])
+
+ if self.instance_num == -1:
+ # init instance_num for multi_human_data
+ self.instance_num = instance_num
+ elif self.instance_num != instance_num:
+ ret_bool = False
+
+ data_len = len(val['frame_range'])
+ if self.data_len == -1:
+ # init data_len
+ self.data_len = data_len
+ elif self.data_len == self.instance_num:
+ # update data_len
+ self.data_len = data_len
+ elif self.data_len != self.instance_num:
+ ret_bool = False
+
+ # check definition
+ elif key in supported_keys:
+ # check data length
+ if 'dim' in supported_keys[key] and \
+ supported_keys[key]['dim'] is not None:
+ val_slice_dim = supported_keys[key]['dim']
+ if supported_keys[key]['type'] == dict:
+ slice_key = supported_keys[key]['slice_key']
+ val_data_len = val[slice_key].shape[val_slice_dim]
+ else:
+ val_data_len = val.shape[val_slice_dim]
+
+ if self.instance_num < 0:
+ # Init instance_num for HumanData,
+ # which is equal to data_len.
+ self.instance_num = val_data_len
+ else:
+ # check if val_data_len matches recorded instance_num
+ if self.instance_num != val_data_len:
+ ret_bool = False
+
+ if self.data_len < 0:
+ # init data_len for HumanData, it's equal to
+ # instance_num.
+ # If it's MultiHumanData needs to be updated
+ self.data_len = val_data_len
+
+ if not ret_bool:
+ err_msg = 'Data length check Failed:\n'
+ err_msg += f'key={str(key)}\n'
+ if self.data_len != self.instance_num:
+ err_msg += f'val\'s instance_num={self.data_len}\n'
+ err_msg += f'expected instance_num={self.instance_num}\n'
+ print_log(msg=err_msg,
+ logger=self.__class__.logger,
+ level=logging.ERROR)
+ return ret_bool
+
+ def __set_default_values__(self) -> None:
+ """For older versions of HumanData, call this method to apply missing
+ values (also attributes).
+
+ Note:
+ 1. Older HumanData doesn't define `data_len`.
+ 2. In the newer HumanData, `data_len` equals the `instances_num`.
+ 3. In MultiHumanData, `instance_num` equals instances num,
+ and `data_len` equals frames num.
+ """
+ supported_keys = self.__class__.SUPPORTED_KEYS
+ if self.instance_num == -1:
+ # the loaded file is not multi_human_data
+ for key in supported_keys:
+ if key in self and \
+ 'dim' in supported_keys[key] and\
+ supported_keys[key]['dim'] is not None:
+ if 'slice_key' in supported_keys[key] and\
+ supported_keys[key]['type'] == dict:
+ sub_key = supported_keys[key]['slice_key']
+ slice_dim = supported_keys[key]['dim']
+ self.instance_num = self[key][sub_key].shape[slice_dim]
+ else:
+ slice_dim = supported_keys[key]['dim']
+ self.instance_num = self[key].shape[slice_dim]
+
+ # convert HumanData to MultiHumanData
+ self.data_len = self.instance_num
+ optional = {}
+ optional['frame_range'] = \
+ [[i, i + 1] for i in range(self.data_len)]
+ self['optional'] = optional
+ break
+
+ for key in list(self.keys()):
+ convention_key = f'{key}_convention'
+ if key.startswith('keypoints') and \
+ not key.endswith('_mask') and \
+ not key.endswith('_convention') and \
+ convention_key not in self:
+ self[convention_key] = 'human_data'
+
+ @classmethod
+ def __get_sliced_result__(
+ cls,
+ input_data: Union[np.ndarray, list, tuple],
+ slice_dim: int,
+ slice_range: slice,
+ frame_index: list = None) -> Union[np.ndarray, list, tuple]:
+
+ if frame_index is not None:
+ slice_data = []
+ for frame_range in frame_index[slice_range]:
+ slice_index = slice(frame_range[0], frame_range[-1], 1)
+ slice_result = \
+ HumanData.__get_sliced_result__(
+ input_data,
+ slice_dim,
+ slice_index)
+ for element in slice_result:
+ slice_data.append(element)
+ if isinstance(input_data, np.ndarray):
+ slice_data = np.array(slice_data)
+ else:
+ slice_data = type(input_data)(slice_data)
+ else:
+ # primary index
+ slice_data = \
+ HumanData.__get_sliced_result__(
+ input_data,
+ slice_dim,
+ slice_range)
+ return slice_data
diff --git a/detrsmpl/data/data_structures/smc_reader.py b/detrsmpl/data/data_structures/smc_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3916114f86290db810bc0222a9647b4e1b8e53
--- /dev/null
+++ b/detrsmpl/data/data_structures/smc_reader.py
@@ -0,0 +1,1021 @@
+import json
+
+import cv2
+import h5py
+import numpy as np
+import torch
+import tqdm
+
+from detrsmpl.models.body_models.builder import build_body_model
+from detrsmpl.models.body_models.utils import batch_transform_to_camera_frame
+
+
+class SMCReader:
+ def __init__(self, file_path, body_model=None):
+ """Read SenseMocapFile endswith ".smc", see: https://github.com/open-
+ mmlab/detrsmpl/blob/main/docs/smc.md.
+
+ Args:
+ file_path (str):
+ Path to an SMC file.
+ body_model (nn.Module or dict):
+ Only needed for SMPL transformation to device frame
+ if nn.Module: a body_model instance
+ if dict: a body_model config
+ """
+ self.smc = h5py.File(file_path, 'r')
+ self.__calibration_dict__ = None
+ self.action_id = self.smc.attrs['action_id']
+ self.actor_id = self.smc.attrs['actor_id']
+ self.datetime_str = self.smc.attrs['datetime_str'] # .decode()
+ self.kinect_num_frames = self.smc['Kinect'].attrs['num_frame']
+ self.num_kinects = self.smc['Kinect'].attrs['num_device']
+ self.kinect_color_resolution = self.get_kinect_color_resolution(0)
+ self.kinect_depth_resolution = self.get_kinect_depth_resolution(0)
+ self.iphone_exists = 'iPhone' in self.smc.keys()
+ self.num_iphones = 1
+ if self.iphone_exists:
+ self.iphone_num_frames = self.smc['iPhone'].attrs['num_frame']
+ self.iphone_color_resolution = \
+ self.smc['iPhone'].attrs['color_resolution'] # vertical
+ self.iphone_depth_resolution = \
+ self.smc['iPhone'].attrs['depth_resolution'] # vertical
+ self.keypoint_exists = 'Keypoints3D' in self.smc.keys()
+ if self.keypoint_exists:
+ self.keypoints_num_frames = self.smc['Keypoints3D'].attrs[
+ 'num_frame']
+ self.keypoints_convention = self.smc['Keypoints3D'].attrs[
+ 'convention']
+ self.keypoints_created_time = self.smc['Keypoints3D'].attrs[
+ 'created_time']
+ self.smpl_exists = 'SMPL' in self.smc.keys()
+ if self.smpl_exists:
+ self.smpl_num_frames = self.smc['SMPL'].attrs['num_frame']
+ self.smpl_created_time = self.smc['SMPL'].attrs['created_time']
+
+ # initialize body model
+ if isinstance(body_model, torch.nn.Module):
+ self.body_model = body_model
+ elif isinstance(body_model, dict):
+ self.body_model = build_body_model(body_model)
+ else:
+ # in most cases, SMCReader is instantiated for image reading
+ # only. Hence, it is wasteful to initialize a body model until
+ # really needed in get_smpl()
+ self.body_model = None
+ self.default_body_model_config = dict(
+ type='SMPL',
+ gender='neutral',
+ num_betas=10,
+ keypoint_src='smpl_45',
+ keypoint_dst='smpl_45',
+ model_path='data/body_models/smpl',
+ batch_size=1,
+ )
+
+ def get_kinect_color_extrinsics(self, kinect_id, homogeneous=True):
+ """Get extrinsics(cam2world) of a kinect RGB camera by kinect id.
+
+ Args:
+ kinect_id (int):
+ ID of a kinect, starts from 0.
+ homogeneous (bool, optional):
+ If true, returns rotation and translation in
+ one 4x4 matrix. Defaults to True.
+
+ Returns:
+ homogeneous is True
+ ndarray: A 4x4 matrix of rotation and translation(cam2world).
+ homogeneous is False
+ dict: A dict of rotation and translation,
+ keys are R and T,
+ each value is an ndarray.
+ """
+ R = np.asarray(self.calibration_dict[str(kinect_id * 2)]['R']).reshape(
+ 3, 3)
+ T = np.asarray(self.calibration_dict[str(kinect_id *
+ 2)]['T']).reshape(3)
+ if homogeneous:
+ extrinsics = np.identity(4, dtype=float)
+ extrinsics[:3, :3] = R
+ extrinsics[:3, 3] = T
+ return extrinsics
+ else:
+ return {'R': R, 'T': T}
+
+ @property
+ def calibration_dict(self):
+ """Get the dict of calibration.
+
+ Returns:
+ dict:
+ A dict of calibrated extrinsics.
+ """
+ if self.__calibration_dict__ is not None:
+ return self.__calibration_dict__
+ else:
+ return json.loads(self.smc['Extrinsics'][()])
+
+ def get_kinect_depth_extrinsics(self, kinect_id, homogeneous=True):
+ """Get extrinsics(cam2world) of a kinect depth camera by kinect id.
+
+ Args:
+ kinect_id (int):
+ ID of a kinect, starts from 0.
+ homogeneous (bool, optional):
+ If true, returns rotation and translation in
+ one 4x4 matrix. Defaults to True.
+
+ Returns:
+ homogeneous is True
+ ndarray: A 4x4 matrix of rotation and translation(cam2world).
+ homogeneous is False
+ dict: A dict of rotation and translation,
+ keys are R and T,
+ each value is an ndarray.
+ """
+ R = np.asarray(self.calibration_dict[str(kinect_id * 2 +
+ 1)]['R']).reshape(3, 3)
+ T = np.asarray(self.calibration_dict[str(kinect_id * 2 +
+ 1)]['T']).reshape(3)
+ if homogeneous:
+ extrinsics = np.identity(4, dtype=float)
+ extrinsics[:3, :3] = R
+ extrinsics[:3, 3] = T
+ return extrinsics
+ else:
+ return {'R': R, 'T': T}
+
+ def get_kinect_color_intrinsics(self, kinect_id):
+ """Get intrinsics of a kinect RGB camera by kinect id.
+
+ Args:
+ kinect_id (int):
+ ID of a kinect, starts from 0.
+
+ Returns:
+ ndarray: A 3x3 matrix.
+ """
+ kinect_dict = self.smc['Kinect'][str(kinect_id)]
+ intrinsics = \
+ kinect_dict['Calibration']['Color']['Intrinsics'][()]
+ cx, cy, fx, fy = intrinsics[:4]
+ intrinsics = \
+ np.asarray([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
+ return intrinsics
+
+ def get_kinect_color_resolution(self, kinect_id):
+ """Get resolution of a kinect RGB camera by kinect id.
+
+ Args:
+ kinect_id (int):
+ ID of a kinect, starts from 0.
+
+ Returns:
+ ndarray:
+ An ndarray of (width, height), shape=[2, ].
+ """
+ kinect_dict = self.smc['Kinect'][str(kinect_id)]
+ resolution = \
+ kinect_dict['Calibration']['Color']['Resolution'][()]
+ return resolution
+
+ def get_kinect_depth_resolution(self, kinect_id):
+ """Get resolution of a kinect depth camera by kinect id.
+
+ Args:
+ kinect_id (int):
+ ID of a kinect, starts from 0.
+
+ Returns:
+ ndarray:
+ An ndarray of (width, height), shape=[2, ].
+ """
+ kinect_dict = self.smc['Kinect'][str(kinect_id)]
+ resolution = \
+ kinect_dict['Calibration']['Depth']['Resolution'][()]
+ return resolution
+
+ def get_kinect_depth_intrinsics(self, kinect_id):
+ """Get intrinsics of a kinect depth camera by kinect id.
+
+ Args:
+ kinect_id (int):
+ ID of a kinect, starts from 0.
+
+ Returns:
+ ndarray: A 3x3 matrix.
+ """
+ kinect_dict = self.smc['Kinect'][str(kinect_id)]
+ intrinsics = \
+ kinect_dict['Calibration']['Depth']['Intrinsics'][()]
+ cx, cy, fx, fy = intrinsics[:4]
+ intrinsics = \
+ np.asarray([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
+ return intrinsics
+
+ def get_iphone_intrinsics(self, iphone_id=0, frame_id=0, vertical=True):
+ """Get intrinsics of an iPhone RGB camera by iPhone id.
+
+ Args:
+ iphone_id (int, optional):
+ ID of an iPhone, starts from 0.
+ Defaults to 0.
+ frame_id (int, optional):
+ int: frame id of one selected frame
+ Defaults to 0.
+ vertical (bool, optional):
+ iPhone assumes landscape orientation
+ if True, convert data to vertical orientation
+ Defaults to True.
+
+ Returns:
+ ndarray: A 3x3 matrix.
+ """
+ camera_info = self.smc['iPhone'][str(iphone_id)]['CameraInfo'][str(
+ frame_id)]
+ camera_info = json.loads(camera_info[()])
+ intrinsics = np.asarray(camera_info['cameraIntrinsics']).transpose()
+
+ # Intrinsics have to be adjusted to achieve rotation
+ # 1. swapping fx, fy
+ # 2. cx -> image height - cy; cy -> cx
+ if vertical:
+ fx, fy = intrinsics[0, 0], intrinsics[1, 1]
+ cx, cy = intrinsics[0, 2], intrinsics[1, 2]
+ W, H = self.get_iphone_color_resolution(vertical=False)
+ intrinsics = np.eye(3)
+ intrinsics[0, 0], intrinsics[1, 1] = fy, fx
+ intrinsics[0, 2], intrinsics[1, 2] = H - cy, cx
+
+ return intrinsics
+
+ def get_iphone_extrinsics(self,
+ iphone_id=0,
+ homogeneous=True,
+ vertical=True):
+ """Get extrinsics(cam2world) of an iPhone RGB camera by iPhone id.
+
+ Args:
+ iphone_id (int, optional):
+ ID of an iPhone, starts from 0.
+ Defaults to 0.
+ homogeneous (bool, optional):
+ If true, returns rotation and translation in
+ one 4x4 matrix. Defaults to True.
+ vertical (bool, optional):
+ iPhone assumes landscape orientation
+ if True, convert data to vertical orientation
+ Defaults to True.
+
+ Returns:
+ homogeneous is True
+ ndarray: A 4x4 transformation matrix(cam2world).
+ homogeneous is False
+ dict: A dict of rotation and translation,
+ keys are R and T,
+ each value is an ndarray.
+ """
+ if iphone_id != 0:
+ raise KeyError('Currently only one iPhone.')
+ R = np.asarray(self.calibration_dict['iPhone']['R']).reshape(3, 3)
+ T = np.asarray(self.calibration_dict['iPhone']['T']).reshape(3)
+
+ # cam2world
+ extrinsics = np.identity(4, dtype=float)
+ extrinsics[:3, :3] = R
+ extrinsics[:3, 3] = T
+
+ # Extrinsics have to be adjusted to achieve rotation
+ # A rotation matrix is applied on the extrinsics
+ if vertical:
+ # 90-degree clockwise rotation around z-axis
+ R = np.eye(4)
+ R[:2, :2] = np.array([[0, -1], [1, 0]])
+ # Note the extrinsics is cam2world
+ # world2cam_adjusted = R @ world2cam
+ # => cam2world_adjusted = cam2world @ inv(R)
+ extrinsics = extrinsics @ np.linalg.inv(R)
+ R = extrinsics[:3, :3]
+ T = extrinsics[:3, 3]
+
+ if homogeneous:
+ return extrinsics
+ else:
+ return {'R': R, 'T': T}
+
+ def get_iphone_color_resolution(self, iphone_id=0, vertical=True):
+ """Get color image resolution of an iPhone RGB camera by iPhone id.
+
+ Args:
+ iphone_id (int, optional):
+ ID of an iPhone, starts from 0.
+ Defaults to 0.
+ vertical (bool, optional):
+ iPhone assumes landscape orientation
+ if True, convert data to vertical orientation
+ Defaults to True.
+
+ Returns:
+ ndarray:get_iphone_keypoints2d
+ An ndarray of (width, height), shape=[2, ].
+ """
+ if iphone_id != 0:
+ raise KeyError('Currently only one iPhone.')
+ if vertical:
+ W_horizontal, H_horizontal = self.iphone_color_resolution
+ W_vertical, H_vertical = H_horizontal, W_horizontal
+ return np.array([W_vertical, H_vertical])
+ else:
+ return self.iphone_color_resolution
+
+ def get_kinect_color(self, kinect_id, frame_id=None, disable_tqdm=True):
+ """Get several frames captured by a kinect RGB camera.
+
+ Args:
+ kinect_id (int):
+ ID of a kinect, starts from 0.
+ frame_id (int, list or None, optional):
+ int: frame id of one selected frame
+ list: a list of frame id
+ None: all frames will be returned
+ Defaults to None.
+ disable_tqdm (bool, optional):
+ Whether to disable the entire progressbar wrapper.
+ Defaults to True.
+
+ Returns:
+ ndarray:
+ An ndarray in shape [frame_number, height, width, channels].
+ """
+ frames = []
+ if frame_id is None:
+ frame_list = range(self.get_kinect_num_frames())
+ elif isinstance(frame_id, list):
+ frame_list = frame_id
+ elif isinstance(frame_id, int):
+ assert frame_id < self.get_kinect_num_frames(),\
+ 'Index out of range...'
+ frame_list = [frame_id]
+ else:
+ raise TypeError('frame_id should be int, list or None.')
+ for i in tqdm.tqdm(frame_list, disable=disable_tqdm):
+ frames.append(
+ self.__read_color_from_bytes__(
+ self.smc['Kinect'][str(kinect_id)]['Color'][str(i)][()]))
+ return np.stack(frames, axis=0)
+
+ def get_kinect_rgbd(self,
+ kinect_id,
+ frame_id,
+ mode='color2depth',
+ threshold=0):
+ if mode == 'color2depth':
+ mapped_color = \
+ self.__map_color_to_depth__(
+ kinect_id, frame_id, threshold=threshold
+ )
+ depth = self.get_kinect_depth(kinect_id, frame_id)[0]
+ return mapped_color, depth
+ else:
+ print('Model {} is not supported...'.format(mode))
+
+ def get_kinect_depth(self, kinect_id, frame_id=None, disable_tqdm=True):
+ """Get several frames captured by a kinect depth camera.
+
+ Args:
+ kinect_id (int):
+ ID of a kinect, starts from 0.
+ frame_id (int, list or None, optional):
+ int: frame id of one selected frame
+ list: a list of frame id
+ None: all frames will be returned
+ Defaults to None.
+ disable_tqdm (bool, optional):
+ Whether to disable the entire progressbar wrapper.
+ Defaults to True.
+
+ Returns:
+ ndarray:
+ An ndarray in shape [frame_number, height, width, channels].
+ """
+ frames = []
+ frame_list = []
+ if frame_id is None or type(frame_id) == list:
+ frame_list = range(self.get_kinect_num_frames())
+ if frame_id:
+ frame_list = frame_id
+ else:
+ assert frame_id < self.get_kinect_num_frames(),\
+ 'Index out of range...'
+ frame_list.append(frame_id)
+ for i in tqdm.tqdm(frame_list, disable=disable_tqdm):
+ frames.append(
+ self.smc['Kinect'][str(kinect_id)]['Depth'][str(i)][()])
+ return np.stack(frames, axis=0)
+
+ def __read_color_from_bytes__(self, color_array):
+ """Decode an RGB image from an encoded byte array."""
+ return cv2.cvtColor(cv2.imdecode(color_array, cv2.IMREAD_COLOR),
+ cv2.COLOR_BGR2RGB)
+
+ def get_num_kinect(self):
+ """Get the number of Kinect devices.
+
+ Returns:
+ int:
+ Number of Kinect devices.
+ """
+ return self.num_kinects
+
+ def get_kinect_num_frames(self):
+ """Get the number of frames recorded by one Kinect RGB camera.
+
+ Returns:
+ int:
+ Number of frames.
+ """
+ return self.kinect_num_frames
+
+ def get_iphone_num_frames(self):
+ """Get the number of frames recorded by one iPhone RGB camera.
+
+ Returns:
+ int:
+ Number of frames.
+ """
+ return self.iphone_num_frames
+
+ def get_depth_mask(self, device_id, frame_id):
+ return self.smc['Kinect'][str(device_id)]['Mask'][str(frame_id)][()]
+
+ def get_kinect_mask(self, device_id, frame_id):
+ kinect_dict = self.smc['Kinect'][str(device_id)]
+ return kinect_dict['Mask_k4abt'][str(frame_id)][()]
+
+ def get_num_iphone(self):
+ """Get the number of iPhone devices.
+
+ Returns:
+ int:
+ Number of iPhone devices.
+ """
+ return self.num_iphones
+
+ def get_iphone_color(self,
+ iphone_id=0,
+ frame_id=None,
+ disable_tqdm=True,
+ vertical=True):
+ """Get several frames captured by an iPhone RGB camera.
+
+ Args:
+ iphone_id (int):
+ ID of an iPhone, starts from 0.
+ frame_id (int, list or None, optional):
+ int: frame id of one selected frame
+ list: a list of frame id
+ None: all frames will be returned
+ Defaults to None.
+ disable_tqdm (bool, optional):
+ Whether to disable the entire progressbar wrapper.
+ Defaults to True.
+ vertical (bool, optional):
+ iPhone assumes horizontal orientation
+ if True, convert data to vertical orientation
+ Defaults to True.
+
+ Returns:
+ frames:
+ An ndarray in shape [frame_number, height, width, channels].
+ """
+ frames = []
+ if frame_id is None:
+ frame_list = range(self.get_iphone_num_frames())
+ elif isinstance(frame_id, list):
+ frame_list = frame_id
+ elif isinstance(frame_id, int):
+ assert frame_id < self.get_iphone_num_frames(),\
+ 'Index out of range...'
+ frame_list = [frame_id]
+ else:
+ raise TypeError('frame_id should be int, list or None.')
+ for i in tqdm.tqdm(frame_list, disable=disable_tqdm):
+ frame = self.__read_color_from_bytes__(
+ self.smc['iPhone'][str(iphone_id)]['Color'][str(i)][()])
+ if vertical:
+ frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
+ frames.append(frame)
+ return np.stack(frames, axis=0)
+
+ def get_iphone_depth(self,
+ iphone_id=0,
+ frame_id=None,
+ disable_tqdm=True,
+ vertical=True):
+ """Get several frames captured by an iPhone RGB camera.
+
+ Args:
+ iphone_id (int):
+ ID of an iPhone, starts from 0.
+ frame_id (int, list or None, optional):
+ int: frame id of one selected frame
+ list: a list of frame id
+ None: all frames will be returned
+ Defaults to None.
+ disable_tqdm (bool, optional):
+ Whether to disable the entire progressbar wrapper.
+ Defaults to True.
+ vertical (bool, optional):
+ iPhone assumes horizontal orientation
+ if True, convert data to vertical orientation
+ Defaults to True.
+
+ Returns:
+ frames:
+ An ndarray in shape [frame_number, height, width, channels].
+ """
+ frames = []
+ if frame_id is None:
+ frame_list = range(self.get_iphone_num_frames())
+ elif isinstance(frame_id, list):
+ frame_list = frame_id
+ elif isinstance(frame_id, int):
+ assert frame_id < self.get_iphone_num_frames(),\
+ 'Index out of range...'
+ frame_list = [frame_id]
+ else:
+ raise TypeError('frame_id should be int, list or None.')
+ for i in tqdm.tqdm(frame_list, disable=disable_tqdm):
+ frame = self.smc['iPhone'][str(iphone_id)]['Depth'][str(i)][()]
+ if vertical:
+ frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
+ frames.append(frame)
+ return np.stack(frames, axis=0)
+
+ def get_kinect_transformation_depth_to_color(self, device_id):
+ """Get transformation matrix from depth to color from a single kinect.
+
+ Args:
+ kinect_id (int, optional):
+ ID of a Kinect, starts from 0.
+
+ Returns:
+ ndarray: A 4x4 transformation matrix.
+ """
+ return np.linalg.inv(self.get_kinect_color_extrinsics(
+ device_id)) @ self.get_kinect_depth_extrinsics(device_id)
+
+ def get_kinect_transformation_color_to_depth(self, device_id):
+ """Get transformation matrix from color to depth from a single kinect.
+
+ Args:
+ kinect_id (int, optional):
+ ID of a Kinect, starts from 0.
+
+ Returns:
+ ndarray: A 4x4 transformation matrix.
+ """
+ return np.linalg.inv(self.get_kinect_depth_extrinsics(
+ device_id)) @ self.get_kinect_color_extrinsics(device_id)
+
+ def __map_color_to_depth__(self, device_id, frame_id, threshold=100):
+ color_image = self.get_kinect_color(device_id, frame_id)[0]
+ depth_image = self.get_kinect_depth(device_id, frame_id)[0]
+ color_intrinsic = self.get_kinect_color_intrinsics(device_id)
+ depth_intrinsic = self.get_kinect_depth_intrinsics(device_id)
+
+ mask = self.get_depth_mask(device_id, frame_id)
+
+ Td2c = self.get_kinect_transformation_depth_to_color(device_id)
+
+ colidx = np.arange(depth_image.shape[1])
+ rowidx = np.arange(depth_image.shape[0])
+ colidx_map, rowidx_map = np.meshgrid(colidx, rowidx)
+ col_indices = colidx_map[mask >= threshold]
+ row_indices = rowidx_map[mask >= threshold]
+
+ homo_padding = \
+ np.ones((col_indices.shape[0], 1), dtype=np.float32)
+ homo_indices = \
+ np.concatenate(
+ (col_indices[..., None], row_indices[..., None], homo_padding),
+ axis=1
+ )
+
+ depth_intrinsic_inv = np.linalg.inv(depth_intrinsic)
+ normalized_points = \
+ depth_intrinsic_inv[None, ...] @ homo_indices[..., None]
+
+ z_values = (depth_image / 1000)[mask >= threshold]
+ valid_points = \
+ normalized_points.squeeze() * z_values[..., None]
+
+ R = Td2c[:3, :3]
+ T = Td2c[:3, 3]
+ valid_points = \
+ R[None, ...] @ valid_points[..., None] + T[None, ..., None]
+ valid_uvs = \
+ color_intrinsic[None, ...] @\
+ valid_points / valid_points[:, 2][..., None]
+ valid_uvs = np.int32(valid_uvs.squeeze()[..., :2] + 0.5)
+ valid_uvs[:, 0] = np.clip(valid_uvs[:, 0], 0, color_image.shape[1] - 1)
+ valid_uvs[:, 1] = np.clip(valid_uvs[:, 1], 0, color_image.shape[0] - 1)
+ mapped_color = np.ones((depth_image.shape[0], depth_image.shape[1], 3),
+ dtype=np.uint8) * 255
+ mapped_color[mask >= threshold] = \
+ color_image[valid_uvs[:, 1], valid_uvs[:, 0]]
+
+ if threshold == 1:
+ return valid_uvs
+ return mapped_color
+
+ def get_kinect_skeleton_3d(self, device_id, frame_id):
+ """Get the 3D skeleton key points from a certain kinect.
+
+ Args:
+ device_id (int):
+ ID of a kinect, starts from 0.
+
+ Returns:
+ list:
+ A list with 3D keypoints
+ """
+ kinect_dict = self.smc['Kinect'][str(device_id)]
+ return json.loads(kinect_dict['Skeleton_k4abt'][str(frame_id)][()])
+
+ def get_depth_floor(self, device_id: int) -> dict:
+ """Get the floor plane defined by a normal vector and a center point
+ from a certain kinect.
+
+ Args:
+ device_id (int):
+ ID of a kinect, starts from 0.
+
+ Raises:
+ KeyError:
+ Key 'floor' not in ID of a kinect.
+
+ Returns:
+ dict:
+ A dict with 'center', 'normal' and 'pnum'.
+ """
+ device_dict = self.calibration_dict[str(device_id * 2 + 1)]
+ if 'floor' in device_dict:
+ return device_dict['floor']
+ else:
+ raise KeyError(f'Kinect {device_id} has no floor data.')
+
+ def get_keypoints2d(self, device, device_id, frame_id=None, vertical=True):
+ """Get keypoints2d projected from keypoints3d.
+
+ Args:
+ device (str):
+ Device name, should be Kinect or iPhone.
+ device_id (int):
+ ID of a device, starts from 0.
+ frame_id (int, list or None, optional):
+ int: frame id of one selected frame
+ list: a list of frame id
+ None: all frames will be returned
+ Defaults to None.
+ vertical (bool, optional):
+ Only applicable to iPhone as device
+ iPhone assumes horizontal orientation
+ if True, convert data to vertical orientation
+ Defaults to True.
+
+ Returns:
+ Tuple[np.ndarray, np.ndarray]:
+ keypoints2d (N, J, 3) and its mask (J, )
+ """
+ assert device in {
+ 'Kinect', 'iPhone'
+ }, f'Undefined device: {device}, should be "Kinect" or "iPhone"'
+ assert device_id >= 0
+
+ kps2d_dict = self.smc['Keypoints2D'][device][str(device_id)]
+ keypoints2d = kps2d_dict['keypoints2d'][...]
+ keypoints2d_mask = kps2d_dict['keypoints2d_mask'][...]
+
+ if frame_id is None:
+ frame_list = range(self.get_keypoints_num_frames())
+ elif isinstance(frame_id, list):
+ frame_list = frame_id
+ elif isinstance(frame_id, int):
+ assert frame_id < self.get_keypoints_num_frames(),\
+ 'Index out of range...'
+ frame_list = [frame_id]
+ else:
+ raise TypeError('frame_id should be int, list or None.')
+
+ keypoints2d = keypoints2d[frame_list, ...]
+
+ if device == 'iPhone' and vertical:
+ # rotate keypoints 2D clockwise by 90 degrees
+ W, H = self.get_iphone_color_resolution(vertical=False)
+ xs, ys, conf = \
+ keypoints2d[..., 0], keypoints2d[..., 1], keypoints2d[..., 2]
+ xs, ys = H - ys, xs # horizontal -> vertical
+ keypoints2d[..., 0], keypoints2d[..., 1] = xs.copy(), ys.copy()
+ keypoints2d[conf == 0.0] = 0.0
+
+ return keypoints2d, keypoints2d_mask
+
+ def get_kinect_keypoints2d(self, device_id, frame_id=None):
+ """Get Kinect 2D keypoints.
+
+ Args:
+ device_id (int):
+ ID of Kinect, starts from 0.
+ frame_id (int, list or None, optional):
+ int: frame id of one selected frame
+ list: a list of frame id
+ None: all frames will be returned
+ Defaults to None.
+
+ Returns:
+ Tuple[np.ndarray, np.ndarray]:
+ keypoints2d (N, J, 3) and its mask (J, )
+ """
+ assert self.num_kinects > device_id >= 0
+ return self.get_keypoints2d('Kinect', device_id, frame_id)
+
+ def get_iphone_keypoints2d(self,
+ device_id=0,
+ frame_id=None,
+ vertical=True):
+ """Get iPhone 2D keypoints.
+
+ Args:
+ device_id (int):
+ ID of iPhone, starts from 0.
+ frame_id (int, list or None, optional):
+ int: frame id of one selected frame
+ list: a list of frame id
+ None: all frames will be returned
+ Defaults to None.
+ vertical (bool, optional):
+ iPhone assumes horizontal orientation
+ if True, convert data to vertical orientation
+ Defaults to True.
+
+ Returns:
+ Tuple[np.ndarray, np.ndarray]:
+ keypoints2d (N, J, 3) and its mask (J, )
+ """
+ assert device_id >= 0
+ return self.get_keypoints2d('iPhone',
+ device_id,
+ frame_id,
+ vertical=vertical)
+
+ def get_color(self,
+ device,
+ device_id,
+ frame_id=None,
+ disable_tqdm=True,
+ vertical=True):
+ """Get RGB image(s) from Kinect RGB or iPhone RGB camera.
+
+ Args:
+ device (str):
+ Device name, should be Kinect or iPhone.
+ device_id (int):
+ Device ID, starts from 0.
+ frame_id (int, list or None, optional):
+ int: frame id of one selected frame
+ list: a list of frame id
+ None: all frames will be returned
+ Defaults to None.
+ disable_tqdm (bool, optional):
+ Whether to disable the entire progressbar wrapper.
+ Defaults to True.
+ vertical (bool, optional):
+ Only applicable to iPhone as device
+ iPhone assumes horizontal orientation
+ if True, convert data to vertical orientation
+ Defaults to True.
+
+ Returns:
+ img (ndarray):
+ An ndarray in shape [frame_number, height, width, channels].
+ """
+
+ assert device in {
+ 'Kinect', 'iPhone'
+ }, f'Undefined device: {device}, should be "Kinect" or "iPhone"'
+
+ if device == 'Kinect':
+ img = self.get_kinect_color(device_id, frame_id, disable_tqdm)
+ else:
+ img = self.get_iphone_color(device_id,
+ frame_id,
+ disable_tqdm,
+ vertical=vertical)
+
+ return img
+
+ def get_keypoints_num_frames(self):
+ return self.keypoints_num_frames
+
+ def get_keypoints_convention(self):
+ return self.keypoints_convention
+
+ def get_keypoints_created_time(self):
+ return self.keypoints_created_time
+
+ def get_keypoints3d(self,
+ device=None,
+ device_id=None,
+ frame_id=None,
+ vertical=True):
+ """Get keypoints3d (world coordinate) computed by mocap processing
+ pipeline.
+
+ Args:
+ device (str):
+ Device name, should be Kinect or iPhone.
+ None: world coordinate
+ Defaults to None.
+ device_id (int):
+ ID of a device, starts from 0.
+ None: world coordinate
+ Defaults to None
+ frame_id (int, list or None, optional):
+ int: frame id of one selected frame
+ list: a list of frame id
+ None: all frames will be returned
+ Defaults to None.
+ vertical (bool, optional):
+ Only applicable to iPhone as device
+ iPhone assumes horizontal orientation
+ if True, convert data to vertical orientation
+ Defaults to True.
+
+ Returns:
+ Tuple[np.ndarray, np.ndarray]:
+ keypoints3d (N, J, 4) and its mask (J, )
+ """
+ assert (device is None and device_id is None) or \
+ (device is not None and device_id is not None), \
+ 'device and device_id should be both None or both not None.'
+ if device is not None:
+ assert device in {
+ 'Kinect', 'iPhone'
+ }, f'Undefined device: {device}, should be "Kinect" or "iPhone"'
+ if device_id is not None:
+ assert device_id >= 0
+
+ if frame_id is None:
+ frame_list = range(self.get_keypoints_num_frames())
+ elif isinstance(frame_id, list):
+ frame_list = frame_id
+ elif isinstance(frame_id, int):
+ assert frame_id < self.get_keypoints_num_frames(),\
+ 'Index out of range...'
+ frame_list = [frame_id]
+ else:
+ raise TypeError('frame_id should be int, list or None.')
+
+ kps3d_dict = self.smc['Keypoints3D']
+
+ # keypoints3d are in world coordinate system
+ keypoints3d_world = kps3d_dict['keypoints3d'][...]
+ keypoints3d_world = keypoints3d_world[frame_list, ...]
+ keypoints3d_mask = kps3d_dict['keypoints3d_mask'][...]
+
+ # return keypoints3d in world coordinate system
+ if device is None:
+ return keypoints3d_world, keypoints3d_mask
+
+ # return keypoints3d in device coordinate system
+ else:
+ if device == 'Kinect':
+ cam2world = self.get_kinect_color_extrinsics(
+ kinect_id=device_id, homogeneous=True)
+ else:
+ cam2world = self.get_iphone_extrinsics(iphone_id=device_id,
+ vertical=vertical)
+
+ xyz, conf = keypoints3d_world[..., :3], keypoints3d_world[..., [3]]
+ xyz_homogeneous = np.ones([*xyz.shape[:-1], 4])
+ xyz_homogeneous[..., :3] = xyz
+ world2cam = np.linalg.inv(cam2world)
+ keypoints3d = np.einsum('ij,kmj->kmi', world2cam, xyz_homogeneous)
+ keypoints3d = np.concatenate([keypoints3d[..., :3], conf], axis=-1)
+
+ return keypoints3d, keypoints3d_mask
+
+ def get_smpl_num_frames(self):
+ return self.smpl_num_frames
+
+ def get_smpl_created_time(self):
+ return self.smpl_created_time
+
+ def get_smpl(self,
+ device=None,
+ device_id=None,
+ frame_id=None,
+ vertical=True):
+ """Get SMPL (world coordinate) computed by mocap processing pipeline.
+
+ Args:
+ device (str):
+ Device name, should be Kinect or iPhone.
+ None: world coordinate
+ Defaults to None.
+ device_id (int):
+ ID of a device, starts from 0.
+ None: world coordinate
+ Defaults to None
+ frame_id (int, list or None, optional):
+ int: frame id of one selected frame
+ list: a list of frame id
+ None: all frames will be returned
+ Defaults to None.
+ vertical (bool, optional):
+ Only applicable to iPhone as device
+ iPhone assumes horizontal orientation
+ if True, convert data to vertical orientation
+ Defaults to True.
+
+ Returns:
+ dict:
+ 'global_orient': np.ndarray of shape (N, 3)
+ 'body_pose': np.ndarray of shape (N, 69)
+ 'transl': np.ndarray of shape (N, 3)
+ 'betas': np.ndarray of shape (N, 10)
+ """
+ smpl_dict = self.smc['SMPL']
+ global_orient = smpl_dict['global_orient'][...]
+ body_pose = smpl_dict['body_pose'][...]
+ transl = smpl_dict['transl'][...]
+ betas = smpl_dict['betas'][...]
+
+ if frame_id is None:
+ frame_list = range(self.get_smpl_num_frames())
+ elif isinstance(frame_id, list):
+ frame_list = frame_id
+ elif isinstance(frame_id, int):
+ assert frame_id < self.get_keypoints_num_frames(),\
+ 'Index out of range...'
+ frame_list = [frame_id]
+ else:
+ raise TypeError('frame_id should be int, list or None.')
+
+ body_pose = body_pose[frame_list, ...]
+ global_orient = global_orient[frame_list, ...]
+ transl = transl[frame_list, ...]
+
+ # return SMPL parameters in world coordinate system
+ if device is None:
+ smpl_dict = dict(global_orient=global_orient,
+ body_pose=body_pose,
+ transl=transl,
+ betas=betas)
+
+ return smpl_dict
+
+ # return SMPL parameters in device coordinate system
+ else:
+
+ if self.body_model is None:
+ self.body_model = \
+ build_body_model(self.default_body_model_config)
+ torch_device = self.body_model.global_orient.device
+
+ assert device in {
+ 'Kinect', 'iPhone'
+ }, f'Undefined device: {device}, should be "Kinect" or "iPhone"'
+ assert device_id >= 0
+
+ if device == 'Kinect':
+ T_cam2world = self.get_kinect_color_extrinsics(
+ kinect_id=device_id, homogeneous=True)
+ else:
+ T_cam2world = self.get_iphone_extrinsics(iphone_id=device_id,
+ vertical=vertical)
+
+ T_world2cam = np.linalg.inv(T_cam2world)
+
+ output = self.body_model(
+ global_orient=torch.tensor(global_orient, device=torch_device),
+ body_pose=torch.tensor(body_pose, device=torch_device),
+ transl=torch.tensor(transl, device=torch_device),
+ betas=torch.tensor(betas, device=torch_device))
+ joints = output['joints'].detach().cpu().numpy()
+ pelvis = joints[:, 0, :]
+
+ new_global_orient, new_transl = batch_transform_to_camera_frame(
+ global_orient=global_orient,
+ transl=transl,
+ pelvis=pelvis,
+ extrinsic=T_world2cam)
+
+ smpl_dict = dict(global_orient=new_global_orient,
+ body_pose=body_pose,
+ transl=new_transl,
+ betas=betas)
+
+ return smpl_dict
diff --git a/detrsmpl/data/datasets/__init__.py b/detrsmpl/data/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f8c20e70bc4719a3fb537d6f35bb513eca8dfce
--- /dev/null
+++ b/detrsmpl/data/datasets/__init__.py
@@ -0,0 +1,21 @@
+from .adversarial_dataset import AdversarialDataset
+from .base_dataset import BaseDataset
+from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
+from .dataset_wrappers import ConcatDataset, RepeatDataset
+from .human_hybrik_dataset import HybrIKHumanImageDataset
+from .human_image_dataset import HumanImageDataset
+from .human_image_smplx_dataset import HumanImageSMPLXDataset
+from .human_video_dataset import HumanVideoDataset
+from .mesh_dataset import MeshDataset
+from .mixed_dataset import MixedDataset
+from .multi_human_image_dataset import MultiHumanImageDataset
+from .pipelines import Compose
+from .samplers import DistributedSampler
+
+__all__ = [
+ 'BaseDataset', 'HumanImageDataset', 'HumanImageSMPLXDataset',
+ 'build_dataloader', 'build_dataset', 'Compose', 'DistributedSampler',
+ 'ConcatDataset', 'RepeatDataset', 'DATASETS', 'PIPELINES', 'MixedDataset',
+ 'AdversarialDataset', 'MeshDataset', 'HumanVideoDataset',
+ 'HybrIKHumanImageDataset', 'MultiHumanImageDataset'
+]
diff --git a/detrsmpl/data/datasets/adversarial_dataset.py b/detrsmpl/data/datasets/adversarial_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6484d3e0b83326883cf985157087671180bfe036
--- /dev/null
+++ b/detrsmpl/data/datasets/adversarial_dataset.py
@@ -0,0 +1,40 @@
+import numpy as np
+from torch.utils.data import Dataset
+
+from .builder import DATASETS, build_dataset
+
+
+@DATASETS.register_module()
+class AdversarialDataset(Dataset):
+ """Mix Dataset for the adversarial training in 3D human mesh estimation
+ task.
+
+ The dataset combines data from two datasets and
+ return a dict containing data from two datasets.
+ Args:
+ train_dataset (:obj:`Dataset`): Dataset for 3D human mesh estimation.
+ adv_dataset (:obj:`Dataset`): Dataset for adversarial learning.
+ """
+ def __init__(self, train_dataset: Dataset, adv_dataset: Dataset):
+ super().__init__()
+ self.train_dataset = build_dataset(train_dataset)
+ self.adv_dataset = build_dataset(adv_dataset)
+ self.num_train_data = len(self.train_dataset)
+ self.num_adv_data = len(self.adv_dataset)
+
+ def __len__(self):
+ """Get the size of the dataset."""
+ return self.num_train_data
+
+ def __getitem__(self, idx: int):
+ """Given index, get the data from train dataset and randomly sample an
+ item from adversarial dataset.
+
+ Return a dict containing data from train and adversarial dataset.
+ """
+ data = self.train_dataset[idx]
+ adv_idx = np.random.randint(low=0, high=self.num_adv_data, dtype=int)
+ adv_data = self.adv_dataset[adv_idx]
+ for k, v in adv_data.items():
+ data['adv_' + k] = v
+ return data
diff --git a/detrsmpl/data/datasets/base_dataset.py b/detrsmpl/data/datasets/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..494b836db1095c34ac25004b317dc9427434d2a1
--- /dev/null
+++ b/detrsmpl/data/datasets/base_dataset.py
@@ -0,0 +1,71 @@
+import copy
+from abc import ABCMeta, abstractmethod
+from typing import Optional, Union
+
+from torch.utils.data import Dataset
+
+from .pipelines import Compose
+
+
+class BaseDataset(Dataset, metaclass=ABCMeta):
+ """Base dataset.
+
+ Args:
+ data_prefix (str): the prefix of data path.
+ pipeline (list): a list of dict, where each element represents
+ a operation defined in `mmhuman3d.datasets.pipelines`.
+ ann_file (str | None, optional): the annotation file. When ann_file is
+ str, the subclass is expected to read from the ann_file. When
+ ann_file is None, the subclass is expected to read according
+ to data_prefix.
+ test_mode (bool): in train mode or test mode. Default: None.
+ dataset_name (str | None, optional): the name of dataset. It is used
+ to identify the type of evaluation metric. Default: None.
+ """
+ # metric
+ ALLOWED_METRICS = {
+ 'mpjpe', 'pa-mpjpe', 'pve', '3dpck', 'pa-3dpck', '3dauc', 'pa-3dauc',
+ '3DRMSE', 'pa-pve'
+ }
+
+ def __init__(self,
+ data_prefix: str,
+ pipeline: list,
+ ann_file: Optional[Union[str, None]] = None,
+ test_mode: Optional[bool] = False,
+ dataset_name: Optional[Union[str, None]] = None):
+ super(BaseDataset, self).__init__()
+
+ self.ann_file = ann_file
+ self.data_prefix = data_prefix
+ self.test_mode = test_mode
+ self.pipeline = Compose(pipeline)
+ if dataset_name is not None:
+ self.dataset_name = dataset_name
+
+ self.load_annotations()
+
+ @abstractmethod
+ def load_annotations(self):
+ """Load annotations from ``ann_file``"""
+ pass
+
+ def prepare_data(self, idx: int):
+ """"Prepare raw data for the f'{idx'}-th data."""
+ results = copy.deepcopy(self.data_infos[idx])
+ results['dataset_name'] = self.dataset_name
+ results['sample_idx'] = idx
+ return self.pipeline(results)
+
+ def __len__(self):
+ """Return the length of current dataset."""
+ return self.num_data
+
+ def __getitem__(self, idx: int):
+ """Prepare data for the ``idx``-th data.
+
+ As for video dataset, we can first parse raw data for each frame. Then
+ we combine annotations from all frames. This interface is used to
+ simplify the logic of video dataset and other special datasets.
+ """
+ return self.prepare_data(idx)
diff --git a/detrsmpl/data/datasets/builder.py b/detrsmpl/data/datasets/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..27d3be684dc7d49ef1b3fa8d4c90ca84cf42d1c3
--- /dev/null
+++ b/detrsmpl/data/datasets/builder.py
@@ -0,0 +1,124 @@
+import platform
+import random
+from functools import partial
+from typing import Optional, Union
+
+import numpy as np
+from mmcv.parallel import collate
+from mmcv.runner import get_dist_info
+from mmcv.utils import Registry, build_from_cfg
+from torch.utils.data import DataLoader
+from torch.utils.data.dataset import Dataset
+
+from .samplers import DistributedSampler
+
+if platform.system() != 'Windows':
+ # https://github.com/pytorch/pytorch/issues/973
+ import resource
+ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+ base_soft_limit = rlimit[0]
+ hard_limit = rlimit[1]
+ soft_limit = min(max(4096, base_soft_limit), hard_limit)
+ resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
+
+DATASETS = Registry('dataset')
+PIPELINES = Registry('pipeline')
+
+
+def build_dataset(cfg: Union[dict, list, tuple],
+ default_args: Optional[Union[dict, None]] = None):
+ """"Build dataset by the given config."""
+ from .dataset_wrappers import (
+ ConcatDataset,
+ RepeatDataset,
+ )
+ if isinstance(cfg, (list, tuple)):
+ dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
+ elif cfg['type'] == 'RepeatDataset':
+ dataset = RepeatDataset(build_dataset(cfg['dataset'], default_args),
+ cfg['times'])
+ else:
+ dataset = build_from_cfg(cfg, DATASETS, default_args)
+
+ return dataset
+
+
+def build_dataloader(dataset: Dataset,
+ samples_per_gpu: int,
+ workers_per_gpu: int,
+ num_gpus: Optional[int] = 1,
+ dist: Optional[bool] = True,
+ shuffle: Optional[bool] = True,
+ round_up: Optional[bool] = True,
+ seed: Optional[Union[int, None]] = None,
+ persistent_workers: Optional[bool] = True,
+ **kwargs):
+ """Build PyTorch DataLoader.
+
+ In distributed training, each GPU/process has a dataloader.
+ In non-distributed training, there is only one dataloader for all GPUs.
+
+ Args:
+ dataset (:obj:`Dataset`): A PyTorch dataset.
+ samples_per_gpu (int): Number of training samples on each GPU, i.e.,
+ batch size of each GPU.
+ workers_per_gpu (int): How many subprocesses to use for data loading
+ for each GPU.
+ num_gpus (int, optional): Number of GPUs. Only used in non-distributed
+ training.
+ dist (bool, optional): Distributed training/test or not. Default: True.
+ shuffle (bool, optional): Whether to shuffle the data at every epoch.
+ Default: True.
+ round_up (bool, optional): Whether to round up the length of dataset by
+ adding extra samples to make it evenly divisible. Default: True.
+ persistent_workers (bool): If True, the data loader will not shutdown
+ the worker processes after a dataset has been consumed once.
+ This allows to maintain the workers Dataset instances alive.
+ The argument also has effect in PyTorch>=1.7.0.
+ Default: True
+ kwargs: any keyword argument to be used to initialize DataLoader
+
+ Returns:
+ DataLoader: A PyTorch dataloader.
+ """
+ rank, world_size = get_dist_info()
+ if dist:
+ sampler = DistributedSampler(dataset,
+ world_size,
+ rank,
+ shuffle=shuffle,
+ round_up=round_up)
+ shuffle = False
+ batch_size = samples_per_gpu
+ num_workers = workers_per_gpu
+ else:
+ sampler = None
+ batch_size = num_gpus * samples_per_gpu
+ num_workers = num_gpus * workers_per_gpu
+
+ init_fn = partial(
+ worker_init_fn, num_workers=num_workers, rank=rank,
+ seed=seed) if seed is not None else None
+
+ data_loader = DataLoader(dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ num_workers=num_workers,
+ collate_fn=partial(
+ collate, samples_per_gpu=samples_per_gpu),
+ pin_memory=False,
+ shuffle=shuffle,
+ worker_init_fn=init_fn,
+ persistent_workers=persistent_workers,
+ **kwargs)
+
+ return data_loader
+
+
+def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
+ """Init random seed for each worker."""
+ # The seed of each worker equals to
+ # num_worker * rank + worker_id + user_seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
diff --git a/detrsmpl/data/datasets/dataset_wrappers.py b/detrsmpl/data/datasets/dataset_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4228b1d2520b3019f90c42f1af12f9abf642cf8
--- /dev/null
+++ b/detrsmpl/data/datasets/dataset_wrappers.py
@@ -0,0 +1,45 @@
+from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
+from torch.utils.data.dataset import Dataset
+
+from .builder import DATASETS
+
+
+@DATASETS.register_module()
+class ConcatDataset(_ConcatDataset):
+ """A wrapper of concatenated dataset.
+
+ Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
+ add `get_cat_ids` function.
+
+ Args:
+ datasets (list[:obj:`Dataset`]): A list of datasets.
+ """
+ def __init__(self, datasets: list):
+ super(ConcatDataset, self).__init__(datasets)
+
+
+@DATASETS.register_module()
+class RepeatDataset(object):
+ """A wrapper of repeated dataset.
+
+ The length of repeated dataset will be `times` larger than the original
+ dataset. This is useful when the data loading time is long but the dataset
+ is small. Using RepeatDataset can reduce the data loading time between
+ epochs.
+
+ Args:
+ dataset (:obj:`Dataset`): The dataset to be repeated.
+ times (int): Repeat times.
+ """
+ def __init__(self, dataset: Dataset, times: int):
+ self.dataset = dataset
+ self.times = times
+ self.CLASSES = dataset.CLASSES
+
+ self._ori_len = len(self.dataset)
+
+ def __getitem__(self, idx: int):
+ return self.dataset[idx % self._ori_len]
+
+ def __len__(self):
+ return self.times * self._ori_len
diff --git a/detrsmpl/data/datasets/human_hybrik_dataset.py b/detrsmpl/data/datasets/human_hybrik_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2b4034cd7b065423fb407dd767408aa9902266a
--- /dev/null
+++ b/detrsmpl/data/datasets/human_hybrik_dataset.py
@@ -0,0 +1,452 @@
+import json
+import os
+import os.path
+from abc import ABCMeta
+from collections import OrderedDict
+from typing import List, Optional, Union
+
+import mmcv
+import numpy as np
+import torch
+
+from detrsmpl.core.conventions.keypoints_mapping import get_mapping
+from detrsmpl.core.evaluation import (
+ keypoint_3d_auc,
+ keypoint_3d_pck,
+ keypoint_mpjpe,
+ vertice_pve,
+)
+from detrsmpl.data.data_structures.human_data import HumanData
+from detrsmpl.models.body_models.builder import build_body_model
+from detrsmpl.utils.demo_utils import box2cs, xyxy2xywh
+from .base_dataset import BaseDataset
+from .builder import DATASETS
+
+
+@DATASETS.register_module()
+class HybrIKHumanImageDataset(BaseDataset, metaclass=ABCMeta):
+ """Dataset for HybrIK training. The dataset loads raw features and apply
+ specified transforms to return a dict containing the image tensors and
+ other information.
+
+ Args:
+
+ data_prefix (str): Path to a directory where preprocessed datasets are
+ held.
+ pipeline (list[dict | callable]): A sequence of data transforms.
+ dataset_name (str): accepted names include 'h36m', 'pw3d',
+ 'mpi_inf_3dhp', 'coco'
+ ann_file (str): Name of annotation file.
+ test_mode (bool): Store True when building test dataset.
+ Default: False.
+ """
+ # metric
+ ALLOWED_METRICS = {
+ 'mpjpe', 'pa-mpjpe', 'pve', '3dpck', 'pa-3dpck', '3dauc', 'pa-3dauc'
+ }
+
+ def __init__(self,
+ data_prefix: str,
+ pipeline: list,
+ dataset_name: str,
+ body_model: Optional[Union[dict, None]] = None,
+ ann_file: Optional[Union[str, None]] = None,
+ test_mode: Optional[bool] = False):
+ if dataset_name is not None:
+ self.dataset_name = dataset_name
+ self.test_mode = test_mode
+ super(HybrIKHumanImageDataset, self).__init__(data_prefix, pipeline,
+ ann_file, test_mode)
+ if body_model is not None:
+ self.body_model = build_body_model(body_model)
+ else:
+ self.body_model = None
+
+ def get_annotation_file(self):
+ """Obtain annotation file path from data prefix."""
+ ann_prefix = os.path.join(self.data_prefix, 'preprocessed_datasets')
+ self.ann_file = os.path.join(ann_prefix, self.ann_file)
+
+ @staticmethod
+ def get_3d_keypoints_vis(keypoints):
+ """Get 3d keypoints and visibility mask
+ Args:
+ keypoints (np.ndarray): 2d (NxKx3) or 3d (NxKx4) keypoints with
+ visibility. N refers to number of datapoints, K refers to number
+ of keypoints.
+
+ Returns:
+ joint_img (np.ndarray): (NxKx3) 3d keypoints
+ joint_vis (np.ndarray): (NxKx3) visibility mask for keypoints
+ """
+ keypoints, keypoints_vis = keypoints[:, :, :-1], keypoints[:, :, -1]
+ num_datapoints, num_keypoints, dim = keypoints.shape
+ joint_img = np.zeros((num_datapoints, num_keypoints, 3),
+ dtype=np.float32)
+ joint_vis = np.zeros((num_datapoints, num_keypoints, 3),
+ dtype=np.float32)
+ joint_img[:, :, :dim] = keypoints
+ joint_vis[:, :, :dim] = np.tile(np.expand_dims(keypoints_vis, axis=2),
+ (1, dim))
+ return joint_img, joint_vis
+
+ def load_annotations(self):
+ """Load annotations."""
+ self.get_annotation_file()
+ data = HumanData()
+ data.load(self.ann_file)
+
+ self.image_path = data['image_path']
+ self.num_data = len(self.image_path)
+
+ self.bbox_xyxy = data['bbox_xywh']
+ self.width = data['image_width']
+ self.height = data['image_height']
+ self.depth_factor = data['depth_factor']
+
+ try:
+ self.keypoints3d, self.keypoints3d_vis = self.get_3d_keypoints_vis(
+ data['keypoints2d'])
+ except KeyError:
+ self.keypoints3d, self.keypoints3d_vis = self.get_3d_keypoints_vis(
+ data['keypoints3d'])
+
+ try:
+ self.smpl = data['smpl']
+ if 'has_smpl' not in data.keys():
+ self.has_smpl = np.ones((self.num_data)).astype(np.float32)
+ else:
+ self.has_smpl = data['has_smpl'].astype(np.float32)
+ self.thetas = self.smpl['thetas'].astype(np.float32)
+ self.betas = self.smpl['betas'].astype(np.float32)
+
+ self.keypoints3d_relative, _ = self.get_3d_keypoints_vis(
+ data['keypoints3d_relative'])
+ self.keypoints3d17, self.keypoints3d17_vis = \
+ self.get_3d_keypoints_vis(data['keypoints3d17'])
+ self.keypoints3d17_relative, _ = self.get_3d_keypoints_vis(
+ data['keypoints3d17_relative'])
+
+ if self.test_mode:
+ self.keypoints3d_cam, _ = self.get_3d_keypoints_vis(
+ data['keypoints3d_cam'])
+ except KeyError:
+ self.has_smpl = np.zeros((self.num_data)).astype(np.float32)
+ if self.test_mode:
+ self.keypoints3d, self.keypoints3d_vis = \
+ self.get_3d_keypoints_vis(data['keypoints3d'])
+ self.keypoints3d_cam, _ = self.get_3d_keypoints_vis(
+ data['keypoints3d_cam'])
+
+ try:
+ self.intrinsic = data['cam_param']['intrinsic']
+ except KeyError:
+ self.intrinsic = np.zeros((self.num_data, 3, 3))
+
+ try:
+ self.target_twist = data['phi']
+ # self.target_twist_weight = np.ones_like((self.target_twist))
+ self.target_twist_weight = data['phi_weight']
+ except KeyError:
+ self.target_twist = np.zeros((self.num_data, 23, 2))
+ self.target_twist_weight = np.zeros_like((self.target_twist))
+
+ try:
+ self.root_cam = data['root_cam']
+ except KeyError:
+ self.root_cam = np.zeros((self.num_data, 3))
+
+ self.data_infos = []
+
+ for idx in range(self.num_data):
+ info = {}
+ info['ann_info'] = {}
+ info['img_prefix'] = None
+ info['image_path'] = os.path.join(self.data_prefix, 'datasets',
+ self.dataset_name,
+ self.image_path[idx])
+ bbox_xyxy = self.bbox_xyxy[idx]
+ info['bbox'] = bbox_xyxy[:4]
+ bbox_xywh = xyxy2xywh(bbox_xyxy)
+ center, scale = box2cs(bbox_xywh,
+ aspect_ratio=1.0,
+ bbox_scale_factor=1.25)
+
+ info['center'] = center
+ info['scale'] = scale
+ info['rotation'] = 0
+ info['ann_info']['dataset_name'] = self.dataset_name
+ info['ann_info']['height'] = self.height[idx]
+ info['ann_info']['width'] = self.width[idx]
+ info['depth_factor'] = float(self.depth_factor[idx])
+ info['has_smpl'] = int(self.has_smpl[idx])
+ info['joint_root'] = self.root_cam[idx].astype(np.float32)
+ info['intrinsic_param'] = self.intrinsic[idx].astype(np.float32)
+ info['target_twist'] = self.target_twist[idx].astype(
+ np.float32) # twist_phi
+ info['target_twist_weight'] = self.target_twist_weight[idx].astype(
+ np.float32)
+ info['keypoints3d'] = self.keypoints3d[idx]
+ info['keypoints3d_vis'] = self.keypoints3d_vis[idx]
+
+ if info['has_smpl']:
+ info['pose'] = self.thetas[idx]
+ info['beta'] = self.betas[idx].astype(np.float32)
+ info['keypoints3d_relative'] = self.keypoints3d_relative[idx]
+ info['keypoints3d17'] = self.keypoints3d17[idx]
+ info['keypoints3d17_vis'] = self.keypoints3d17_vis[idx]
+ info['keypoints3d17_relative'] = self.keypoints3d17_relative[
+ idx]
+
+ if self.test_mode:
+ info['joint_relative_17'] = self.keypoints3d17_relative[
+ idx].astype(np.float32)
+
+ else:
+ if self.test_mode:
+ info['joint_relative_17'] = self.keypoints3d_cam[
+ idx].astype(np.float32)
+
+ self.data_infos.append(info)
+
+ def evaluate(self,
+ outputs: list,
+ res_folder: str,
+ metric: Optional[Union[str, List[str]]] = 'pa-mpjpe',
+ **kwargs: dict):
+ """Evaluate 3D keypoint results.
+
+ Args:
+ outputs (list): results from model inference.
+ res_folder (str): path to store results.
+ metric (Optional[Union[str, List(str)]]):
+ the type of metric. Default: 'pa-mpjpe'
+ kwargs (dict): other arguments.
+ Returns:
+ dict:
+ A dict of all evaluation results.
+ """
+ metrics = metric if isinstance(metric, list) else [metric]
+ for metric in metrics:
+ if metric not in self.ALLOWED_METRICS:
+ raise ValueError(f'metric {metric} is not supported')
+
+ res_file = os.path.join(res_folder, 'result_keypoints.json')
+
+ res_dict = {}
+ for out in outputs:
+ target_id = out['image_idx']
+ batch_size = len(out['xyz_17'])
+ for i in range(batch_size):
+ res_dict[int(target_id[i])] = dict(
+ keypoints=out['xyz_17'][i],
+ poses=out['smpl_pose'][i],
+ betas=out['smpl_beta'][i],
+ )
+
+ keypoints, poses, betas = [], [], []
+ for i in range(self.num_data):
+ keypoints.append(res_dict[i]['keypoints'])
+ poses.append(res_dict[i]['poses'])
+ betas.append(res_dict[i]['betas'])
+
+ res = dict(keypoints=keypoints, poses=poses, betas=betas)
+ mmcv.dump(res, res_file)
+
+ name_value_tuples = []
+ for _metric in metrics:
+ if _metric == 'mpjpe':
+ _nv_tuples = self._report_mpjpe(res)
+ elif _metric == 'pa-mpjpe':
+ _nv_tuples = self._report_mpjpe(res, metric='pa-mpjpe')
+ elif _metric == '3dpck':
+ _nv_tuples = self._report_3d_pck(res)
+ elif _metric == 'pa-3dpck':
+ _nv_tuples = self._report_3d_pck(res, metric='pa-3dpck')
+ elif _metric == '3dauc':
+ _nv_tuples = self._report_3d_auc(res)
+ elif _metric == 'pa-3dauc':
+ _nv_tuples = self._report_3d_auc(res, metric='pa-3dauc')
+ elif _metric == 'pve':
+ _nv_tuples = self._report_pve(res)
+ else:
+ raise NotImplementedError
+ name_value_tuples.extend(_nv_tuples)
+
+ name_value = OrderedDict(name_value_tuples)
+ return name_value
+
+ @staticmethod
+ def _write_keypoint_results(keypoints, res_file):
+ """Write results into a json file."""
+ with open(res_file, 'w') as f:
+ json.dump(keypoints, f, sort_keys=True, indent=4)
+
+ def _parse_result(self, res, mode='keypoint'):
+ """Parse results."""
+ gts = self.data_infos
+ if mode == 'vertice':
+ pred_pose = torch.FloatTensor(res['poses'])
+ pred_beta = torch.FloatTensor(res['betas'])
+ pred_output = self.body_model(
+ betas=pred_beta,
+ body_pose=pred_pose[:, 1:],
+ global_orient=pred_pose[:, 0].unsqueeze(1),
+ pose2rot=False)
+ pred_vertices = pred_output['vertices'].detach().cpu().numpy()
+
+ gt_pose = torch.FloatTensor([gt['pose']
+ for gt in gts]).view(-1, 72)
+ gt_beta = torch.FloatTensor([gt['beta'] for gt in gts])
+ gt_output = self.body_model(betas=gt_beta,
+ body_pose=gt_pose[:, 3:],
+ global_orient=gt_pose[:, :3])
+ gt_vertices = gt_output['vertices'].detach().cpu().numpy()
+ gt_mask = np.ones(gt_vertices.shape[:-1])
+ assert len(pred_vertices) == self.num_data
+
+ return pred_vertices * 1000., gt_vertices * 1000., gt_mask
+ elif mode == 'keypoint':
+ pred_keypoints3d = res['keypoints']
+ assert len(pred_keypoints3d) == self.num_data
+ # (B, 17, 3)
+ pred_keypoints3d = np.array(pred_keypoints3d)
+ factor, root_idx_17 = 1, 0
+
+ if self.dataset_name == 'mpi_inf_3dhp':
+ _, hp3d_idxs, _ = get_mapping('human_data',
+ 'mpi_inf_3dhp_test')
+ gt_keypoints3d = np.array(
+ [gt['joint_relative_17'][hp3d_idxs] for gt in gts])
+ joint_mapper = [
+ 14, 11, 12, 13, 8, 9, 10, 15, 1, 16, 0, 5, 6, 7, 2, 3, 4
+ ]
+ gt_keypoints3d_mask = np.ones(
+ (len(gt_keypoints3d), len(joint_mapper)))
+ else:
+ _, h36m_idxs, _ = get_mapping('human_data', 'h36m')
+ gt_keypoints3d = np.array(
+ [gt['joint_relative_17'][h36m_idxs] for gt in gts])
+ joint_mapper = [
+ 6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10
+ ]
+ gt_keypoints3d_mask = np.ones(
+ (len(gt_keypoints3d), len(joint_mapper)))
+ if self.dataset_name == 'pw3d':
+ factor = 1000
+
+ assert len(pred_keypoints3d) == self.num_data
+
+ pred_keypoints3d = pred_keypoints3d * (2000 / factor)
+ if self.dataset_name == 'mpi_inf_3dhp':
+ gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :]
+ # root joint alignment
+ pred_keypoints3d = (
+ pred_keypoints3d -
+ pred_keypoints3d[:, None, root_idx_17]) * factor
+ gt_keypoints3d = (gt_keypoints3d -
+ gt_keypoints3d[:, None, root_idx_17]) * factor
+
+ if self.dataset_name == 'pw3d' or self.dataset_name == 'h36m':
+ # select eval 14 joints
+ pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :]
+ gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :]
+
+ gt_keypoints3d_mask = gt_keypoints3d_mask > 0
+
+ return pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask
+
+ else:
+ raise NotImplementedError()
+
+ def _report_mpjpe(self, res_file, metric='mpjpe'):
+ """Cauculate mean per joint position error (MPJPE) or its variants PA-
+ MPJPE.
+
+ Report mean per joint position error (MPJPE) and mean per joint
+ position error after rigid alignment (PA-MPJPE)
+ """
+ pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \
+ self._parse_result(res_file, mode='keypoint')
+
+ err_name = metric.upper()
+ if metric == 'mpjpe':
+ alignment = 'none'
+ elif metric == 'pa-mpjpe':
+ alignment = 'procrustes'
+ else:
+ raise ValueError(f'Invalid metric: {metric}')
+
+ error = keypoint_mpjpe(pred_keypoints3d, gt_keypoints3d,
+ gt_keypoints3d_mask, alignment)
+ info_str = [(err_name, error)]
+
+ return info_str
+
+ def _report_3d_pck(self, res_file, metric='3dpck'):
+ """Cauculate Percentage of Correct Keypoints (3DPCK) w. or w/o
+ Procrustes alignment.
+ Args:
+ keypoint_results (list): Keypoint predictions. See
+ 'Body3DMpiInf3dhpDataset.evaluate' for details.
+ metric (str): Specify mpjpe variants. Supported options are:
+ - ``'3dpck'``: Standard 3DPCK.
+ - ``'pa-3dpck'``:
+ 3DPCK after aligning prediction to groundtruth
+ via a rigid transformation (scale, rotation and
+ translation).
+ """
+
+ pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \
+ self._parse_result(res_file, mode='keypoint')
+
+ err_name = metric.upper()
+ if metric == '3dpck':
+ alignment = 'none'
+ elif metric == 'pa-3dpck':
+ alignment = 'procrustes'
+ else:
+ raise ValueError(f'Invalid metric: {metric}')
+
+ error = keypoint_3d_pck(pred_keypoints3d, gt_keypoints3d,
+ gt_keypoints3d_mask, alignment)
+ name_value_tuples = [(err_name, error)]
+
+ return name_value_tuples
+
+ def _report_3d_auc(self, res_file, metric='3dauc'):
+ """Cauculate the Area Under the Curve (AUC) computed for a range of
+ 3DPCK thresholds.
+ Args:
+ keypoint_results (list): Keypoint predictions. See
+ 'Body3DMpiInf3dhpDataset.evaluate' for details.
+ metric (str): Specify mpjpe variants. Supported options are:
+ - ``'3dauc'``: Standard 3DAUC.
+ - ``'pa-3dauc'``: 3DAUC after aligning prediction to
+ groundtruth via a rigid transformation (scale, rotation and
+ translation).
+ """
+
+ pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \
+ self._parse_result(res_file, mode='keypoint')
+
+ err_name = metric.upper()
+ if metric == '3dauc':
+ alignment = 'none'
+ elif metric == 'pa-3dauc':
+ alignment = 'procrustes'
+ else:
+ raise ValueError(f'Invalid metric: {metric}')
+
+ error = keypoint_3d_auc(pred_keypoints3d, gt_keypoints3d,
+ gt_keypoints3d_mask, alignment)
+ name_value_tuples = [(err_name, error)]
+
+ return name_value_tuples
+
+ def _report_pve(self, res_file):
+ """Cauculate per vertex error."""
+ pred_verts, gt_verts, _ = \
+ self._parse_result(res_file, mode='vertice')
+ error = vertice_pve(pred_verts, gt_verts)
+ return [('PVE', error)]
diff --git a/detrsmpl/data/datasets/human_image_dataset.py b/detrsmpl/data/datasets/human_image_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b6c121892d889ddb3ff3b3aece38c150ef22ef
--- /dev/null
+++ b/detrsmpl/data/datasets/human_image_dataset.py
@@ -0,0 +1,662 @@
+import json
+import os
+import os.path
+from abc import ABCMeta
+from collections import OrderedDict
+from typing import Any, List, Optional, Union
+
+import mmcv
+import numpy as np
+import torch
+import torch.distributed as dist
+from mmcv.runner import get_dist_info
+
+from detrsmpl.core.conventions.keypoints_mapping import (
+ convert_kps,
+ get_keypoint_num,
+ get_mapping,
+)
+from detrsmpl.core.evaluation import (
+ keypoint_3d_auc,
+ keypoint_3d_pck,
+ keypoint_mpjpe,
+ vertice_pve,
+)
+from detrsmpl.data.data_structures.human_data import HumanData
+from detrsmpl.data.data_structures.human_data_cache import (
+ HumanDataCacheReader,
+ HumanDataCacheWriter,
+)
+from detrsmpl.models.body_models.builder import build_body_model
+from .base_dataset import BaseDataset
+from .builder import DATASETS
+
+
+@DATASETS.register_module()
+class HumanImageDataset(BaseDataset, metaclass=ABCMeta):
+ """Human Image Dataset.
+
+ Args:
+ data_prefix (str): the prefix of data path.
+ pipeline (list): a list of dict, where each element represents
+ a operation defined in `detrsmpl.datasets.pipelines`.
+ dataset_name (str | None): the name of dataset. It is used to
+ identify the type of evaluation metric. Default: None.
+ body_model (dict | None, optional): the config for body model,
+ which will be used to generate meshes and keypoints.
+ Default: None.
+ ann_file (str | None, optional): the annotation file. When ann_file
+ is str, the subclass is expected to read from the ann_file.
+ When ann_file is None, the subclass is expected to read
+ according to data_prefix.
+ convention (str, optional): keypoints convention. Keypoints will be
+ converted from "human_data" to the given one.
+ Default: "human_data"
+ cache_data_path (str | None, optional): the path to store the cache
+ file. When cache_data_path is None, each dataset will store a copy
+ into memory. If cache_data_path is set, the dataset will first
+ create one cache file and then use a cache reader to reduce memory
+ cost and initialization time. The cache file will be generated
+ only once if they are not found at the the path. Otherwise, only
+ cache readers will be established.
+ test_mode (bool, optional): in train mode or test mode.
+ Default: False.
+ """
+ # metric
+ ALLOWED_METRICS = {
+ 'mpjpe', 'pa-mpjpe', 'pve', '3dpck', 'pa-3dpck', '3dauc', 'pa-3dauc',
+ 'ihmr'
+ }
+
+ def __init__(self,
+ data_prefix: str,
+ pipeline: list,
+ dataset_name: str,
+ body_model: Optional[Union[dict, None]] = None,
+ ann_file: Optional[Union[str, None]] = None,
+ convention: Optional[str] = 'human_data',
+ cache_data_path: Optional[Union[str, None]] = None,
+ test_mode: Optional[bool] = False):
+ self.convention = convention
+ self.num_keypoints = get_keypoint_num(convention)
+ self.cache_data_path = cache_data_path
+ super(HumanImageDataset,
+ self).__init__(data_prefix, pipeline, ann_file, test_mode,
+ dataset_name)
+ if body_model is not None:
+ self.body_model = build_body_model(body_model)
+ else:
+ self.body_model = None
+
+ def get_annotation_file(self):
+ """Get path of the annotation file."""
+ ann_prefix = os.path.join(self.data_prefix, 'preprocessed_datasets')
+ self.ann_file = os.path.join(ann_prefix, self.ann_file)
+
+ def load_annotations(self):
+ """Load annotation from the annotation file.
+
+ Here we simply use :obj:`HumanData` to parse the annotation.
+ """
+ rank, world_size = get_dist_info()
+ self.get_annotation_file()
+ if self.cache_data_path is None:
+ use_human_data = True
+ elif rank == 0 and not os.path.exists(self.cache_data_path):
+ use_human_data = True
+ else:
+ use_human_data = False
+ if use_human_data:
+ self.human_data = HumanData.fromfile(self.ann_file)
+
+ if self.human_data.check_keypoints_compressed():
+ self.human_data.decompress_keypoints()
+ # change keypoint from 'human_data' to the given convention
+ if 'keypoints3d' in self.human_data:
+ keypoints3d = self.human_data['keypoints3d']
+ assert 'keypoints3d_mask' in self.human_data
+ keypoints3d_mask = self.human_data['keypoints3d_mask']
+ keypoints3d, keypoints3d_mask = \
+ convert_kps(
+ keypoints3d,
+ src='human_data',
+ dst=self.convention,
+ mask=keypoints3d_mask)
+ self.human_data.__setitem__('keypoints3d', keypoints3d)
+ self.human_data.__setitem__('keypoints3d_convention',
+ self.convention)
+ self.human_data.__setitem__('keypoints3d_mask',
+ keypoints3d_mask)
+ if 'keypoints2d' in self.human_data:
+ keypoints2d = self.human_data['keypoints2d']
+ assert 'keypoints2d_mask' in self.human_data
+ keypoints2d_mask = self.human_data['keypoints2d_mask']
+ keypoints2d, keypoints2d_mask = \
+ convert_kps(
+ keypoints2d,
+ src='human_data',
+ dst=self.convention,
+ mask=keypoints2d_mask)
+ self.human_data.__setitem__('keypoints2d', keypoints2d)
+ self.human_data.__setitem__('keypoints2d_convention',
+ self.convention)
+ self.human_data.__setitem__('keypoints2d_mask',
+ keypoints2d_mask)
+ self.human_data.compress_keypoints_by_mask()
+
+ if self.cache_data_path is not None:
+ if rank == 0 and not os.path.exists(self.cache_data_path):
+ writer_kwargs, sliced_data = self.human_data.get_sliced_cache()
+ writer = HumanDataCacheWriter(**writer_kwargs)
+ writer.update_sliced_dict(sliced_data)
+ writer.dump(self.cache_data_path)
+ if world_size > 1:
+ dist.barrier()
+ self.cache_reader = HumanDataCacheReader(
+ npz_path=self.cache_data_path)
+ self.num_data = self.cache_reader.data_len
+ self.human_data = None
+ else:
+ self.cache_reader = None
+ self.num_data = self.human_data.data_len
+
+ def prepare_raw_data(self, idx: int):
+ """Get item from self.human_data."""
+ sample_idx = idx
+ if self.cache_reader is not None:
+ self.human_data = self.cache_reader.get_item(idx)
+ idx = idx % self.cache_reader.slice_size
+ info = {}
+ info['img_prefix'] = None
+ image_path = self.human_data['image_path'][idx]
+ info['image_path'] = os.path.join(self.data_prefix, 'datasets',
+ self.dataset_name, image_path)
+ if image_path.endswith('smc'):
+ device, device_id, frame_id = self.human_data['image_id'][idx]
+ info['image_id'] = (device, int(device_id), int(frame_id))
+
+ info['dataset_name'] = self.dataset_name
+ info['sample_idx'] = sample_idx
+ if 'bbox_xywh' in self.human_data:
+ info['bbox_xywh'] = self.human_data['bbox_xywh'][idx]
+ x, y, w, h, s = info['bbox_xywh']
+ cx = x + w / 2
+ cy = y + h / 2
+ w = h = max(w, h)
+ info['center'] = np.array([cx, cy])
+ info['scale'] = np.array([w, h])
+ else:
+ info['bbox_xywh'] = np.zeros((5))
+ info['center'] = np.zeros((2))
+ info['scale'] = np.zeros((2))
+
+ # in later modules, we will check validity of each keypoint by
+ # its confidence. Therefore, we do not need the mask of keypoints.
+
+ if 'keypoints2d' in self.human_data:
+ info['keypoints2d'] = self.human_data['keypoints2d'][idx]
+ info['has_keypoints2d'] = 1
+ else:
+ info['keypoints2d'] = np.zeros((self.num_keypoints, 3))
+ info['has_keypoints2d'] = 0
+ if 'keypoints3d' in self.human_data:
+ info['keypoints3d'] = self.human_data['keypoints3d'][idx]
+ info['has_keypoints3d'] = 1
+ else:
+ info['keypoints3d'] = np.zeros((self.num_keypoints, 4))
+ info['has_keypoints3d'] = 0
+
+ if 'smpl' in self.human_data:
+ smpl_dict = self.human_data['smpl']
+ else:
+ smpl_dict = {}
+
+ if 'smpl' in self.human_data:
+ if 'has_smpl' in self.human_data:
+ info['has_smpl'] = int(self.human_data['has_smpl'][idx])
+ else:
+ info['has_smpl'] = 1
+ else:
+ info['has_smpl'] = 0
+ if 'body_pose' in smpl_dict:
+ info['smpl_body_pose'] = smpl_dict['body_pose'][idx]
+ else:
+ info['smpl_body_pose'] = np.zeros((23, 3))
+
+ if 'global_orient' in smpl_dict:
+ info['smpl_global_orient'] = smpl_dict['global_orient'][idx]
+ else:
+ info['smpl_global_orient'] = np.zeros((3))
+
+ if 'betas' in smpl_dict:
+ info['smpl_betas'] = smpl_dict['betas'][idx]
+ else:
+ info['smpl_betas'] = np.zeros((10))
+
+ if 'transl' in smpl_dict:
+ info['smpl_transl'] = smpl_dict['transl'][idx]
+ else:
+ info['smpl_transl'] = np.zeros((3))
+
+ return info
+
+ def prepare_data(self, idx: int):
+ """Generate and transform data."""
+ info = self.prepare_raw_data(idx)
+ return self.pipeline(info)
+
+ def evaluate(self,
+ outputs: list,
+ res_folder: str,
+ metric: Optional[Union[str, List[str]]] = 'pa-mpjpe',
+ **kwargs: dict):
+ """Evaluate 3D keypoint results.
+
+ Args:
+ outputs (list): results from model inference.
+ res_folder (str): path to store results.
+ metric (Optional[Union[str, List(str)]]):
+ the type of metric. Default: 'pa-mpjpe'
+ kwargs (dict): other arguments.
+ Returns:
+ dict:
+ A dict of all evaluation results.
+ """
+ metrics = metric if isinstance(metric, list) else [metric]
+ for metric in metrics:
+ if metric not in self.ALLOWED_METRICS:
+ raise KeyError(f'metric {metric} is not supported')
+
+ res_file = os.path.join(res_folder, 'result_keypoints.json')
+ # for keeping correctness during multi-gpu test, we sort all results
+
+ res_dict = {}
+ for out in outputs:
+ target_id = out['image_idx']
+ batch_size = len(out['keypoints_3d'])
+ for i in range(batch_size):
+ res_dict[int(target_id[i])] = dict(
+ keypoints=out['keypoints_3d'][i],
+ poses=out['smpl_pose'][i],
+ betas=out['smpl_beta'][i],
+ )
+
+ keypoints, poses, betas = [], [], []
+ for i in range(self.num_data):
+ keypoints.append(res_dict[i]['keypoints'])
+ poses.append(res_dict[i]['poses'])
+ betas.append(res_dict[i]['betas'])
+
+ res = dict(keypoints=keypoints, poses=poses, betas=betas)
+ mmcv.dump(res, res_file)
+
+ name_value_tuples = []
+ for _metric in metrics:
+ if _metric == 'mpjpe':
+ _nv_tuples = self._report_mpjpe(res)
+ elif _metric == 'pa-mpjpe':
+ _nv_tuples = self._report_mpjpe(res, metric='pa-mpjpe')
+ elif _metric == '3dpck':
+ _nv_tuples = self._report_3d_pck(res)
+ elif _metric == 'pa-3dpck':
+ _nv_tuples = self._report_3d_pck(res, metric='pa-3dpck')
+ elif _metric == '3dauc':
+ _nv_tuples = self._report_3d_auc(res)
+ elif _metric == 'pa-3dauc':
+ _nv_tuples = self._report_3d_auc(res, metric='pa-3dauc')
+ elif _metric == 'pve':
+ _nv_tuples = self._report_pve(res)
+ elif _metric == 'ihmr':
+ _nv_tuples = self._report_ihmr(res)
+ else:
+ raise NotImplementedError
+ name_value_tuples.extend(_nv_tuples)
+
+ name_value = OrderedDict(name_value_tuples)
+ return name_value
+
+ @staticmethod
+ def _write_keypoint_results(keypoints: Any, res_file: str):
+ """Write results into a json file."""
+
+ with open(res_file, 'w') as f:
+ json.dump(keypoints, f, sort_keys=True, indent=4)
+
+ def _parse_result(self, res, mode='keypoint', body_part=None):
+ """Parse results."""
+
+ if mode == 'vertice':
+ # gt
+ gt_beta, gt_pose, gt_global_orient, gender = [], [], [], []
+ gt_smpl_dict = self.human_data['smpl']
+ for idx in range(self.num_data):
+ gt_beta.append(gt_smpl_dict['betas'][idx])
+ gt_pose.append(gt_smpl_dict['body_pose'][idx])
+ gt_global_orient.append(gt_smpl_dict['global_orient'][idx])
+ if self.human_data['meta']['gender'][idx] == 'm':
+ gender.append(0)
+ else:
+ gender.append(1)
+ gt_beta = torch.FloatTensor(gt_beta)
+ gt_pose = torch.FloatTensor(gt_pose).view(-1, 69)
+ gt_global_orient = torch.FloatTensor(gt_global_orient)
+ gender = torch.Tensor(gender)
+ gt_output = self.body_model(betas=gt_beta,
+ body_pose=gt_pose,
+ global_orient=gt_global_orient,
+ gender=gender)
+ gt_vertices = gt_output['vertices'].detach().cpu().numpy() * 1000.
+ gt_mask = np.ones(gt_vertices.shape[:-1])
+ # pred
+ pred_pose = torch.FloatTensor(res['poses'])
+ pred_beta = torch.FloatTensor(res['betas'])
+ pred_output = self.body_model(
+ betas=pred_beta,
+ body_pose=pred_pose[:, 1:],
+ global_orient=pred_pose[:, 0].unsqueeze(1),
+ pose2rot=False,
+ gender=gender)
+ pred_vertices = pred_output['vertices'].detach().cpu().numpy(
+ ) * 1000.
+
+ assert len(pred_vertices) == self.num_data
+
+ return pred_vertices, gt_vertices, gt_mask
+ elif mode == 'keypoint':
+ pred_keypoints3d = res['keypoints']
+ assert len(pred_keypoints3d) == self.num_data
+ # (B, 17, 3)
+ pred_keypoints3d = np.array(pred_keypoints3d)
+
+ if self.dataset_name == 'pw3d':
+ betas = []
+ body_pose = []
+ global_orient = []
+ gender = []
+ smpl_dict = self.human_data['smpl']
+ for idx in range(self.num_data):
+ betas.append(smpl_dict['betas'][idx])
+ body_pose.append(smpl_dict['body_pose'][idx])
+ global_orient.append(smpl_dict['global_orient'][idx])
+ if self.human_data['meta']['gender'][idx] == 'm':
+ gender.append(0)
+ else:
+ gender.append(1)
+ betas = torch.FloatTensor(betas)
+ body_pose = torch.FloatTensor(body_pose).view(-1, 69)
+ global_orient = torch.FloatTensor(global_orient)
+ gender = torch.Tensor(gender)
+ gt_output = self.body_model(betas=betas,
+ body_pose=body_pose,
+ global_orient=global_orient,
+ gender=gender)
+ gt_keypoints3d = gt_output['joints'].detach().cpu().numpy()
+ gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 24))
+ elif self.dataset_name == 'h36m':
+ _, h36m_idxs, _ = get_mapping('human_data', 'h36m')
+ gt_keypoints3d = \
+ self.human_data['keypoints3d'][:, h36m_idxs, :3]
+ gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 17))
+ elif self.dataset_name == 'humman':
+ betas = []
+ body_pose = []
+ global_orient = []
+ smpl_dict = self.human_data['smpl']
+ for idx in range(self.num_data):
+ betas.append(smpl_dict['betas'][idx])
+ body_pose.append(smpl_dict['body_pose'][idx])
+ global_orient.append(smpl_dict['global_orient'][idx])
+ betas = torch.FloatTensor(betas)
+ body_pose = torch.FloatTensor(body_pose).view(-1, 69)
+ global_orient = torch.FloatTensor(global_orient)
+ gt_output = self.body_model(betas=betas,
+ body_pose=body_pose,
+ global_orient=global_orient)
+ gt_keypoints3d = gt_output['joints'].detach().cpu().numpy()
+ gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 24))
+ else:
+ raise NotImplementedError()
+
+ # SMPL_49 only!
+ if gt_keypoints3d.shape[1] == 49:
+ assert pred_keypoints3d.shape[1] == 49
+
+ gt_keypoints3d = gt_keypoints3d[:, 25:, :]
+ pred_keypoints3d = pred_keypoints3d[:, 25:, :]
+
+ joint_mapper = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]
+ gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :]
+ pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :]
+
+ # we only evaluate on 14 lsp joints
+ pred_pelvis = (pred_keypoints3d[:, 2] +
+ pred_keypoints3d[:, 3]) / 2
+ gt_pelvis = (gt_keypoints3d[:, 2] + gt_keypoints3d[:, 3]) / 2
+
+ # H36M for testing!
+ elif gt_keypoints3d.shape[1] == 17:
+ assert pred_keypoints3d.shape[1] == 17
+
+ H36M_TO_J17 = [
+ 6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9
+ ]
+ H36M_TO_J14 = H36M_TO_J17[:14]
+ joint_mapper = H36M_TO_J14
+
+ pred_pelvis = pred_keypoints3d[:, 0]
+ gt_pelvis = gt_keypoints3d[:, 0]
+
+ gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :]
+ pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :]
+
+ # keypoint 24
+ elif gt_keypoints3d.shape[1] == 24:
+ assert pred_keypoints3d.shape[1] == 24
+
+ joint_mapper = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]
+ gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :]
+ pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :]
+
+ # we only evaluate on 14 lsp joints
+ pred_pelvis = (pred_keypoints3d[:, 2] +
+ pred_keypoints3d[:, 3]) / 2
+ gt_pelvis = (gt_keypoints3d[:, 2] + gt_keypoints3d[:, 3]) / 2
+
+ else:
+ pass
+
+ pred_keypoints3d = (pred_keypoints3d -
+ pred_pelvis[:, None, :]) * 1000
+ gt_keypoints3d = (gt_keypoints3d - gt_pelvis[:, None, :]) * 1000
+
+ gt_keypoints3d_mask = gt_keypoints3d_mask[:, joint_mapper] > 0
+
+ return pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask
+
+ def _report_mpjpe(self, res_file, metric='mpjpe', body_part=''):
+ """Cauculate mean per joint position error (MPJPE) or its variants PA-
+ MPJPE.
+
+ Report mean per joint position error (MPJPE) and mean per joint
+ position error after rigid alignment (PA-MPJPE)
+ """
+ pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \
+ self._parse_result(res_file, mode='keypoint', body_part=body_part)
+
+ err_name = metric.upper()
+ if body_part != '':
+ err_name = body_part.upper() + ' ' + err_name
+
+ if metric == 'mpjpe':
+ alignment = 'none'
+ elif metric == 'pa-mpjpe':
+ alignment = 'procrustes'
+ else:
+ raise ValueError(f'Invalid metric: {metric}')
+
+ error = keypoint_mpjpe(pred_keypoints3d, gt_keypoints3d,
+ gt_keypoints3d_mask, alignment)
+ info_str = [(err_name, error)]
+
+ return info_str
+
+ def _report_3d_pck(self, res_file, metric='3dpck'):
+ """Cauculate Percentage of Correct Keypoints (3DPCK) w. or w/o
+ Procrustes alignment.
+ Args:
+ keypoint_results (list): Keypoint predictions. See
+ 'Body3DMpiInf3dhpDataset.evaluate' for details.
+ metric (str): Specify mpjpe variants. Supported options are:
+ - ``'3dpck'``: Standard 3DPCK.
+ - ``'pa-3dpck'``:
+ 3DPCK after aligning prediction to groundtruth
+ via a rigid transformation (scale, rotation and
+ translation).
+ """
+
+ pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \
+ self._parse_result(res_file)
+
+ err_name = metric.upper()
+ if metric == '3dpck':
+ alignment = 'none'
+ elif metric == 'pa-3dpck':
+ alignment = 'procrustes'
+ else:
+ raise ValueError(f'Invalid metric: {metric}')
+
+ error = keypoint_3d_pck(pred_keypoints3d, gt_keypoints3d,
+ gt_keypoints3d_mask, alignment)
+ name_value_tuples = [(err_name, error)]
+
+ return name_value_tuples
+
+ def _report_3d_auc(self, res_file, metric='3dauc'):
+ """Cauculate the Area Under the Curve (AUC) computed for a range of
+ 3DPCK thresholds.
+ Args:
+ keypoint_results (list): Keypoint predictions. See
+ 'Body3DMpiInf3dhpDataset.evaluate' for details.
+ metric (str): Specify mpjpe variants. Supported options are:
+ - ``'3dauc'``: Standard 3DAUC.
+ - ``'pa-3dauc'``: 3DAUC after aligning prediction to
+ groundtruth via a rigid transformation (scale, rotation and
+ translation).
+ """
+
+ pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \
+ self._parse_result(res_file)
+
+ err_name = metric.upper()
+ if metric == '3dauc':
+ alignment = 'none'
+ elif metric == 'pa-3dauc':
+ alignment = 'procrustes'
+ else:
+ raise ValueError(f'Invalid metric: {metric}')
+
+ error = keypoint_3d_auc(pred_keypoints3d, gt_keypoints3d,
+ gt_keypoints3d_mask, alignment)
+ name_value_tuples = [(err_name, error)]
+
+ return name_value_tuples
+
+ def _report_pve(self, res_file, metric='pve', body_part=''):
+ """Cauculate per vertex error."""
+ pred_verts, gt_verts, _ = \
+ self._parse_result(res_file, mode='vertice', body_part=body_part)
+ err_name = metric.upper()
+ if body_part != '':
+ err_name = body_part.upper() + ' ' + err_name
+
+ if metric == 'pve':
+ alignment = 'none'
+ elif metric == 'pa-pve':
+ alignment = 'procrustes'
+ else:
+ raise ValueError(f'Invalid metric: {metric}')
+ error = vertice_pve(pred_verts, gt_verts, alignment)
+ return [(err_name, error)]
+
+ def _report_ihmr(self, res_file):
+ """Calculate IHMR metric.
+
+ https://arxiv.org/abs/2203.16427
+ """
+ pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \
+ self._parse_result(res_file, mode='keypoint')
+
+ pred_verts, gt_verts, _ = \
+ self._parse_result(res_file, mode='vertice')
+
+ from detrsmpl.utils.geometry import rot6d_to_rotmat
+ mean_param_path = 'data/body_models/smpl_mean_params.npz'
+ mean_params = np.load(mean_param_path)
+ mean_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
+ mean_shape = torch.from_numpy(
+ mean_params['shape'][:].astype('float32')).unsqueeze(0)
+ mean_pose = rot6d_to_rotmat(mean_pose).view(1, 24, 3, 3)
+ mean_output = self.body_model(betas=mean_shape,
+ body_pose=mean_pose[:, 1:],
+ global_orient=mean_pose[:, :1],
+ pose2rot=False)
+ mean_verts = mean_output['vertices'].detach().cpu().numpy() * 1000.
+ dis = (gt_verts - mean_verts) * (gt_verts - mean_verts)
+ dis = np.sqrt(dis.sum(axis=-1)).mean(axis=-1)
+ # from the most remote one to the nearest one
+ idx_order = np.argsort(dis)[::-1]
+ num_data = idx_order.shape[0]
+
+ def report_ihmr_idx(idx):
+ mpvpe = vertice_pve(pred_verts[idx], gt_verts[idx])
+ mpjpe = keypoint_mpjpe(pred_keypoints3d[idx], gt_keypoints3d[idx],
+ gt_keypoints3d_mask[idx], 'none')
+ pampjpe = keypoint_mpjpe(pred_keypoints3d[idx],
+ gt_keypoints3d[idx],
+ gt_keypoints3d_mask[idx], 'procrustes')
+ return (mpvpe, mpjpe, pampjpe)
+
+ def report_ihmr_tail(percentage):
+ cur_data = int(num_data * percentage / 100.0)
+ idx = idx_order[:cur_data]
+ mpvpe, mpjpe, pampjpe = report_ihmr_idx(idx)
+ res_mpvpe = ('bMPVPE Tail ' + str(percentage) + '%', mpvpe)
+ res_mpjpe = ('bMPJPE Tail ' + str(percentage) + '%', mpjpe)
+ res_pampjpe = ('bPA-MPJPE Tail ' + str(percentage) + '%', pampjpe)
+ return [res_mpvpe, res_mpjpe, res_pampjpe]
+
+ def report_ihmr_all(num_bin):
+ num_per_bin = np.array([0 for _ in range(num_bin)
+ ]).astype(np.float32)
+ sum_mpvpe = np.array([0
+ for _ in range(num_bin)]).astype(np.float32)
+ sum_mpjpe = np.array([0
+ for _ in range(num_bin)]).astype(np.float32)
+ sum_pampjpe = np.array([0 for _ in range(num_bin)
+ ]).astype(np.float32)
+ max_dis = dis[idx_order[0]]
+ min_dis = dis[idx_order[-1]]
+ delta = (max_dis - min_dis) / num_bin
+ for i in range(num_data):
+ idx = int((dis[i] - min_dis) / delta - 0.001)
+ res_mpvpe, res_mpjpe, res_pampjpe = report_ihmr_idx([i])
+ num_per_bin[idx] += 1
+ sum_mpvpe[idx] += res_mpvpe
+ sum_mpjpe[idx] += res_mpjpe
+ sum_pampjpe[idx] += res_pampjpe
+ for i in range(num_bin):
+ if num_per_bin[i] > 0:
+ sum_mpvpe[i] = sum_mpvpe[i] / num_per_bin[i]
+ sum_mpjpe[i] = sum_mpjpe[i] / num_per_bin[i]
+ sum_pampjpe[i] = sum_pampjpe[i] / num_per_bin[i]
+ valid_idx = np.where(num_per_bin > 0)
+ res_mpvpe = ('bMPVPE All', sum_mpvpe[valid_idx].mean())
+ res_mpjpe = ('bMPJPE All', sum_mpjpe[valid_idx].mean())
+ res_pampjpe = ('bPA-MPJPE All', sum_pampjpe[valid_idx].mean())
+ return [res_mpvpe, res_mpjpe, res_pampjpe]
+
+ metrics = []
+ metrics.extend(report_ihmr_all(num_bin=100))
+ metrics.extend(report_ihmr_tail(percentage=10))
+ metrics.extend(report_ihmr_tail(percentage=5))
+ return metrics
diff --git a/detrsmpl/data/datasets/human_image_smplx_dataset.py b/detrsmpl/data/datasets/human_image_smplx_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cc8ae84734317bea3031d293a4d9c97ed2fd7a4
--- /dev/null
+++ b/detrsmpl/data/datasets/human_image_smplx_dataset.py
@@ -0,0 +1,386 @@
+import os
+import os.path
+import pickle
+from collections import OrderedDict
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+
+from detrsmpl.core.conventions.keypoints_mapping import (
+ get_keypoint_idx,
+ get_keypoint_idxs_by_part,
+)
+from detrsmpl.core.evaluation import fg_vertices_to_mesh_distance
+from detrsmpl.utils.transforms import aa_to_rotmat
+from .builder import DATASETS
+from .human_image_dataset import HumanImageDataset
+
+
+@DATASETS.register_module()
+class HumanImageSMPLXDataset(HumanImageDataset):
+
+ # metric
+ ALLOWED_METRICS = {
+ 'mpjpe', 'pa-mpjpe', 'pve', '3dpck', 'pa-3dpck', '3dauc', 'pa-3dauc',
+ '3DRMSE', 'pa-pve'
+ }
+
+ def __init__(
+ self,
+ data_prefix: str,
+ pipeline: list,
+ dataset_name: str,
+ body_model: Optional[Union[dict, None]] = None,
+ ann_file: Optional[Union[str, None]] = None,
+ convention: Optional[str] = 'human_data',
+ cache_data_path: Optional[Union[str, None]] = None,
+ test_mode: Optional[bool] = False,
+ num_betas: Optional[int] = 10,
+ num_expression: Optional[int] = 10,
+ face_vertex_ids_path: Optional[str] = None,
+ hand_vertex_ids_path: Optional[str] = None,
+ ):
+ super().__init__(data_prefix, pipeline, dataset_name, body_model,
+ ann_file, convention, cache_data_path, test_mode)
+ self.num_betas = num_betas
+ self.num_expression = num_expression
+ if face_vertex_ids_path is not None:
+ if os.path.exists(face_vertex_ids_path):
+ self.face_vertex_ids = np.load(face_vertex_ids_path).astype(
+ np.int32)
+ if hand_vertex_ids_path is not None:
+ if os.path.exists(hand_vertex_ids_path):
+ with open(hand_vertex_ids_path, 'rb') as f:
+ vertex_idxs_data = pickle.load(f, encoding='latin1')
+ self.left_hand_vertex_ids = vertex_idxs_data['left_hand']
+ self.right_hand_vertex_ids = vertex_idxs_data['right_hand']
+
+ def prepare_raw_data(self, idx: int):
+ """Get item from self.human_data."""
+ info = super().prepare_raw_data(idx)
+ if self.cache_reader is not None:
+ self.human_data = self.cache_reader.get_item(idx)
+ idx = idx % self.cache_reader.slice_size
+
+ if 'smplx' in self.human_data:
+ smplx_dict = self.human_data['smplx']
+ info['has_smplx'] = 1
+ else:
+ smplx_dict = {}
+ info['has_smplx'] = 0
+ if 'global_orient' in smplx_dict:
+ info['smplx_global_orient'] = smplx_dict['global_orient'][idx]
+ info['has_smplx_global_orient'] = 1
+ else:
+ info['smplx_global_orient'] = np.zeros((3), dtype=np.float32)
+ info['has_smplx_global_orient'] = 0
+
+ if 'body_pose' in smplx_dict:
+ info['smplx_body_pose'] = smplx_dict['body_pose'][idx]
+ info['has_smplx_body_pose'] = 1
+ else:
+ info['smplx_body_pose'] = np.zeros((21, 3), dtype=np.float32)
+ info['has_smplx_body_pose'] = 0
+
+ if 'right_hand_pose' in smplx_dict:
+ info['smplx_right_hand_pose'] = smplx_dict['right_hand_pose'][idx]
+ info['has_smplx_right_hand_pose'] = 1
+ else:
+ info['smplx_right_hand_pose'] = np.zeros((15, 3), dtype=np.float32)
+ info['has_smplx_right_hand_pose'] = 0
+
+ if 'left_hand_pose' in smplx_dict:
+ info['smplx_left_hand_pose'] = smplx_dict['left_hand_pose'][idx]
+ info['has_smplx_left_hand_pose'] = 1
+ else:
+ info['smplx_left_hand_pose'] = np.zeros((15, 3), dtype=np.float32)
+ info['has_smplx_left_hand_pose'] = 0
+
+ if 'jaw_pose' in smplx_dict:
+ info['smplx_jaw_pose'] = smplx_dict['jaw_pose'][idx]
+ info['has_smplx_jaw_pose'] = 1
+ else:
+ info['smplx_jaw_pose'] = np.zeros((3), dtype=np.float32)
+ info['has_smplx_jaw_pose'] = 0
+
+ if 'betas' in smplx_dict:
+ info['smplx_betas'] = smplx_dict['betas'][idx]
+ info['has_smplx_betas'] = 1
+ else:
+ info['smplx_betas'] = np.zeros((self.num_betas), dtype=np.float32)
+ info['has_smplx_betas'] = 0
+
+ if 'expression' in smplx_dict:
+ info['smplx_expression'] = smplx_dict['expression'][idx]
+ info['has_smplx_expression'] = 1
+ else:
+ info['smplx_expression'] = np.zeros((self.num_expression),
+ dtype=np.float32)
+ info['has_smplx_expression'] = 0
+
+ return info
+
+ def _parse_result(self, res, mode='keypoint', body_part=''):
+ if mode == 'vertice':
+ # pred
+ pred_vertices = res['vertices'] * 1000.
+ # gt
+ if 'vertices' in self.human_data: # stirling or ehf
+ gt_vertices = self.human_data['vertices'].copy()
+ if self.dataset_name == 'EHF':
+ gt_vertices = gt_vertices * 1000.
+ else:
+ gt_param_dict = self.human_data['smplx'].copy()
+ for key, value in gt_param_dict.items():
+ new_value = torch.FloatTensor(value)
+ if ('pose' in key or key
+ == 'global_orient') and value.shape[-2] != 3:
+ new_value = aa_to_rotmat(new_value)
+ gt_param_dict[key] = new_value
+ gt_output = self.body_model(**gt_param_dict)
+ gt_vertices = gt_output['vertices'].detach().cpu().numpy(
+ ) * 1000.
+
+ if body_part == 'right_hand':
+ pred_vertices = pred_vertices[:, self.right_hand_vertex_ids]
+ gt_vertices = gt_vertices[:, self.right_hand_vertex_ids]
+ elif body_part == 'left_hand':
+ pred_vertices = pred_vertices[:, self.left_hand_vertex_ids]
+ gt_vertices = gt_vertices[:, self.left_hand_vertex_ids]
+ elif body_part == 'face':
+ pred_vertices = pred_vertices[:, self.face_vertex_ids]
+ gt_vertices = gt_vertices[:, self.face_vertex_ids]
+
+ gt_mask = np.ones(gt_vertices.shape[:-1])
+ assert len(pred_vertices) == self.num_data
+
+ return pred_vertices, gt_vertices, gt_mask
+ elif mode == 'keypoint':
+ pred_keypoints3d = res['keypoints']
+ assert len(pred_keypoints3d) == self.num_data
+ if self.dataset_name in {'pw3d', '3DPW', '3dpw'}:
+ betas = []
+ body_pose = []
+ global_orient = []
+ gender = []
+ smpl_dict = self.human_data['smpl']
+ for idx in range(self.num_data):
+ betas.append(smpl_dict['betas'][idx])
+ body_pose.append(smpl_dict['body_pose'][idx])
+ global_orient.append(smpl_dict['global_orient'][idx])
+ if self.human_data['meta']['gender'][idx] == 'm':
+ gender.append(0)
+ else:
+ gender.append(1)
+ betas = torch.FloatTensor(betas)
+ body_pose = torch.FloatTensor(body_pose).view(-1, 69)
+ global_orient = torch.FloatTensor(global_orient)
+ gender = torch.Tensor(gender)
+ gt_output = self.body_model(betas=betas,
+ body_pose=body_pose,
+ global_orient=global_orient,
+ gender=gender)
+ gt_keypoints3d = gt_output['joints'].detach().cpu().numpy()
+ gt_keypoints3d_mask = np.ones(
+ (len(pred_keypoints3d), gt_keypoints3d.shape[1]))
+ elif self.dataset_name == 'EHF':
+ gt_vertices = self.human_data['vertices'].copy()
+ if body_part == 'J14':
+ gt_keypoints3d = torch.einsum('bik,ji->bjk', [
+ torch.from_numpy(gt_vertices).float(),
+ self.body_model.joints_regressor
+ ]).numpy()
+ pred_vertices = res['vertices']
+ pred_keypoints3d = torch.einsum('bik,ji->bjk', [
+ torch.from_numpy(pred_vertices).float(),
+ self.body_model.joints_regressor
+ ]).numpy()
+ gt_keypoints3d_mask = np.ones(
+ (len(pred_keypoints3d), gt_keypoints3d.shape[1]))
+ else:
+ gt_keypoints3d = torch.einsum('bik,ji->bjk', [
+ torch.from_numpy(gt_vertices).float(),
+ self.body_model.J_regressor
+ ]).numpy()
+ extra_joints_idxs = np.array([
+ 9120, 9929, 9448, 616, 6, 5770, 5780, 8846, 8463, 8474,
+ 8635, 5361, 4933, 5058, 5169, 5286, 8079, 7669, 7794,
+ 7905, 8022
+ ])
+ gt_keypoints3d = np.concatenate(
+ (gt_keypoints3d, gt_vertices[:, extra_joints_idxs]),
+ axis=1)
+ pred_vertices = res['vertices']
+ pred_keypoints3d = torch.einsum('bik,ji->bjk', [
+ torch.from_numpy(pred_vertices).float(),
+ self.body_model.J_regressor
+ ]).numpy()
+ pred_keypoints3d = np.concatenate(
+ (pred_keypoints3d, pred_vertices[:,
+ extra_joints_idxs]),
+ axis=1)
+
+ idxs = list(range(0, gt_keypoints3d.shape[1]))
+ if body_part == 'right_hand':
+ idxs = get_keypoint_idxs_by_part(
+ 'right_hand', self.convention)
+ idxs.append(
+ get_keypoint_idx('right_wrist', self.convention))
+ elif body_part == 'left_hand':
+ idxs = get_keypoint_idxs_by_part(
+ 'left_hand', self.convention)
+ idxs.append(
+ get_keypoint_idx('left_wrist', self.convention))
+ elif body_part == 'body':
+ idxs = get_keypoint_idxs_by_part(
+ 'body', self.convention)
+ gt_keypoints3d = gt_keypoints3d[:, idxs]
+ pred_keypoints3d = pred_keypoints3d[:, idxs]
+ gt_keypoints3d_mask = np.ones(
+ (len(pred_keypoints3d), gt_keypoints3d.shape[1]))
+ else:
+ gt_keypoints3d = self.human_data['keypoints3d'][:, :, :3]
+ gt_keypoints3d_mask = np.ones(
+ (len(pred_keypoints3d), gt_keypoints3d.shape[1]))
+
+ if gt_keypoints3d.shape[1] == 17:
+ # SMPLX_to_J14
+ assert pred_keypoints3d.shape[1] == 14
+ H36M_TO_J17 = [
+ 6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9
+ ]
+ H36M_TO_J14 = H36M_TO_J17[:14]
+ joint_mapper = H36M_TO_J14
+ gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :]
+ pred_pelvis = pred_keypoints3d[:,
+ [2, 3], :].mean(axis=1,
+ keepdims=True)
+ gt_pelvis = gt_keypoints3d[:, [2, 3], :].mean(axis=1,
+ keepdims=True)
+ gt_keypoints3d_mask = gt_keypoints3d_mask[:, joint_mapper]
+ pred_keypoints3d = pred_keypoints3d - pred_pelvis
+ gt_keypoints3d = gt_keypoints3d - gt_pelvis
+ elif gt_keypoints3d.shape[1] == 14:
+ assert pred_keypoints3d.shape[1] == 14
+ pred_pelvis = pred_keypoints3d[:,
+ [2, 3], :].mean(axis=1,
+ keepdims=True)
+ gt_pelvis = gt_keypoints3d[:, [2, 3], :].mean(axis=1,
+ keepdims=True)
+ pred_keypoints3d = pred_keypoints3d - pred_pelvis
+ gt_keypoints3d = gt_keypoints3d - gt_pelvis
+ elif gt_keypoints3d.shape[1] == 21:
+ pred_pelvis = pred_keypoints3d[:, :1, :]
+ gt_pelvis = gt_keypoints3d[:, :1, :]
+ pred_keypoints3d = pred_keypoints3d - pred_pelvis
+ gt_keypoints3d = gt_keypoints3d - gt_pelvis
+ else:
+ pass
+
+ pred_keypoints3d = pred_keypoints3d * 1000
+ if self.dataset_name != 'stirling':
+ gt_keypoints3d = gt_keypoints3d * 1000
+ gt_keypoints3d_mask = gt_keypoints3d_mask > 0
+
+ return pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask
+
+ def _report_3d_rmse(self, res_file):
+ """compute the 3DRMSE between a predicted 3D face shape and the 3D
+ ground truth scan."""
+ pred_vertices, gt_vertices, _ = self._parse_result(res_file,
+ mode='vertice')
+ pred_keypoints3d, gt_keypoints3d, _ = self._parse_result(
+ res_file, mode='keypoint')
+ errors = []
+ for pred_vertice, gt_vertice, pred_points, gt_points in zip(
+ pred_vertices, gt_vertices, pred_keypoints3d, gt_keypoints3d):
+ error = fg_vertices_to_mesh_distance(gt_vertice, gt_points,
+ pred_vertice,
+ self.body_model.faces,
+ pred_points)
+ errors.append(error)
+
+ error = np.array(errors).mean()
+ name_value_tuples = [('3DRMSE', error)]
+ return name_value_tuples
+
+ def evaluate(self,
+ outputs: list,
+ res_folder: str,
+ metric: Optional[Union[str, List[str]]] = 'pa-mpjpe',
+ **kwargs: dict):
+ """Evaluate 3D keypoint results.
+
+ Args:
+ outputs (list): results from model inference.
+ res_folder (str): path to store results.
+ metric (Optional[Union[str, List(str)]]):
+ the type of metric. Default: 'pa-mpjpe'
+ kwargs (dict): other arguments.
+ Returns:
+ dict:
+ A dict of all evaluation results.
+ """
+ metrics = metric if isinstance(metric, list) else [metric]
+ for metric in metrics:
+ if metric not in self.ALLOWED_METRICS:
+ raise KeyError(f'metric {metric} is not supported')
+
+ # for keeping correctness during multi-gpu test, we sort all results
+ res_dict = {}
+ for out in outputs:
+ target_id = out['image_idx']
+ batch_size = len(out['keypoints_3d'])
+ for i in range(batch_size):
+ res_dict[int(target_id[i])] = dict(
+ keypoints=out['keypoints_3d'][i],
+ vertices=out['vertices'][i],
+ )
+ keypoints, vertices = [], []
+ for i in range(self.num_data):
+ keypoints.append(res_dict[i]['keypoints'])
+ vertices.append(res_dict[i]['vertices'])
+ keypoints = np.stack(keypoints)
+ vertices = np.stack(vertices)
+ res = dict(keypoints=keypoints, vertices=vertices)
+ name_value_tuples = []
+ for index, _metric in enumerate(metrics):
+ if 'body_part' in kwargs:
+ body_parts = kwargs['body_part'][index]
+ for body_part in body_parts:
+ if _metric == 'pa-mpjpe':
+ _nv_tuples = self._report_mpjpe(res,
+ metric='pa-mpjpe',
+ body_part=body_part)
+ elif _metric == 'pa-pve':
+ _nv_tuples = self._report_pve(res,
+ metric='pa-pve',
+ body_part=body_part)
+ else:
+ raise NotImplementedError
+ name_value_tuples.extend(_nv_tuples)
+ else:
+ if _metric == 'mpjpe':
+ _nv_tuples = self._report_mpjpe(res)
+ elif _metric == 'pa-mpjpe':
+ _nv_tuples = self._report_mpjpe(res, metric='pa-mpjpe')
+ elif _metric == '3dpck':
+ _nv_tuples = self._report_3d_pck(res)
+ elif _metric == 'pa-3dpck':
+ _nv_tuples = self._report_3d_pck(res, metric='pa-3dpck')
+ elif _metric == '3dauc':
+ _nv_tuples = self._report_3d_auc(res)
+ elif _metric == 'pa-3dauc':
+ _nv_tuples = self._report_3d_auc(res, metric='pa-3dauc')
+ elif _metric == 'pve':
+ _nv_tuples = self._report_pve(res)
+ elif _metric == 'pa-pve':
+ _nv_tuples = self._report_pve(res, metric='pa-pve')
+ elif _metric == '3DRMSE':
+ _nv_tuples = self._report_3d_rmse(res)
+ else:
+ raise NotImplementedError
+ name_value_tuples.extend(_nv_tuples)
+ name_value = OrderedDict(name_value_tuples)
+ return name_value
diff --git a/detrsmpl/data/datasets/human_video_dataset.py b/detrsmpl/data/datasets/human_video_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..240cbfa7d786ba8c0c30e95f6d8982fcb80fc1e8
--- /dev/null
+++ b/detrsmpl/data/datasets/human_video_dataset.py
@@ -0,0 +1,164 @@
+import copy
+from typing import Optional, Union
+
+import numpy as np
+import torch
+from mmcv.parallel import DataContainer as DC
+from skimage.util.shape import view_as_windows
+
+from .builder import DATASETS
+from .human_image_dataset import HumanImageDataset
+
+
+def get_vid_name(image_path: str):
+ """Get base_dir of the given path."""
+ content = image_path.split('/')
+ vid_name = '/'.join(content[:-1])
+ return vid_name
+
+
+def split_into_chunks(data_infos: list, seq_len: int, stride: int,
+ test_mode: bool, only_vid_name: bool):
+ """Split annotations into chunks.
+ Adapted from https://github.com/mkocabas/VIBE
+ Args:
+ data_infos (list): parsed annotations.
+ seq_len (int): the length of each chunk.
+ stride (int): the interval between chunks.
+ test_mode (bool): if test_mode is true, then an additional chunk
+ will be added to cover all frames. Otherwise, last few frames
+ will be dropped.
+ only_vid_name (bool): if only_vid_name is true, image_path only
+ contains the video name. Otherwise, image_path contains both
+ video_name and frame index.
+
+ Return:
+ list:
+ shape: [N, 4]. Each chunk contains four parameters: start_frame,
+ end_frame, valid_start_frame, valid_end_frame. The last two
+ parameters are used to suppress redundant frames.
+ """
+ vid_names = []
+ for image_path in data_infos:
+ if only_vid_name:
+ vid_name = image_path
+ else:
+ vid_name = get_vid_name(image_path)
+ vid_names.append(vid_name)
+ vid_names = np.array(vid_names)
+ video_start_end_indices = []
+
+ video_names, group = np.unique(vid_names, return_index=True)
+ perm = np.argsort(group)
+ video_names, group = video_names[perm], group[perm]
+
+ indices = np.split(np.arange(0, vid_names.shape[0]), group[1:])
+
+ for idx in range(len(video_names)):
+ indexes = indices[idx]
+ if indexes.shape[0] < seq_len:
+ continue
+ chunks = view_as_windows(indexes, (seq_len, ), step=stride)
+ start_finish = chunks[:, (0, -1, 0, -1)].tolist()
+ video_start_end_indices += start_finish
+ if chunks[-1][-1] < indexes[-1] and test_mode:
+ start_frame = indexes[-1] - seq_len + 1
+ end_frame = indexes[-1]
+ valid_start_frame = chunks[-1][-1] + 1
+ valid_end_frame = indexes[-1]
+ extra_start_finish = [[
+ start_frame, end_frame, valid_start_frame, valid_end_frame
+ ]]
+ video_start_end_indices += extra_start_finish
+
+ return video_start_end_indices
+
+
+@DATASETS.register_module()
+class HumanVideoDataset(HumanImageDataset):
+ """Human Video Dataset.
+
+ Args:
+ data_prefix (str): the prefix of data path.
+ pipeline (list): a list of dict, where each element represents
+ a operation defined in `mmhuman3d.datasets.pipelines`.
+ dataset_name (str | None): the name of dataset. It is used to
+ identify the type of evaluation metric. Default: None.
+ seq_len (int, optional): the length of input sequence. Default: 16.
+ overlap (float, optional): the overlap between different sequences.
+ Default: 0
+ only_vid_name (bool, optional): the format of image_path.
+ If only_vid_name is true, image_path only contains the video
+ name. Otherwise, image_path contains both video_name and frame
+ index.
+ body_model (dict | None, optional): the config for body model,
+ which will be used to generate meshes and keypoints.
+ Default: None.
+ ann_file (str | None, optional): the annotation file. When ann_file
+ is str, the subclass is expected to read from the ann_file. When
+ ann_file is None, the subclass is expected to read according
+ to data_prefix.
+ convention (str, optional): keypoints convention. Keypoints will be
+ converted from "human_data" to the given one.
+ Default: "human_data"
+ test_mode (bool, optional): in train mode or test mode. Default: False.
+ """
+ def __init__(self,
+ data_prefix: str,
+ pipeline: list,
+ dataset_name: str,
+ seq_len: Optional[int] = 16,
+ overlap: Optional[float] = 0.,
+ only_vid_name: Optional[bool] = False,
+ body_model: Optional[Union[dict, None]] = None,
+ ann_file: Optional[Union[str, None]] = None,
+ convention: Optional[str] = 'human_data',
+ test_mode: Optional[bool] = False):
+ super(HumanVideoDataset, self).__init__(data_prefix=data_prefix,
+ pipeline=pipeline,
+ dataset_name=dataset_name,
+ body_model=body_model,
+ convention=convention,
+ ann_file=ann_file,
+ test_mode=test_mode)
+ self.seq_len = seq_len
+ self.stride = int(seq_len * (1 - overlap))
+ self.vid_indices = split_into_chunks(self.human_data['image_path'],
+ self.seq_len, self.stride,
+ test_mode, only_vid_name)
+ self.vid_indices = np.array(self.vid_indices)
+
+ def __len__(self):
+ return len(self.vid_indices)
+
+ def prepare_data(self, idx: int):
+ """Prepare data for each chunk.
+
+ Step 1: get annotation from each frame. Step 2: add metas of each
+ chunk.
+ """
+ start_idx, end_idx = self.vid_indices[idx][:2]
+ batch_results = []
+ image_path = []
+ for frame_idx in range(start_idx, end_idx + 1):
+ frame_results = copy.deepcopy(self.prepare_raw_data(frame_idx))
+ image_path.append(frame_results.pop('image_path'))
+ if 'features' in self.human_data:
+ frame_results['features'] = \
+ copy.deepcopy(self.human_data['features'][frame_idx])
+ frame_results = self.pipeline(frame_results)
+ batch_results.append(frame_results)
+ video_results = {}
+ for key in batch_results[0].keys():
+ batch_anno = []
+ for item in batch_results:
+ batch_anno.append(item[key])
+ if isinstance(batch_anno[0], torch.Tensor):
+ batch_anno = torch.stack(batch_anno, dim=0)
+ video_results[key] = batch_anno
+ img_metas = {
+ 'frame_idx': self.vid_indices[idx],
+ 'image_path': image_path
+ }
+ video_results['img_metas'] = DC(img_metas, cpu_only=True)
+ return video_results
diff --git a/detrsmpl/data/datasets/mesh_dataset.py b/detrsmpl/data/datasets/mesh_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a162e066b640f2d8c47409b867cd096b0be1fbe
--- /dev/null
+++ b/detrsmpl/data/datasets/mesh_dataset.py
@@ -0,0 +1,63 @@
+import os
+from abc import ABCMeta
+from typing import Optional, Union
+
+import numpy as np
+
+from .base_dataset import BaseDataset
+from .builder import DATASETS
+
+
+@DATASETS.register_module()
+class MeshDataset(BaseDataset, metaclass=ABCMeta):
+ """Mesh Dataset. This dataset only contains smpl data.
+
+ Args:
+ data_prefix (str): the prefix of data path.
+ pipeline (list): a list of dict, where each element represents
+ a operation defined in `detrsmpl.datasets.pipelines`.
+ dataset_name (str | None): the name of dataset. It is used to
+ identify the type of evaluation metric. Default: None.
+ ann_file (str | None, optional): the annotation file. When ann_file
+ is str, the subclass is expected to read from the ann_file. When
+ ann_file is None, the subclass is expected to read according
+ to data_prefix.
+ test_mode (bool, optional): in train mode or test mode. Default: False.
+ """
+ def __init__(self,
+ data_prefix: str,
+ pipeline: list,
+ dataset_name: str,
+ ann_file: Optional[Union[str, None]] = None,
+ test_mode: Optional[bool] = False):
+ self.dataset_name = dataset_name
+ super(MeshDataset, self).__init__(data_prefix=data_prefix,
+ pipeline=pipeline,
+ ann_file=ann_file,
+ test_mode=test_mode)
+
+ def get_annotation_file(self):
+ ann_prefix = os.path.join(self.data_prefix, 'preprocessed_datasets')
+ self.ann_file = os.path.join(ann_prefix, self.ann_file)
+
+ def load_annotations(self):
+
+ self.get_annotation_file()
+ data = np.load(self.ann_file, allow_pickle=True)
+
+ self.smpl = data['smpl'].item()
+ num_data = self.smpl['global_orient'].shape[0]
+ if 'transl' not in self.smpl:
+ self.smpl['transl'] = np.zeros((num_data, 3))
+ self.has_smpl = np.ones((num_data))
+
+ data_infos = []
+
+ for idx in range(num_data):
+ info = {}
+ for k, v in self.smpl.items():
+ info['smpl_' + k] = v[idx]
+
+ data_infos.append(info)
+ self.num_data = len(data_infos)
+ self.data_infos = data_infos
diff --git a/detrsmpl/data/datasets/mixed_dataset.py b/detrsmpl/data/datasets/mixed_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..31558f05aa8b0b7447b3d3b6d6c56a3e4a0b346a
--- /dev/null
+++ b/detrsmpl/data/datasets/mixed_dataset.py
@@ -0,0 +1,48 @@
+from typing import Optional, Union
+
+import numpy as np
+from torch.utils.data import ConcatDataset, Dataset, WeightedRandomSampler
+
+from .builder import DATASETS, build_dataset
+
+
+@DATASETS.register_module()
+class MixedDataset(Dataset):
+ """Mixed Dataset.
+
+ Args:
+ config (list): the list of different datasets.
+ partition (list): the ratio of datasets in each batch.
+ num_data (int | None, optional): if num_data is not None, the number
+ of iterations is set to this fixed value. Otherwise, the number of
+ iterations is set to the maximum size of each single dataset.
+ Default: None.
+ """
+ def __init__(self,
+ configs: list,
+ partition: list,
+ num_data: Optional[Union[int, None]] = None):
+ """Load data from multiple datasets."""
+ assert min(partition) >= 0
+ datasets = [build_dataset(cfg) for cfg in configs]
+ self.dataset = ConcatDataset(datasets)
+ if num_data is not None:
+ self.length = num_data
+ else:
+ self.length = max(len(ds) for ds in datasets)
+ weights = [
+ np.ones(len(ds)) * p / len(ds)
+ for (p, ds) in zip(partition, datasets)
+ ]
+ weights = np.concatenate(weights, axis=0)
+ self.sampler = WeightedRandomSampler(weights, 1)
+
+ def __len__(self):
+ """Get the size of the dataset."""
+ return self.length
+
+ def __getitem__(self, idx):
+ """Given index, sample the data from multiple datasets with the given
+ proportion."""
+ idx_new = list(self.sampler)[0]
+ return self.dataset[idx_new]
diff --git a/detrsmpl/data/datasets/multi_human_image_dataset.py b/detrsmpl/data/datasets/multi_human_image_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..426e363653ca87c5b93e284fee1843a1b15f9593
--- /dev/null
+++ b/detrsmpl/data/datasets/multi_human_image_dataset.py
@@ -0,0 +1,757 @@
+import json
+import os
+import os.path
+from abc import ABCMeta
+from collections import OrderedDict
+from typing import Any, List, Optional, Union
+
+import mmcv
+import numpy as np
+import torch
+import torch.distributed as dist
+from mmcv.runner import get_dist_info
+
+from detrsmpl.core.conventions.keypoints_mapping import (
+ convert_kps,
+ get_keypoint_num,
+ get_mapping,
+)
+
+from detrsmpl.core.evaluation import (
+ keypoint_3d_auc,
+ keypoint_3d_pck,
+ keypoint_mpjpe,
+ vertice_pve,
+)
+
+from detrsmpl.data.data_structures.multi_human_data import MultiHumanData
+from detrsmpl.models.body_models.builder import build_body_model
+from .base_dataset import BaseDataset
+from .builder import DATASETS
+
+
+@DATASETS.register_module()
+class MultiHumanImageDataset(BaseDataset, metaclass=ABCMeta):
+ def __init__(self,
+ data_prefix: str,
+ pipeline: list,
+ body_model: Optional[Union[dict, None]] = None,
+ ann_file: Optional[Union[str, None]] = None,
+ convention: Optional[str] = 'human_data',
+ test_mode: Optional[bool] = False,
+ dataset_name: Optional[Union[str, None]] = None):
+ self.num_keypoints = get_keypoint_num(convention)
+ self.convention = convention
+ super(MultiHumanImageDataset,
+ self).__init__(data_prefix, pipeline, ann_file, test_mode,
+ dataset_name)
+
+ if body_model is not None:
+ self.body_model = build_body_model(body_model)
+ else:
+ self.body_model = None
+
+ def get_annotation_file(self):
+ """Get path of the annotation file."""
+ ann_prefix = os.path.join(self.data_prefix, 'preprocessed_datasets')
+ self.ann_file = os.path.join(ann_prefix, self.ann_file)
+
+ def load_annotations(self):
+ """Load annotations."""
+ self.get_annotation_file()
+ self.human_data = MultiHumanData()
+ self.human_data.load(self.ann_file)
+
+ self.instance_num = self.human_data.instance_num
+ self.image_path = self.human_data['image_path']
+ self.num_data = self.human_data.data_len
+
+ try:
+ self.frame_range = self.human_data['frame_range']
+ except KeyError:
+ self.frame_range = \
+ np.array([[i, i + 1] for i in range(self.num_data)])
+
+ self.num_data = self.frame_range.shape[0]
+ if self.human_data.check_keypoints_compressed():
+ self.human_data.decompress_keypoints()
+
+ # change keypoint from 'human_data' to the given convention
+ if 'keypoints3d_ori' in self.human_data:
+ keypoints3d_ori = self.human_data['keypoints3d_ori']
+ assert 'keypoints3d_ori_mask' in self.human_data
+ keypoints3d_ori_mask = self.human_data['keypoints3d_ori_mask']
+ keypoints3d_ori, keypoints3d_ori_mask = \
+ convert_kps(
+ keypoints3d_ori,
+ src='human_data',
+ dst=self.convention,
+ mask=keypoints3d_ori_mask)
+ self.human_data.__setitem__('keypoints3d_ori', keypoints3d_ori)
+ self.human_data.__setitem__('keypoints3d_ori_convention',
+ self.convention)
+ self.human_data.__setitem__('keypoints3d_ori_mask',
+ keypoints3d_ori_mask)
+ elif 'keypoints3d' in self.human_data:
+ keypoints3d_ori = self.human_data['keypoints3d']
+ assert 'keypoints3d_mask' in self.human_data
+ keypoints3d_ori_mask = self.human_data['keypoints3d_mask']
+ keypoints3d_ori, keypoints3d_ori_mask = \
+ convert_kps(
+ keypoints3d_ori,
+ src='human_data',
+ dst=self.convention,
+ mask=keypoints3d_ori_mask)
+ self.human_data.__setitem__('keypoints3d_ori', keypoints3d_ori)
+ self.human_data.__setitem__('keypoints3d_ori_convention',
+ self.convention)
+ self.human_data.__setitem__('keypoints3d_ori_mask',
+ keypoints3d_ori_mask)
+
+ if 'keypoints2d_ori' in self.human_data:
+ keypoints2d_ori = self.human_data['keypoints2d_ori']
+ assert 'keypoints2d_ori_mask' in self.human_data
+ keypoints2d_ori_mask = self.human_data['keypoints2d_ori_mask']
+ keypoints2d_ori, keypoints2d_ori_mask = \
+ convert_kps(
+ keypoints2d_ori,
+ src='human_data',
+ dst=self.convention,
+ mask=keypoints2d_ori_mask)
+ self.human_data.__setitem__('keypoints2d_ori', keypoints2d_ori)
+ self.human_data.__setitem__('keypoints2d_ori_convention',
+ self.convention)
+ self.human_data.__setitem__('keypoints2d_ori_mask',
+ keypoints2d_ori_mask)
+ ori_mask = keypoints2d_ori[:, :, 2]
+ elif 'keypoints2d' in self.human_data:
+ keypoints2d_ori = self.human_data['keypoints2d']
+ assert 'keypoints2d_mask' in self.human_data
+ keypoints2d_ori_mask = self.human_data['keypoints2d_mask']
+ keypoints2d_ori, keypoints2d_ori_mask = \
+ convert_kps(
+ keypoints2d_ori,
+ src='human_data',
+ dst=self.convention,
+ mask=keypoints2d_ori_mask)
+ self.human_data.__setitem__('keypoints2d_ori', keypoints2d_ori)
+ self.human_data.__setitem__('keypoints2d_ori_convention',
+ self.convention)
+ self.human_data.__setitem__('keypoints2d_ori_mask',
+ keypoints2d_ori_mask)
+
+ # if 'has_smpl' in self.human_data:
+ # index = ori_mask.sum(-1)>=8
+ # self.human_data['has_smpl']=self.human_data['has_smpl'][:147270]*index
+ # change keypoint from 'human_data' to the given convention
+ if 'keypoints3d_smpl' in self.human_data:
+ keypoints3d_smpl = self.human_data['keypoints3d_smpl']
+ assert 'keypoints3d_smpl_mask' in self.human_data
+ keypoints3d_smpl_mask = self.human_data['keypoints3d_smpl_mask']
+ keypoints3d_smpl, keypoints3d_smpl_mask = \
+ convert_kps(
+ keypoints3d_smpl,
+ src='human_data',
+ dst=self.convention,
+ mask=keypoints3d_smpl_mask)
+ # index = ori_mask.sum(-1)<8
+ # index = ori_mask.sum(-1)<8
+ # keypoints3d_smpl[index]=np.concatenate(
+ # [keypoints3d_smpl[index][:,:,:3],
+ # keypoints2d_ori[index][:,:,[2]]],
+ # -1)
+ self.human_data.__setitem__('keypoints3d_smpl', keypoints3d_smpl)
+ self.human_data.__setitem__('keypoints3d_smpl_convention',
+ self.convention)
+ self.human_data.__setitem__('keypoints3d_smpl_mask',
+ keypoints3d_smpl_mask)
+
+ if 'keypoints2d_smpl' in self.human_data:
+ keypoints2d_smpl = self.human_data['keypoints2d_smpl']
+ assert 'keypoints2d_smpl_mask' in self.human_data
+ keypoints2d_smpl_mask = self.human_data['keypoints2d_smpl_mask']
+ keypoints2d_smpl, keypoints2d_smpl_mask = \
+ convert_kps(
+ keypoints2d_smpl,
+ src='human_data',
+ dst=self.convention,
+ mask=keypoints2d_smpl_mask)
+ # index = ori_mask.sum(-1)<8
+ # keypoints2d_smpl[index]=np.concatenate(
+ # [keypoints2d_smpl[index][:,:,:2],
+ # keypoints2d_ori[index][:,:,[2]]],
+ # -1)
+ # keypoints2d_smpl[index][:,:,2]=keypoints2d_ori[index][:, :,2]
+ self.human_data.__setitem__('keypoints2d_smpl', keypoints2d_smpl)
+ self.human_data.__setitem__('keypoints2d_smpl_convention',
+ self.convention)
+ self.human_data.__setitem__('keypoints2d_smpl_mask',
+ keypoints2d_smpl_mask)
+ self.human_data.compress_keypoints_by_mask()
+
+
+
+ def prepare_raw_data(self, idx: int):
+ """Get item from self.human_data."""
+ sample_idx = idx
+ frame_start, frame_end = self.frame_range[idx]
+ frame_num = frame_end - frame_start
+ # TODO: Support cache_reader?
+ info = {}
+ info['img_prefix'] = None
+ image_path = self.human_data['image_path'][frame_start]
+ info['image_path'] = os.path.join(self.data_prefix, 'datasets',
+ self.dataset_name, image_path)
+ # TODO: Support smc?
+ info['dataset_name'] = self.dataset_name
+ info['sample_idx'] = sample_idx
+ if 'bbox_xywh' in self.human_data:
+ info['bbox_xywh'] = self.human_data['bbox_xywh'][
+ frame_start:frame_end]
+ center, scale = [], []
+ for bbox in info['bbox_xywh']:
+ x, y, w, h, s = bbox
+ cx = x + w / 2
+ cy = y + h / 2
+ # TODO: verify if we should keep w = h = max(w, h) for multi human data
+ w = h = max(w, h)
+ center.append([cx, cy])
+ scale.append([w, h])
+ info['center'] = np.array(center)
+ info['scale'] = np.array(scale)
+ else:
+ info['bbox_xywh'] = np.zeros((frame_num, 5))
+ info['center'] = np.zeros((frame_num, 2))
+ info['scale'] = np.zeros((frame_num, 2))
+
+ if 'keypoints2d_ori' in self.human_data:
+ info['keypoints2d_ori'] = self.human_data['keypoints2d_ori'][
+ frame_start:frame_end]
+ conf = info['keypoints2d_ori'][..., -1].sum(-1) > 0
+ info['has_keypoints2d_ori'] = np.ones(
+ (frame_num, 1)) * conf[..., None]
+ else:
+ info['keypoints2d_ori'] = np.zeros(
+ (frame_num, self.num_keypoints, 3))
+ info['has_keypoints2d_ori'] = np.zeros((frame_num, 1))
+
+ if 'keypoints3d_ori' in self.human_data:
+ info['keypoints3d_ori'] = self.human_data['keypoints3d_ori'][
+ frame_start:frame_end]
+ conf = info['keypoints3d_ori'][..., -1].sum(-1) > 0
+ info['has_keypoints3d_ori'] = np.ones(
+ (frame_num, 1)) * conf[..., None]
+ else:
+ info['keypoints3d_ori'] = np.zeros(
+ (frame_num, self.num_keypoints, 4))
+ info['has_keypoints3d_ori'] = np.zeros((frame_num, 1))
+
+ if 'keypoints2d_smpl' in self.human_data:
+ info['keypoints2d_smpl'] = self.human_data['keypoints2d_smpl'][
+ frame_start:frame_end]
+ conf = info['keypoints2d_smpl'][..., -1].sum(-1) > 0
+ info['has_keypoints2d_smpl'] = np.ones(
+ (frame_num, 1)) * conf[..., None]
+ else:
+ info['keypoints2d_smpl'] = np.zeros(
+ (frame_num, self.num_keypoints, 3))
+ info['has_keypoints2d_smpl'] = np.zeros((frame_num, 1))
+
+ if 'keypoints3d_smpl' in self.human_data:
+ info['keypoints3d_smpl'] = self.human_data['keypoints3d_smpl'][
+ frame_start:frame_end]
+ conf = info['keypoints3d_smpl'][..., -1].sum(-1) > 0
+ info['has_keypoints3d_smpl'] = np.ones(
+ (frame_num, 1)) * conf[..., None]
+ else:
+ info['keypoints3d_smpl'] = np.zeros(
+ (frame_num, self.num_keypoints, 4))
+ info['has_keypoints3d_smpl'] = np.zeros((frame_num, 1))
+
+ if 'smpl' in self.human_data:
+ if 'has_smpl' in self.human_data:
+ info['has_smpl'] = \
+ self.human_data['has_smpl'][frame_start:frame_end]
+ else:
+ info['has_smpl'] = np.ones((frame_num, 1))
+ smpl_dict = self.human_data['smpl']
+ else:
+ info['has_smpl'] = np.zeros((frame_num, 1))
+ smpl_dict = {}
+
+ if 'body_pose' in smpl_dict:
+ info['smpl_body_pose'] = smpl_dict['body_pose'][
+ frame_start:frame_end]
+ else:
+ info['smpl_body_pose'] = np.zeros((frame_num, 23, 3))
+
+ if 'global_orient' in smpl_dict:
+ info['smpl_global_orient'] = smpl_dict['global_orient'][
+ frame_start:frame_end]
+ else:
+ info['smpl_global_orient'] = np.zeros((frame_num, 3))
+
+ if 'betas' in smpl_dict:
+ info['smpl_betas'] = smpl_dict['betas'][frame_start:frame_end]
+ else:
+ info['smpl_betas'] = np.zeros((frame_num, 10))
+
+ if 'transl' in smpl_dict:
+ info['smpl_transl'] = smpl_dict['transl'][frame_start:frame_end]
+ else:
+ info['smpl_transl'] = np.zeros((frame_num, 3))
+
+ if 'area' in self.human_data:
+ info['area'] = self.human_data['area'][frame_start:frame_end]
+ else:
+ info['area'] = np.zeros((frame_num, 0))
+
+
+ return info
+
+ def prepare_data(self, idx: int):
+ """Generate and transform data."""
+ info = self.prepare_raw_data(idx)
+ return self.pipeline(info)
+
+ def evaluate(self,
+ outputs: list,
+ res_folder: str,
+ metric: Optional[Union[str, List[str]]] = 'pa-mpjpe',
+ **kwargs: dict):
+ """Evaluate 3D keypoint results.
+
+ Args:
+ outputs (list): results from model inference.
+ res_folder (str): path to store results.
+ metric (Optional[Union[str, List(str)]]):
+ the type of metric. Default: 'pa-mpjpe'
+ kwargs (dict): other arguments.
+ Returns:
+ dict:
+ A dict of all evaluation results.
+ """
+ metrics = metric if isinstance(metric, list) else [metric]
+ for metric in metrics:
+ if metric not in self.ALLOWED_METRICS:
+ raise KeyError(f'metric {metric} is not supported')
+
+ res_file = os.path.join(res_folder, 'result_keypoints.json')
+ # for keeping correctness during multi-gpu test, we sort all results
+ res_dict = {}
+ # 'scores', 'labels', 'boxes', 'keypoints', 'pred_smpl_pose',
+ # 'pred_smpl_beta', 'pred_smpl_cam', 'pred_smpl_kp3d',
+ # 'gt_smpl_pose', 'gt_smpl_beta', 'gt_smpl_kp3d', 'gt_boxes',
+ # 'gt_keypoints', 'image_idx'
+ for out in outputs:
+ target_id = out['image_idx']
+ batch_size = len(out['pred_smpl_kp3d'])
+ for i in range(batch_size):
+ res_dict[int(target_id[i])] = dict(
+ keypoints=out['pred_smpl_kp3d'][i],
+ gt_poses=out['gt_smpl_pose'][i],
+ gt_betas=out['gt_smpl_beta'][i],
+ pred_poses=out['pred_smpl_pose'][i],
+ pred_betas=out['pred_smpl_beta'][i])
+ keypoints, gt_poses, gt_betas, pred_poses, pred_betas = \
+ [], [], [], [], []
+ # print(self.num_data)
+ for i in range(self.num_data):
+ keypoints.append(res_dict[i]['keypoints'])
+ gt_poses.append(res_dict[i]['gt_poses'])
+ gt_betas.append(res_dict[i]['gt_betas'])
+ pred_poses.append(res_dict[i]['pred_poses'])
+ pred_betas.append(res_dict[i]['pred_betas'])
+
+ res = dict(keypoints=keypoints,
+ gt_poses=gt_poses,
+ gt_betas=gt_betas,
+ pred_poses=pred_poses,
+ pred_betas=pred_betas)
+ # mmcv.dump(res, res_file)
+ name_value_tuples = []
+ for _metric in metrics:
+ if _metric == 'mpjpe':
+ _nv_tuples = self._report_mpjpe(res)
+ elif _metric == 'pa-mpjpe':
+ _nv_tuples = self._report_mpjpe(res, metric='pa-mpjpe')
+ print(_nv_tuples)
+ elif _metric == '3dpck':
+ _nv_tuples = self._report_3d_pck(res)
+ elif _metric == 'pa-3dpck':
+ _nv_tuples = self._report_3d_pck(res, metric='pa-3dpck')
+ elif _metric == '3dauc':
+ _nv_tuples = self._report_3d_auc(res)
+ elif _metric == 'pa-3dauc':
+ _nv_tuples = self._report_3d_auc(res, metric='pa-3dauc')
+ elif _metric == 'pve':
+ _nv_tuples = self._report_pve(res)
+ elif _metric == 'ihmr':
+ _nv_tuples = self._report_ihmr(res)
+ else:
+ raise NotImplementedError
+ name_value_tuples.extend(_nv_tuples)
+
+ name_value = OrderedDict(name_value_tuples)
+ return name_value
+
+ @staticmethod
+ def _write_keypoint_results(keypoints: Any, res_file: str):
+ """Write results into a json file."""
+
+ with open(res_file, 'w') as f:
+ json.dump(keypoints, f, sort_keys=True, indent=4)
+
+ def _parse_result(self, res, mode='keypoint', body_part=None):
+ """Parse results."""
+
+ if mode == 'vertice':
+ # gt
+ gt_beta, gt_pose, gt_global_orient, gender = [], [], [], []
+ gt_smpl_dict = self.human_data['smpl']
+ for idx in range(self.num_data):
+ gt_beta.append(gt_smpl_dict['betas'][idx])
+ gt_pose.append(gt_smpl_dict['body_pose'][idx])
+ gt_global_orient.append(gt_smpl_dict['global_orient'][idx])
+ if self.human_data['meta']['gender'][idx] == 'm':
+ gender.append(0)
+ else:
+ gender.append(1)
+ gt_beta = torch.FloatTensor(gt_beta)
+ gt_pose = torch.FloatTensor(gt_pose).view(-1, 69)
+ gt_global_orient = torch.FloatTensor(gt_global_orient)
+ gender = torch.Tensor(gender)
+ gt_output = self.body_model(betas=gt_beta,
+ body_pose=gt_pose,
+ global_orient=gt_global_orient,
+ gender=gender)
+ gt_vertices = gt_output['vertices'].detach().cpu().numpy() * 1000.
+ gt_mask = np.ones(gt_vertices.shape[:-1])
+ # pred
+ pred_pose = torch.FloatTensor(res['pred_poses'])
+ pred_beta = torch.FloatTensor(res['pred_betas'])
+ pred_output = self.body_model(
+ betas=pred_beta[:, 0],
+ body_pose=pred_pose[:, 0, 1:],
+ global_orient=pred_pose[:, 0, 0].unsqueeze(1),
+ pose2rot=False)
+ pred_vertices = pred_output['vertices'].detach().cpu().numpy(
+ ) * 1000.
+
+ assert len(pred_vertices) == self.num_data
+
+ return pred_vertices, gt_vertices, gt_mask
+ elif mode == 'keypoint':
+ pred_keypoints3d = res['keypoints']
+ assert len(pred_keypoints3d) == self.num_data
+ # (B, 17, 3)
+ pred_keypoints3d = np.array(pred_keypoints3d).reshape(
+ len(pred_keypoints3d), -1, 3)
+ # pred_keypoints3d,_ = convert_kps(
+ # pred_keypoints3d,
+ # src='smpl_54',
+ # dst='h36m',
+ # )
+
+ gt_smpl_pose = np.array(res['gt_poses'])
+ gt_body_pose = gt_smpl_pose[..., 1:, :]
+ gt_global_orient = gt_smpl_pose[..., 0, :]
+ gt_betas = np.array(res['gt_betas'])
+ gender = np.zeros([gt_betas.shape[0], gt_betas.shape[1]])
+ if self.dataset_name == 'pw3d':
+ # betas = []
+ # body_pose = []
+ # global_orient = []
+ # gender = []
+ # smpl_dict = self.human_data['smpl']
+
+ # for idx in range(self.num_data):
+ # betas.append(smpl_dict['betas'][idx])
+ # body_pose.append(smpl_dict['body_pose'][idx])
+ # global_orient.append(smpl_dict['global_orient'][idx])
+ # if self.human_data['meta']['gender'][idx] == 'm':
+ # gender.append(0)
+ # else:
+ # gender.append(1)
+ betas = torch.FloatTensor(gt_betas).view(-1, 10)
+ body_pose = torch.FloatTensor(gt_body_pose).view(-1, 69)
+ global_orient = torch.FloatTensor(gt_global_orient).view(-1, 3)
+ gender = torch.Tensor(gender).view(-1)
+ gt_output = self.body_model(betas=betas,
+ body_pose=body_pose,
+ global_orient=global_orient,
+ gender=gender)
+ gt_keypoints3d = gt_output['joints'].detach().cpu().numpy()
+ # gt_keypoints3d,_ = convert_kps(
+ # gt_keypoints3d,
+ # src='smpl_54',
+ # dst='h36m')
+ gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 17))
+ elif self.dataset_name == 'h36m':
+ _, h36m_idxs, _ = get_mapping('human_data', 'h36m')
+ gt_keypoints3d = \
+ self.human_data['keypoints3d'][:, h36m_idxs, :3]
+ gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 17))
+ elif self.dataset_name == 'humman':
+ betas = []
+ body_pose = []
+ global_orient = []
+ smpl_dict = self.human_data['smpl']
+ for idx in range(self.num_data):
+ betas.append(smpl_dict['betas'][idx])
+ body_pose.append(smpl_dict['body_pose'][idx])
+ global_orient.append(smpl_dict['global_orient'][idx])
+ betas = torch.FloatTensor(betas)
+ body_pose = torch.FloatTensor(body_pose).view(-1, 69)
+ global_orient = torch.FloatTensor(global_orient)
+ gt_output = self.body_model(betas=betas,
+ body_pose=body_pose,
+ global_orient=global_orient)
+ gt_keypoints3d = gt_output['joints'].detach().cpu().numpy()
+ gt_keypoints3d_mask = np.ones((len(pred_keypoints3d), 24))
+ else:
+ raise NotImplementedError()
+
+ # SMPL_49 only!
+ if gt_keypoints3d.shape[1] == 49:
+ assert pred_keypoints3d.shape[1] == 49
+
+ gt_keypoints3d = gt_keypoints3d[:, 25:, :]
+ pred_keypoints3d = pred_keypoints3d[:, 25:, :]
+
+ joint_mapper = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]
+ gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :]
+ pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :]
+
+ # we only evaluate on 14 lsp joints
+ pred_pelvis = (pred_keypoints3d[:, 2] +
+ pred_keypoints3d[:, 3]) / 2
+ gt_pelvis = (gt_keypoints3d[:, 2] + gt_keypoints3d[:, 3]) / 2
+
+ # H36M for testing!
+ elif gt_keypoints3d.shape[1] == 17:
+ assert pred_keypoints3d.shape[-2] == 17
+
+ H36M_TO_J17 = [
+ 6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9
+ ]
+ H36M_TO_J14 = H36M_TO_J17[:14]
+ joint_mapper = H36M_TO_J14
+
+ pred_pelvis = pred_keypoints3d[:, 0]
+ gt_pelvis = gt_keypoints3d[:, 0]
+
+ gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :]
+ pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :]
+
+ # keypoint 24
+ elif gt_keypoints3d.shape[1] == 24:
+ assert pred_keypoints3d.shape[1] == 24
+
+ joint_mapper = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]
+ gt_keypoints3d = gt_keypoints3d[:, joint_mapper, :]
+ pred_keypoints3d = pred_keypoints3d[:, joint_mapper, :]
+
+ # we only evaluate on 14 lsp joints
+ pred_pelvis = (pred_keypoints3d[:, 2] +
+ pred_keypoints3d[:, 3]) / 2
+ gt_pelvis = (gt_keypoints3d[:, 2] + gt_keypoints3d[:, 3]) / 2
+
+ else:
+ pass
+
+ pred_keypoints3d = (pred_keypoints3d -
+ pred_pelvis[:, None, :]) * 1000
+ gt_keypoints3d = (gt_keypoints3d - gt_pelvis[:, None, :]) * 1000
+
+ gt_keypoints3d_mask = gt_keypoints3d_mask[:, joint_mapper] > 0
+
+ return pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask
+
+ def _report_mpjpe(self, res_file, metric='mpjpe', body_part=''):
+ """Cauculate mean per joint position error (MPJPE) or its variants PA-
+ MPJPE.
+
+ Report mean per joint position error (MPJPE) and mean per joint
+ position error after rigid alignment (PA-MPJPE)
+ """
+ pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \
+ self._parse_result(res_file, mode='keypoint', body_part=body_part)
+
+ err_name = metric.upper()
+ if body_part != '':
+ err_name = body_part.upper() + ' ' + err_name
+
+ if metric == 'mpjpe':
+ alignment = 'none'
+ elif metric == 'pa-mpjpe':
+ alignment = 'procrustes'
+ else:
+ raise ValueError(f'Invalid metric: {metric}')
+
+ error = keypoint_mpjpe(pred_keypoints3d, gt_keypoints3d,
+ gt_keypoints3d_mask, alignment)
+ info_str = [(err_name, error)]
+
+ return info_str
+
+ def _report_3d_pck(self, res_file, metric='3dpck'):
+ """Cauculate Percentage of Correct Keypoints (3DPCK) w. or w/o
+ Procrustes alignment.
+ Args:
+ keypoint_results (list): Keypoint predictions. See
+ 'Body3DMpiInf3dhpDataset.evaluate' for details.
+ metric (str): Specify mpjpe variants. Supported options are:
+ - ``'3dpck'``: Standard 3DPCK.
+ - ``'pa-3dpck'``:
+ 3DPCK after aligning prediction to groundtruth
+ via a rigid transformation (scale, rotation and
+ translation).
+ """
+
+ pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \
+ self._parse_result(res_file)
+
+ err_name = metric.upper()
+ if metric == '3dpck':
+ alignment = 'none'
+ elif metric == 'pa-3dpck':
+ alignment = 'procrustes'
+ else:
+ raise ValueError(f'Invalid metric: {metric}')
+
+ error = keypoint_3d_pck(pred_keypoints3d, gt_keypoints3d,
+ gt_keypoints3d_mask, alignment)
+ name_value_tuples = [(err_name, error)]
+
+ return name_value_tuples
+
+ def _report_3d_auc(self, res_file, metric='3dauc'):
+ """Cauculate the Area Under the Curve (AUC) computed for a range of
+ 3DPCK thresholds.
+ Args:
+ keypoint_results (list): Keypoint predictions. See
+ 'Body3DMpiInf3dhpDataset.evaluate' for details.
+ metric (str): Specify mpjpe variants. Supported options are:
+ - ``'3dauc'``: Standard 3DAUC.
+ - ``'pa-3dauc'``: 3DAUC after aligning prediction to
+ groundtruth via a rigid transformation (scale, rotation and
+ translation).
+ """
+
+ pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \
+ self._parse_result(res_file)
+
+ err_name = metric.upper()
+ if metric == '3dauc':
+ alignment = 'none'
+ elif metric == 'pa-3dauc':
+ alignment = 'procrustes'
+ else:
+ raise ValueError(f'Invalid metric: {metric}')
+
+ error = keypoint_3d_auc(pred_keypoints3d, gt_keypoints3d,
+ gt_keypoints3d_mask, alignment)
+ name_value_tuples = [(err_name, error)]
+
+ return name_value_tuples
+
+ def _report_pve(self, res_file, metric='pve', body_part=''):
+ """Cauculate per vertex error."""
+ pred_verts, gt_verts, _ = \
+ self._parse_result(res_file, mode='vertice', body_part=body_part)
+ err_name = metric.upper()
+ if body_part != '':
+ err_name = body_part.upper() + ' ' + err_name
+
+ if metric == 'pve':
+ alignment = 'none'
+ elif metric == 'pa-pve':
+ alignment = 'procrustes'
+ else:
+ raise ValueError(f'Invalid metric: {metric}')
+ error = vertice_pve(pred_verts, gt_verts, alignment)
+ return [(err_name, error)]
+
+ def _report_ihmr(self, res_file):
+ """Calculate IHMR metric.
+
+ https://arxiv.org/abs/2203.16427
+ """
+ pred_keypoints3d, gt_keypoints3d, gt_keypoints3d_mask = \
+ self._parse_result(res_file, mode='keypoint')
+
+ pred_verts, gt_verts, _ = \
+ self._parse_result(res_file, mode='vertice')
+
+ from detrsmpl.utils.geometry import rot6d_to_rotmat
+ mean_param_path = 'data/body_models/smpl_mean_params.npz'
+ mean_params = np.load(mean_param_path)
+ mean_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
+ mean_shape = torch.from_numpy(
+ mean_params['shape'][:].astype('float32')).unsqueeze(0)
+ mean_pose = rot6d_to_rotmat(mean_pose).view(1, 24, 3, 3)
+ mean_output = self.body_model(betas=mean_shape,
+ body_pose=mean_pose[:, 1:],
+ global_orient=mean_pose[:, :1],
+ pose2rot=False)
+ mean_verts = mean_output['vertices'].detach().cpu().numpy() * 1000.
+ dis = (gt_verts - mean_verts) * (gt_verts - mean_verts)
+ dis = np.sqrt(dis.sum(axis=-1)).mean(axis=-1)
+ # from the most remote one to the nearest one
+ idx_order = np.argsort(dis)[::-1]
+ num_data = idx_order.shape[0]
+
+ def report_ihmr_idx(idx):
+ mpvpe = vertice_pve(pred_verts[idx], gt_verts[idx])
+ mpjpe = keypoint_mpjpe(pred_keypoints3d[idx], gt_keypoints3d[idx],
+ gt_keypoints3d_mask[idx], 'none')
+ pampjpe = keypoint_mpjpe(pred_keypoints3d[idx],
+ gt_keypoints3d[idx],
+ gt_keypoints3d_mask[idx], 'procrustes')
+ return (mpvpe, mpjpe, pampjpe)
+
+ def report_ihmr_tail(percentage):
+ cur_data = int(num_data * percentage / 100.0)
+ idx = idx_order[:cur_data]
+ mpvpe, mpjpe, pampjpe = report_ihmr_idx(idx)
+ res_mpvpe = ('bMPVPE Tail ' + str(percentage) + '%', mpvpe)
+ res_mpjpe = ('bMPJPE Tail ' + str(percentage) + '%', mpjpe)
+ res_pampjpe = ('bPA-MPJPE Tail ' + str(percentage) + '%', pampjpe)
+ return [res_mpvpe, res_mpjpe, res_pampjpe]
+
+ def report_ihmr_all(num_bin):
+ num_per_bin = np.array([0 for _ in range(num_bin)
+ ]).astype(np.float32)
+ sum_mpvpe = np.array([0
+ for _ in range(num_bin)]).astype(np.float32)
+ sum_mpjpe = np.array([0
+ for _ in range(num_bin)]).astype(np.float32)
+ sum_pampjpe = np.array([0 for _ in range(num_bin)
+ ]).astype(np.float32)
+ max_dis = dis[idx_order[0]]
+ min_dis = dis[idx_order[-1]]
+ delta = (max_dis - min_dis) / num_bin
+ for i in range(num_data):
+ idx = int((dis[i] - min_dis) / delta - 0.001)
+ res_mpvpe, res_mpjpe, res_pampjpe = report_ihmr_idx([i])
+ num_per_bin[idx] += 1
+ sum_mpvpe[idx] += res_mpvpe
+ sum_mpjpe[idx] += res_mpjpe
+ sum_pampjpe[idx] += res_pampjpe
+ for i in range(num_bin):
+ if num_per_bin[i] > 0:
+ sum_mpvpe[i] = sum_mpvpe[i] / num_per_bin[i]
+ sum_mpjpe[i] = sum_mpjpe[i] / num_per_bin[i]
+ sum_pampjpe[i] = sum_pampjpe[i] / num_per_bin[i]
+ valid_idx = np.where(num_per_bin > 0)
+ res_mpvpe = ('bMPVPE All', sum_mpvpe[valid_idx].mean())
+ res_mpjpe = ('bMPJPE All', sum_mpjpe[valid_idx].mean())
+ res_pampjpe = ('bPA-MPJPE All', sum_pampjpe[valid_idx].mean())
+ return [res_mpvpe, res_mpjpe, res_pampjpe]
+
+ metrics = []
+ metrics.extend(report_ihmr_all(num_bin=100))
+ metrics.extend(report_ihmr_tail(percentage=10))
+ metrics.extend(report_ihmr_tail(percentage=5))
+ return metrics
diff --git a/detrsmpl/data/datasets/pipelines/__init__.py b/detrsmpl/data/datasets/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bba193f4c44ef84b0d7b61965a50a5981207405d
--- /dev/null
+++ b/detrsmpl/data/datasets/pipelines/__init__.py
@@ -0,0 +1,65 @@
+from .compose import Compose
+from .formatting import (
+ Collect,
+ ImageToTensor,
+ ToNumpy,
+ ToPIL,
+ ToTensor,
+ Transpose,
+ to_tensor,
+)
+from .hybrik_transforms import (
+ GenerateHybrIKTarget,
+ HybrIKAffine,
+ HybrIKRandomFlip,
+ NewKeypointsSelection,
+ RandomDPG,
+ RandomOcclusion,
+)
+from .loading import LoadImageFromFile
+from .synthetic_occlusion_augmentation import SyntheticOcclusion
+from .transforms import (
+ BBoxCenterJitter,
+ CenterCrop,
+ ColorJitter,
+ GetRandomScaleRotation,
+ Lighting,
+ MeshAffine,
+ MeshAffineED,
+ Normalize,
+ RandomChannelNoise,
+ RandomHorizontalFlip,
+ Rotation,
+ SimulateLowRes,
+)
+
+__all__ = [
+ 'Compose',
+ 'to_tensor',
+ 'ToTensor',
+ 'ImageToTensor',
+ 'ToPIL',
+ 'ToNumpy',
+ 'Transpose',
+ 'Collect',
+ 'LoadImageFromFile',
+ 'CenterCrop',
+ 'RandomHorizontalFlip',
+ 'ColorJitter',
+ 'Lighting',
+ 'RandomChannelNoise',
+ 'GetRandomScaleRotation',
+ 'MeshAffine',
+ 'MeshAffineED',
+ 'HybrIKRandomFlip',
+ 'HybrIKAffine',
+ 'GenerateHybrIKTarget',
+ 'RandomDPG',
+ 'RandomOcclusion',
+ 'Rotation',
+ 'NewKeypointsSelection',
+ 'Normalize',
+ 'SyntheticOcclusion',
+ 'BBoxCenterJitter',
+ 'SimulateLowRes',
+]
diff --git a/detrsmpl/data/datasets/pipelines/compose.py b/detrsmpl/data/datasets/pipelines/compose.py
new file mode 100644
index 0000000000000000000000000000000000000000..23e1875bdfd6c042f403979ae855ab95b134a6fd
--- /dev/null
+++ b/detrsmpl/data/datasets/pipelines/compose.py
@@ -0,0 +1,41 @@
+from collections.abc import Sequence
+
+from mmcv.utils import build_from_cfg
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class Compose(object):
+ """Compose a data pipeline with a sequence of transforms.
+
+ Args:
+ transforms (list[dict | callable]):
+ Either config dicts of transforms or transform objects.
+ """
+ def __init__(self, transforms):
+ assert isinstance(transforms, Sequence)
+ self.transforms = []
+ for transform in transforms:
+ if isinstance(transform, dict):
+ transform = build_from_cfg(transform, PIPELINES)
+ self.transforms.append(transform)
+ elif callable(transform):
+ self.transforms.append(transform)
+ else:
+ raise TypeError('transform must be callable or a dict, but got'
+ f' {type(transform)}')
+
+ def __call__(self, data):
+ for t in self.transforms:
+ data = t(data)
+ if data is None:
+ return None
+ return data
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ for t in self.transforms:
+ format_string += f'\n {t}'
+ format_string += '\n)'
+ return format_string
diff --git a/detrsmpl/data/datasets/pipelines/formatting.py b/detrsmpl/data/datasets/pipelines/formatting.py
new file mode 100644
index 0000000000000000000000000000000000000000..8260ffff860fe98b6025288c3df91c6a33e722f7
--- /dev/null
+++ b/detrsmpl/data/datasets/pipelines/formatting.py
@@ -0,0 +1,319 @@
+from collections.abc import Sequence
+
+import mmcv
+import numpy as np
+import torch
+from mmcv.parallel import DataContainer as DC
+from PIL import Image
+
+from ..builder import PIPELINES
+
+
+def to_tensor(data):
+ """Convert objects of various python types to :obj:`torch.Tensor`.
+
+ Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
+ :class:`Sequence`, :class:`int` and :class:`float`.
+ """
+ if isinstance(data, torch.Tensor):
+ return data
+ elif isinstance(data, np.ndarray):
+ return torch.from_numpy(data)
+ elif isinstance(data, Sequence) and not mmcv.is_str(data):
+ return torch.tensor(data)
+ elif isinstance(data, int):
+ return torch.LongTensor([data])
+ elif isinstance(data, float):
+ return torch.FloatTensor([data])
+ else:
+ raise TypeError(
+ f'Type {type(data)} cannot be converted to tensor.'
+ 'Supported types are: `numpy.ndarray`, `torch.Tensor`, '
+ '`Sequence`, `int` and `float`')
+
+
+@PIPELINES.register_module()
+class ToTensor(object):
+ def __init__(self, keys):
+ self.keys = keys
+
+ def __call__(self, results):
+ for key in self.keys:
+ results[key] = to_tensor(results[key])
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class ImageToTensor(object):
+ def __init__(self, keys):
+ self.keys = keys
+
+ def __call__(self, results):
+ for key in self.keys:
+ img = results[key]
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ results[key] = to_tensor(img.transpose(2, 0, 1))
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class Transpose(object):
+ def __init__(self, keys, order):
+ self.keys = keys
+ self.order = order
+
+ def __call__(self, results):
+ for key in self.keys:
+ results[key] = results[key].transpose(self.order)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, order={self.order})'
+
+
+@PIPELINES.register_module()
+class ToPIL(object):
+ def __init__(self):
+ pass
+
+ def __call__(self, results):
+ results['img'] = Image.fromarray(results['img'])
+ return results
+
+
+@PIPELINES.register_module()
+class ToNumpy(object):
+ def __init__(self):
+ pass
+
+ def __call__(self, results):
+ results['img'] = np.array(results['img'], dtype=np.float32)
+ return results
+
+
+@PIPELINES.register_module()
+class Collect(object):
+ """Collect data from the loader relevant to the specific task.
+
+ This is usually the last stage of the data loader pipeline. Typically keys
+ is set to some subset of "img" and "gt_label".
+
+ Args:
+ keys (Sequence[str]): Keys of results to be collected in ``data``.
+ meta_keys (Sequence[str], optional): Meta keys to be converted to
+ ``mmcv.DataContainer`` and collected in ``data[img_metas]``.
+ Default: ``('filename', 'ori_shape', 'img_shape', 'flip',
+ 'flip_direction', 'img_norm_cfg')``
+
+ Returns:
+ dict: The result dict contains the following keys
+ - keys in``self.keys``
+ - ``img_metas`` if available
+ """
+ def __init__(self,
+ keys,
+ meta_keys=('filename', 'ori_filename', 'ori_shape',
+ 'img_shape', 'flip', 'flip_direction',
+ 'img_norm_cfg')):
+ self.keys = keys
+ self.meta_keys = meta_keys
+
+ def __call__(self, results):
+ data = {}
+ img_meta = {}
+ for key in self.meta_keys:
+ if key in results:
+ img_meta[key] = results[key]
+ data['img_metas'] = DC(img_meta, cpu_only=True)
+ for key in self.keys:
+ data[key] = results[key]
+ return data
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, meta_keys={self.meta_keys})'
+
+
+@PIPELINES.register_module()
+class ToDataContainer:
+ """Convert results to :obj:`mmcv.DataContainer` by given fields.
+
+ Args:
+ fields (Sequence[dict]): Each field is a dict like
+ ``dict(key='xxx', **kwargs)``. The ``key`` in result will
+ be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
+ Default: ``(dict(key='img', stack=True), dict(key='gt_bboxes'),
+ dict(key='gt_labels'))``.
+ """
+ def __init__(self,
+ fields=(dict(key='img', stack=True), dict(key='gt_bboxes'),
+ dict(key='gt_labels'))):
+ self.fields = fields
+
+ def __call__(self, results):
+ """Call function to convert data in results to
+ :obj:`mmcv.DataContainer`.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data converted to \
+ :obj:`mmcv.DataContainer`.
+ """
+
+ for field in self.fields:
+ field = field.copy()
+ key = field.pop('key')
+ results[key] = DC(results[key], **field)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(fields={self.fields})'
+
+
+@PIPELINES.register_module()
+class DefaultFormatBundle:
+ """Default formatting bundle.
+
+ It simplifies the pipeline of formatting common fields, including "img",
+ "proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg".
+ These fields are formatted as follows.
+
+ - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
+ - proposals: (1)to tensor, (2)to DataContainer
+ - gt_bboxes: (1)to tensor, (2)to DataContainer
+ - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer
+ - gt_labels: (1)to tensor, (2)to DataContainer
+ - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True)
+ - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \
+ (3)to DataContainer (stack=True)
+
+ Args:
+ img_to_float (bool): Whether to force the image to be converted to
+ float type. Default: True.
+ pad_val (dict): A dict for padding value in batch collating,
+ the default value is `dict(img=0, masks=0, seg=255)`.
+ Without this argument, the padding value of "gt_semantic_seg"
+ will be set to 0 by default, which should be 255.
+ """
+ def __init__(self,
+ img_to_float=True,
+ pad_val=dict(img=0, masks=0, seg=255)):
+ self.img_to_float = img_to_float
+ self.pad_val = pad_val
+
+ def __call__(self, results):
+ """Call function to transform and format common fields in results.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data that is formatted with \
+ default bundle.
+ """
+ data_keys = [
+ 'center', 'scale', 'rotation', 'smpl_body_pose',
+ 'smpl_global_orient', 'smpl_betas', 'smpl_transl', 'area',
+ 'bbox_xywh', 'has_smpl', 'keypoints2d_ori', 'keypoints3d_ori',
+ 'keypoints2d_smpl', 'keypoints3d_smpl', 'has_keypoints2d_ori',
+ 'has_keypoints3d_ori', 'has_keypoints2d_smpl',
+ 'has_keypoints3d_smpl'
+ ]
+ if 'img' in results:
+ img = results['img']
+ if self.img_to_float is True and img.dtype == np.uint8:
+ # Normally, image is of uint8 type without normalization.
+ # At this time, it needs to be forced to be converted to
+ # flot32, otherwise the model training and inference
+ # will be wrong. Only used for YOLOX currently .
+ img = img.astype(np.float32)
+ # add default meta keys
+ results = self._add_default_meta_keys(results)
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ img = np.ascontiguousarray(img.transpose(2, 0, 1))
+ results['img'] = DC(to_tensor(img),
+ padding_value=self.pad_val['img'],
+ stack=True)
+ for key in data_keys:
+ if key not in results:
+ continue
+ results[key] = DC(to_tensor(results[key]))
+ # if 'gt_masks' in results:
+ # results['gt_masks'] = DC(
+ # results['gt_masks'],
+ # padding_value=self.pad_val['masks'],
+ # cpu_only=True)
+ # if 'gt_semantic_seg' in results:
+ # results['gt_semantic_seg'] = DC(
+ # to_tensor(results['gt_semantic_seg'][None, ...]),
+ # padding_value=self.pad_val['seg'],
+ # stack=True)
+ return results
+
+ def _add_default_meta_keys(self, results):
+ """Add default meta keys.
+
+ We set default meta keys including `pad_shape`, `scale_factor` and
+ `img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and
+ `Pad` are implemented during the whole pipeline.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ results (dict): Updated result dict contains the data to convert.
+ """
+ img = results['img']
+ results.setdefault('pad_shape', img.shape)
+ results.setdefault('scale_factor', 1.0)
+ num_channels = 1 if len(img.shape) < 3 else img.shape[2]
+ results.setdefault(
+ 'img_norm_cfg',
+ dict(mean=np.zeros(num_channels, dtype=np.float32),
+ std=np.ones(num_channels, dtype=np.float32),
+ to_rgb=False))
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(img_to_float={self.img_to_float})'
+
+
+@PIPELINES.register_module()
+class WrapFieldsToLists(object):
+ """Wrap fields of the data dictionary into lists for evaluation.
+
+ This class can be used as a last step of a test or validation
+ pipeline for single image evaluation or inference.
+
+ Example:
+ >>> test_pipeline = [
+ >>> dict(type='LoadImageFromFile'),
+ >>> dict(type='Normalize',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ to_rgb=True),
+ >>> dict(type='ImageToTensor', keys=['img']),
+ >>> dict(type='Collect', keys=['img']),
+ >>> dict(type='WrapIntoLists')
+ >>> ]
+ """
+ def __call__(self, results):
+ # Wrap dict fields into lists
+ for key, val in results.items():
+ results[key] = [val]
+ return results
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}()'
diff --git a/detrsmpl/data/datasets/pipelines/hybrik_transforms.py b/detrsmpl/data/datasets/pipelines/hybrik_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..a00f17626c722414f587f6e53cbb037e29692c16
--- /dev/null
+++ b/detrsmpl/data/datasets/pipelines/hybrik_transforms.py
@@ -0,0 +1,877 @@
+import math
+import random
+
+import cv2
+import mmcv
+import numpy as np
+
+from detrsmpl.core.conventions.keypoints_mapping import get_flip_pairs
+from detrsmpl.utils.demo_utils import box2cs, xyxy2xywh
+from ..builder import PIPELINES
+from .transforms import (
+ _rotate_smpl_pose,
+ affine_transform,
+ get_affine_transform,
+)
+
+
+def get_bbox(bbox_xywh, w, h):
+ """Obtain bbox in xyxy format given bbox in xywh format and applying
+ clipping to ensure bbox is within image bounds.
+
+ Args:
+ xywh (list): bbox in format (x, y, w, h).
+ w (int): image width
+ h (int): image height
+
+ Returns:
+ xyxy (numpy.ndarray): Converted bboxes in format (xmin, ymin,
+ xmax, ymax).
+ """
+ bbox_xywh = bbox_xywh.reshape(1, 4)
+ xmin, ymin, xmax, ymax = bbox_clip_xyxy(bbox_xywh_to_xyxy(bbox_xywh), w, h)
+ bbox = np.array([xmin, ymin, xmax, ymax])
+ return bbox
+
+
+def heatmap2coord(pred_jts,
+ pred_scores,
+ hm_shape,
+ bbox,
+ output_3d=False,
+ mean_bbox_scale=None):
+ """Retrieve predicted keypoints and scores from heatmap."""
+ hm_width, hm_height = hm_shape
+
+ ndims = pred_jts.dim()
+ assert ndims in [2, 3], 'Dimensions of input heatmap should be 2 or 3'
+ if ndims == 2:
+ pred_jts = pred_jts.unsqueeze(0)
+ pred_scores = pred_scores.unsqueeze(0)
+
+ coords = pred_jts.cpu().numpy()
+ coords = coords.astype(float)
+ pred_scores = pred_scores.cpu().numpy()
+ pred_scores = pred_scores.astype(float)
+
+ coords[:, :, 0] = (coords[:, :, 0] + 0.5) * hm_width
+ coords[:, :, 1] = (coords[:, :, 1] + 0.5) * hm_height
+
+ preds = np.zeros_like(coords)
+ # transform bbox to scale
+ xmin, ymin, xmax, ymax = bbox
+ w = xmax - xmin
+ h = ymax - ymin
+ center = np.array([xmin + w * 0.5, ymin + h * 0.5])
+ scale = np.array([w, h])
+ # Transform back
+ for i in range(coords.shape[0]):
+ for j in range(coords.shape[1]):
+ preds[i, j, 0:2] = transform_preds(coords[i, j, 0:2], center,
+ scale, [hm_width, hm_height])
+ if output_3d:
+ if mean_bbox_scale is not None:
+ zscale = scale[0] / mean_bbox_scale
+ preds[i, j, 2] = coords[i, j, 2] / zscale
+ else:
+ preds[i, j, 2] = coords[i, j, 2]
+ # maxvals = np.ones((*preds.shape[:2], 1), dtype=float)
+ # score_mul = 1 if norm_name == 'sigmoid' else 5
+
+ return preds, pred_scores
+
+
+def transform_preds(coords, center, scale, output_size):
+ """Transform heatmap coordinates to image coordinates."""
+ target_coords = np.zeros(coords.shape)
+ trans = get_affine_transform(center,
+ scale,
+ 0,
+ output_size,
+ inv=1,
+ pixel_std=1)
+ target_coords[0:2] = affine_transform(coords[0:2], trans)
+ return target_coords
+
+
+def bbox_xywh_to_xyxy(xywh):
+ """Convert bounding boxes from format (x, y, w, h) to (xmin, ymin, xmax,
+ ymax)
+
+ Args:
+ xywh (list, tuple or numpy.ndarray): bbox in format (x, y, w, h).
+ If numpy.ndarray is provided, we expect multiple bounding boxes with
+ shape `(N, 4)`.
+
+ Returns:
+ xyxy (tuple or numpy.ndarray): Converted bboxes in format (xmin, ymin,
+ xmax, ymax). Return numpy.ndarray if input is in the same format.
+ """
+ if isinstance(xywh, (tuple, list)):
+ if not len(xywh) == 4:
+ raise IndexError(
+ 'Bounding boxes must have 4 elements, given {}'.format(
+ len(xywh)))
+ w, h = np.maximum(xywh[2] - 1, 0), np.maximum(xywh[3] - 1, 0)
+ return (xywh[0], xywh[1], xywh[0] + w, xywh[1] + h)
+ elif isinstance(xywh, np.ndarray):
+ if not xywh.size % 4 == 0:
+ raise IndexError(
+ 'Bounding boxes must have n * 4 elements, given {}'.format(
+ xywh.shape))
+ xyxy = np.hstack(
+ (xywh[:, :2], xywh[:, :2] + np.maximum(0, xywh[:, 2:4] - 1)))
+ return xyxy
+ else:
+ raise TypeError(
+ 'Expect input xywh a list, tuple or numpy.ndarray, given {}'.
+ format(type(xywh)))
+
+
+def bbox_clip_xyxy(xyxy, width, height):
+ """Clip bounding box with format (xmin, ymin, xmax, ymax) to `(0, 0, width,
+ height)`.
+
+ Args:
+ xyxy (list, tuple or numpy.ndarray): bbox in format (xmin, ymin,
+ xmax, ymax). If numpy.ndarray is provided, we expect multiple bounding
+ boxes with shape `(N, 4)`.
+ width (int or float): Boundary width.
+ height (int or float): Boundary height.
+
+ Returns:
+ xyxy (list, tuple or numpy.ndarray): clipped bbox in format (xmin, ymin,
+ xmax, ymax) and input type
+ """
+ if isinstance(xyxy, (tuple, list)):
+ if not len(xyxy) == 4:
+ raise IndexError(
+ 'Bounding boxes must have 4 elements, given {}'.format(
+ len(xyxy)))
+ x1 = np.minimum(width - 1, np.maximum(0, xyxy[0]))
+ y1 = np.minimum(height - 1, np.maximum(0, xyxy[1]))
+ x2 = np.minimum(width - 1, np.maximum(0, xyxy[2]))
+ y2 = np.minimum(height - 1, np.maximum(0, xyxy[3]))
+ return (x1, y1, x2, y2)
+ elif isinstance(xyxy, np.ndarray):
+ if not xyxy.size % 4 == 0:
+ raise IndexError(
+ 'Bounding boxes must have n * 4 elements, given {}'.format(
+ xyxy.shape))
+ x1 = np.minimum(width - 1, np.maximum(0, xyxy[:, 0]))
+ y1 = np.minimum(height - 1, np.maximum(0, xyxy[:, 1]))
+ x2 = np.minimum(width - 1, np.maximum(0, xyxy[:, 2]))
+ y2 = np.minimum(height - 1, np.maximum(0, xyxy[:, 3]))
+ return np.hstack((x1, y1, x2, y2))
+ else:
+ raise TypeError(
+ 'Expect input xywh a list, tuple or numpy.ndarray, given {}'.
+ format(type(xyxy)))
+
+
+def cam2pixel(cam_coord, f, c):
+ """Convert coordinates from camera to image frame given f and c
+ Args:
+ cam_coord (np.ndarray): Coordinates in camera frame
+ f (list): focal length, fx, fy
+ c (list): principal point offset, x0, y0
+
+ Returns:
+ img_coord (np.ndarray): Coordinates in image frame
+ """
+
+ x = cam_coord[:, 0] / (cam_coord[:, 2] + 1e-8) * f[0] + c[0]
+ y = cam_coord[:, 1] / (cam_coord[:, 2] + 1e-8) * f[1] + c[1]
+ z = cam_coord[:, 2]
+ img_coord = np.concatenate((x[:, None], y[:, None], z[:, None]), 1)
+ return img_coord
+
+
+def get_intrinsic_matrix(f, c, inv=False):
+ """Get intrisic matrix (or its inverse) given f and c.
+ Args:
+ f (list): focal length, fx, fy
+ c (list): principal point offset, x0, y0
+ inv (bool): Store True to get inverse. Default: False.
+
+ Returns:
+ intrinsic matrix (np.ndarray): 3x3 intrinsic matrix or its inverse
+ """
+ intrinsic_metrix = np.zeros((3, 3)).astype(np.float32)
+ intrinsic_metrix[0, 0] = f[0]
+ intrinsic_metrix[0, 2] = c[0]
+ intrinsic_metrix[1, 1] = f[1]
+ intrinsic_metrix[1, 2] = c[1]
+ intrinsic_metrix[2, 2] = 1
+
+ if inv:
+ intrinsic_metrix = np.linalg.inv(intrinsic_metrix).astype(np.float32)
+ return intrinsic_metrix
+
+
+def aa_to_quat_numpy(axis_angle):
+ """Convert rotations given as axis/angle to quaternions.
+
+ Args:
+ axis_angle: Rotations given as a vector in axis angle form,
+ as a np.ndarray of shape (..., 3), where the magnitude is
+ the angle turned anticlockwise in radians around the
+ vector's direction.
+
+ Returns:
+ quaternions with real part first, as np.ndarray of shape (..., 4).
+ """
+ angles = np.linalg.norm(axis_angle, ord=2, axis=-1, keepdims=True)
+ half_angles = 0.5 * angles
+ eps = 1e-6
+ small_angles = np.abs(angles) < eps
+ sin_half_angles_over_angles = np.empty_like(angles)
+ sin_half_angles_over_angles[~small_angles] = (
+ np.sin(half_angles[~small_angles]) / angles[~small_angles])
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
+ sin_half_angles_over_angles[small_angles] = (
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48)
+ quaternions = np.concatenate(
+ [np.cos(half_angles), axis_angle * sin_half_angles_over_angles],
+ axis=-1)
+ return quaternions
+
+
+def flip_thetas(thetas, theta_pairs):
+ """Flip thetas.
+
+ Args:
+ thetas (np.ndarray): joints in shape (num_thetas, 3)
+ theta_pairs (list): flip pairs for thetas
+
+ Returns:
+ thetas_flip (np.ndarray): flipped thetas with shape (num_thetas, 3)
+ """
+ thetas_flip = thetas.copy()
+ # reflect horizontally
+ thetas_flip[:, 1] = -1 * thetas_flip[:, 1]
+ thetas_flip[:, 2] = -1 * thetas_flip[:, 2]
+ # change left-right parts
+ for pair in theta_pairs:
+ thetas_flip[pair[0], :], thetas_flip[pair[1], :] = \
+ thetas_flip[pair[1], :], thetas_flip[pair[0], :].copy()
+
+ return thetas_flip
+
+
+def flip_joints_3d(joints_3d, joints_3d_visible, width, flip_pairs):
+ """Flip 3d joints.
+
+ Args:
+ joints_3d (np.ndarray): joints in shape (N, 3, 2)
+ width (int): Image width
+ joint_pairs (list): flip pairs for joints
+
+ Returns:
+ joints_3d_flipped (np.ndarray): flipped joints with shape (N, 3, 2)
+ joints_3d_visible_flipped (np.ndarray): visibility of (N, 3, 2)
+ """
+
+ assert len(joints_3d) == len(joints_3d_visible)
+ joints_3d[:, 0] = width - joints_3d[:, 0] - 1
+ joints_3d_flipped = joints_3d.copy()
+ joints_3d_visible_flipped = joints_3d_visible.copy()
+
+ # Swap left-right parts
+ for left, right in flip_pairs:
+ joints_3d_flipped[left, :] = joints_3d[right, :]
+ joints_3d_flipped[right, :] = joints_3d[left, :]
+
+ joints_3d_visible_flipped[left, :] = joints_3d_visible[right, :]
+ joints_3d_visible_flipped[right, :] = joints_3d_visible[left, :]
+
+ joints_3d_flipped = joints_3d_flipped * joints_3d_visible_flipped
+
+ return joints_3d_flipped, joints_3d_visible_flipped
+
+
+def flip_xyz_joints_3d(joints_3d, flip_pairs):
+ """Flip 3d xyz joints.
+
+ Args:
+ joints_3d (np.ndarray): Joints in shape (N, 3)
+ joint_pairs (list): flip pairs for joints
+
+ Returns:
+ joints_3d_flipped (np.ndarray): flipped joints with shape (N, 3)
+ """
+
+ joints_3d[:, 0] = -1 * joints_3d[:, 0]
+ joints_3d_flipped = joints_3d.copy()
+ # change left-right parts
+ for left, right in flip_pairs:
+ joints_3d_flipped[left, :] = joints_3d[right, :]
+ joints_3d_flipped[right, :] = joints_3d[left, :]
+
+ return joints_3d_flipped
+
+
+def flip_twist(twist_phi, twist_weight, twist_pairs):
+ """Flip twist and weight.
+
+ Args:
+ twist_phi (np.ndarray): twist in shape (num_twist, 2)
+ twist_weight (np.ndarray): weight in shape (num_twist, 2)
+ twist_pairs (list): flip pairs for twist
+
+ Returns:
+ twist_flip (np.ndarray): flipped twist with shape (num_twist, 2)
+ weight_flip (np.ndarray): flipped weights with shape (num_twist, 2)
+ """
+ # twist_flip = -1 * twist_phi.copy() # 23 x 2
+ twist_flip = np.zeros_like(twist_phi)
+ weight_flip = twist_weight.copy()
+
+ twist_flip[:, 0] = twist_phi[:, 0].copy() # cos
+ twist_flip[:, 1] = -1 * twist_phi[:, 1].copy() # sin
+ for pair in twist_pairs:
+ idx0 = pair[0] - 1
+ idx1 = pair[1] - 1
+ twist_flip[idx0, :], twist_flip[idx1, :] = \
+ twist_flip[idx1, :], twist_flip[idx0, :].copy()
+
+ weight_flip[idx0, :], weight_flip[idx1, :] = \
+ weight_flip[idx1, :], weight_flip[idx0, :].copy()
+
+ return twist_flip, weight_flip
+
+
+def _center_scale_to_box(center, scale):
+ """Flip twist and weight.
+
+ Args:
+ joints_3d (np.ndarray): Joints in shape (N, 3)
+ joint_pairs (list): flip pairs for joints
+
+ Returns:
+ joints_3d_flipped (np.ndarray): flipped joints with shape (N, 3)
+ """
+ pixel_std = 1.0
+ w = scale[0] * pixel_std
+ h = scale[1] * pixel_std
+ xmin = center[0] - w * 0.5
+ ymin = center[1] - h * 0.5
+ xmax = xmin + w
+ ymax = ymin + h
+ bbox = [xmin, ymin, xmax, ymax]
+ return bbox
+
+
+@PIPELINES.register_module()
+class RandomDPG(object):
+ """Add dpg for data augmentation, including random crop and random sample
+ Required keys: 'bbox', 'ann_info
+ Modifies key: 'bbox', 'center', 'scale'
+ Args:
+ dpg_prob (float): Probability of dpg
+ """
+ def __init__(self, dpg_prob):
+ self.dpg_prob = dpg_prob
+
+ def __call__(self, results):
+ if np.random.rand() > self.dpg_prob:
+ return results
+
+ bbox = results['bbox']
+ imgwidth = results['ann_info']['width']
+ imgheight = results['ann_info']['height']
+
+ PatchScale = random.uniform(0, 1)
+ width = bbox[2] - bbox[0]
+ ht = bbox[3] - bbox[1]
+
+ if PatchScale > 0.85:
+ ratio = ht / width
+ if (width < ht):
+ patchWidth = PatchScale * width
+ patchHt = patchWidth * ratio
+ else:
+ patchHt = PatchScale * ht
+ patchWidth = patchHt / ratio
+
+ xmin = bbox[0] + random.uniform(0, 1) * (width - patchWidth)
+ ymin = bbox[1] + random.uniform(0, 1) * (ht - patchHt)
+ xmax = xmin + patchWidth + 1
+ ymax = ymin + patchHt + 1
+ else:
+ xmin = max(
+ 1,
+ min(bbox[0] + np.random.normal(-0.0142, 0.1158) * width,
+ imgwidth - 3))
+ ymin = max(
+ 1,
+ min(bbox[1] + np.random.normal(0.0043, 0.068) * ht,
+ imgheight - 3))
+ xmax = min(
+ max(xmin + 2,
+ bbox[2] + np.random.normal(0.0154, 0.1337) * width),
+ imgwidth - 3)
+ ymax = min(
+ max(ymin + 2,
+ bbox[3] + np.random.normal(-0.0013, 0.0711) * ht),
+ imgheight - 3)
+ bbox_xyxy = np.array([xmin, ymin, xmax, ymax])
+ bbox_xywh = xyxy2xywh(bbox_xyxy)
+ center, scale = box2cs(bbox_xywh,
+ aspect_ratio=1.0,
+ bbox_scale_factor=1.0)
+ results['bbox'] = bbox_xyxy
+ results['center'] = center
+ results['scale'] = scale
+
+ return results
+
+
+@PIPELINES.register_module()
+class HybrIKRandomFlip:
+ """Data augmentation with random image flip.
+
+ Required keys: 'img', 'keypoints3d', 'keypoints3d_vis', 'center',
+ and 'ann_info', 'has_smpl'
+ Additional keys required if has_smpl: 'keypoints3d17', 'keypoints3d17_vis',
+ 'keypoints3d_relative', 'keypoints3d17_relative', 'pose'
+
+ Modifies key: 'img', 'keypoints3d', 'keypoints3d_vis', 'center', 'pose'
+ Additional keys modified if has_smpl: 'keypoints3d17', 'keypoints3d17_vis',
+ 'keypoints3d_relative', 'keypoints3d17_relative', 'pose'
+
+ Args:
+ flip_prob (float): probability of the image being flipped. Default: 0.5
+ flip_pairs (list[int]): list of left-right keypoint pairs for flipping
+ """
+ def __init__(self, flip_prob=0.5, flip_pairs=None):
+ assert 0 <= flip_prob <= 1
+ self.flip_prob = flip_prob
+ self.flip_pairs = flip_pairs
+
+ def __call__(self, results):
+ """Perform data augmentation with random image flip."""
+ if np.random.rand() > self.flip_prob:
+ results['is_flipped'] = np.array([0])
+ return results
+
+ results['is_flipped'] = np.array([1])
+
+ # flip image
+ for key in results.get('img_fields', ['img']):
+ results[key] = mmcv.imflip(results[key], direction='horizontal')
+
+ width = results['img'][:, ::-1, :].shape[1]
+ # flip bbox center
+ center = results['center']
+ center[0] = width - 1 - center[0]
+ results['center'] = center
+
+ keypoints3d = results['keypoints3d']
+ keypoints3d_vis = results['keypoints3d_vis']
+
+ keypoints3d, keypoints3d_vis = flip_joints_3d(keypoints3d,
+ keypoints3d_vis, width,
+ self.flip_pairs)
+
+ if results['has_smpl']:
+ pose = results['pose']
+ smpl_flip_pairs = get_flip_pairs('smpl')
+ pose = flip_thetas(pose, smpl_flip_pairs)
+
+ keypoints3d17 = results['keypoints3d17']
+ keypoints3d17_vis = results['keypoints3d17_vis']
+ keypoints3d17_relative = results['keypoints3d17_relative']
+ keypoints3d_relative = results['keypoints3d_relative']
+
+ keypoints3d17, keypoints3d17_vis = flip_joints_3d(
+ keypoints3d17, keypoints3d17_vis, width, self.flip_pairs)
+ keypoints3d17_relative = flip_xyz_joints_3d(
+ keypoints3d17_relative, self.flip_pairs)
+ keypoints3d_relative = flip_xyz_joints_3d(keypoints3d_relative,
+ self.flip_pairs)
+ twist_phi, twist_weight = results['target_twist'], results[
+ 'target_twist_weight']
+ results['target_twist'], results[
+ 'target_twist_weight'] = flip_twist(twist_phi, twist_weight,
+ smpl_flip_pairs)
+
+ results['keypoints3d17_relative'] = keypoints3d17_relative.astype(
+ np.float32)
+ results['keypoints3d_relative'] = keypoints3d_relative.astype(
+ np.float32)
+ results['keypoints3d17'] = keypoints3d17.astype(np.float32)
+ results['keypoints3d17_vis'] = keypoints3d17_vis.astype(np.float32)
+ results['pose'] = pose.astype(np.float32)
+
+ results['keypoints3d'] = keypoints3d.astype(np.float32)
+ results['keypoints3d_vis'] = keypoints3d_vis.astype(np.float32)
+ return results
+
+
+@PIPELINES.register_module()
+class HybrIKAffine:
+ """Affine transform the image to get input image. Affine transform the 2D
+ keypoints, 3D kepoints and IUV image too.
+
+ Required keys: 'img', 'keypoints3d', 'keypoints3d_vis', 'pose', 'ann_info',
+ 'scale', 'keypoints3d17', 'keypoints3d17_vis', 'rotation' and 'center'.
+ Modifies key: 'img', 'keypoints3d','keypoints3d_vis', 'pose',
+ 'keypoints3d17', 'keypoints3d17_vis'
+ """
+ def __init__(self, img_res):
+ self.image_size = np.array([img_res, img_res])
+
+ def __call__(self, results):
+
+ img = results['img']
+ keypoints3d = results['keypoints3d']
+ num_joints = len(keypoints3d)
+ keypoints3d_vis = results['keypoints3d_vis']
+ has_smpl = results['has_smpl']
+
+ c = results['center']
+ s = results['scale']
+ r = results['rotation']
+ trans = get_affine_transform(c, s, r, self.image_size, pixel_std=1)
+ img = cv2.warpAffine(
+ img,
+ trans, (int(self.image_size[0]), int(self.image_size[1])),
+ flags=cv2.INTER_LINEAR)
+
+ for i in range(num_joints):
+ if keypoints3d_vis[i, 0] > 0.0:
+ keypoints3d[i, 0:2] = affine_transform(keypoints3d[i, 0:2],
+ trans)
+
+ if has_smpl:
+
+ keypoints3d17 = results['keypoints3d17']
+ keypoints3d17_vis = results['keypoints3d17_vis']
+ for i in range(17):
+ if keypoints3d17_vis[i, 0] > 0.0:
+ keypoints3d17[i, 0:2] = affine_transform(
+ keypoints3d17[i, 0:2], trans)
+ results['keypoints3d17'] = keypoints3d17
+ results['keypoints3d17_vis'] = keypoints3d17_vis
+
+ # to rotate poses
+ pose = results['pose']
+ pose = _rotate_smpl_pose(pose.reshape(-1), r)
+ results['pose'] = pose.reshape(24, 3)
+
+ results['img'] = img.astype(np.float32)
+ results['keypoints3d_vis'] = keypoints3d_vis.astype(np.float32)
+ results['keypoints3d'] = keypoints3d.astype(np.float32)
+
+ return results
+
+
+@PIPELINES.register_module()
+class RandomOcclusion:
+ """Add random occlusion.
+
+ Add random occlusion based on occlusion probability.
+
+ Args:
+ occlusion_prob (float): probability of the image having
+ occlusion. Default: 0.5
+ """
+ def __init__(self, occlusion_prob=0.5):
+ self.occlusion_prob = occlusion_prob
+
+ def __call__(self, results):
+
+ if np.random.rand() > self.occlusion_prob:
+ return results
+
+ xmin, ymin, xmax, ymax = results['bbox']
+ imgwidth = results['ann_info']['width']
+ imgheight = results['ann_info']['height']
+ img = results['img']
+
+ area_min = 0.0
+ area_max = 0.7
+ synth_area = (random.random() * (area_max - area_min) +
+ area_min) * (xmax - xmin) * (ymax - ymin)
+
+ ratio_min = 0.3
+ ratio_max = 1 / 0.3
+ synth_ratio = (random.random() * (ratio_max - ratio_min) + ratio_min)
+
+ synth_h = math.sqrt(synth_area * synth_ratio)
+ synth_w = math.sqrt(synth_area / synth_ratio)
+ synth_xmin = random.random() * ((xmax - xmin) - synth_w - 1) + xmin
+ synth_ymin = random.random() * ((ymax - ymin) - synth_h - 1) + ymin
+
+ if synth_xmin >= 0 and synth_ymin >= 0 and \
+ synth_xmin + synth_w < imgwidth and \
+ synth_ymin + synth_h < imgheight:
+ synth_xmin = int(synth_xmin)
+ synth_ymin = int(synth_ymin)
+ synth_w = int(synth_w)
+ synth_h = int(synth_h)
+ img[synth_ymin:synth_ymin + synth_h, synth_xmin:synth_xmin +
+ synth_w, :] = np.random.rand(synth_h, synth_w, 3) * 255
+
+ results['img'] = img
+
+ return results
+
+
+@PIPELINES.register_module()
+class GenerateHybrIKTarget:
+ """Generate the targets required for training.
+
+ Required keys: 'keypoints3d', 'keypoints3d_vis', 'ann_info', 'depth_factor'
+ Additional keys if has_smpl: 'keypoints3d17', 'keypoints3d17_vis',
+ 'keypoints3d_relative', 'keypoints3d17_relative' Add keys: 'target_uvd_29',
+ 'target_xyz_24', 'target_weight_24', 'target_weight_29', 'target_xyz_17',
+ 'target_weight_17', 'target_theta', 'target_beta', 'target_smpl_weight',
+ 'target_theta_weight', trans_inv', 'bbox'
+ """
+ def __init__(self, img_res, test_mode):
+ self.test_mode = test_mode
+ self.image_size = np.array([img_res, img_res])
+
+ def _integral_uvd_target_generator(self,
+ joints_3d,
+ num_joints,
+ patch_height,
+ patch_width,
+ depth_factor,
+ test_mode=False):
+
+ target_weight = np.ones((num_joints, 3), dtype=np.float32)
+ target_weight[:, 0] = joints_3d[:, 0, 1]
+ target_weight[:, 1] = joints_3d[:, 0, 1]
+ target_weight[:, 2] = joints_3d[:, 0, 1]
+
+ target = np.zeros((num_joints, 3), dtype=np.float32)
+ target[:, 0] = joints_3d[:, 0, 0] / patch_width - 0.5
+ target[:, 1] = joints_3d[:, 1, 0] / patch_height - 0.5
+ target[:, 2] = joints_3d[:, 2, 0] / depth_factor
+
+ target_weight[target[:, 0] > 0.5] = 0
+ target_weight[target[:, 0] < -0.5] = 0
+ target_weight[target[:, 1] > 0.5] = 0
+ target_weight[target[:, 1] < -0.5] = 0
+ target_weight[target[:, 2] > 0.5] = 0
+ target_weight[target[:, 2] < -0.5] = 0
+
+ target = target.reshape((-1))
+ target_weight = target_weight.reshape((-1))
+ return target, target_weight
+
+ def _integral_target_generator(self, joints_3d, num_joints, patch_height,
+ patch_width, depth_factor):
+ target_weight = np.ones((num_joints, 3), dtype=np.float32)
+ target_weight[:, 0] = joints_3d[:, 0, 1]
+ target_weight[:, 1] = joints_3d[:, 0, 1]
+ target_weight[:, 2] = joints_3d[:, 0, 1]
+
+ target = np.zeros((num_joints, 3), dtype=np.float32)
+ target[:, 0] = joints_3d[:, 0, 0] / patch_width - 0.5
+ target[:, 1] = joints_3d[:, 1, 0] / patch_height - 0.5
+ target[:, 2] = joints_3d[:, 2, 0] / depth_factor
+
+ target_weight[target[:, 0] > 0.5] = 0
+ target_weight[target[:, 0] < -0.5] = 0
+ target_weight[target[:, 1] > 0.5] = 0
+ target_weight[target[:, 1] < -0.5] = 0
+ target_weight[target[:, 2] > 0.5] = 0
+ target_weight[target[:, 2] < -0.5] = 0
+
+ target = target.reshape((-1))
+ target_weight = target_weight.reshape((-1))
+ return target, target_weight
+
+ def _integral_xyz_target_generator(self, joints_3d, joints_3d_vis,
+ num_joints, depth_factor):
+ target_weight = np.ones((num_joints, 3), dtype=np.float32)
+ target_weight[:, 0] = joints_3d_vis[:, 0]
+ target_weight[:, 1] = joints_3d_vis[:, 1]
+ target_weight[:, 2] = joints_3d_vis[:, 2]
+
+ target = np.zeros((num_joints, 3), dtype=np.float32)
+ target[:, 0] = joints_3d[:, 0] / int(depth_factor)
+ target[:, 1] = joints_3d[:, 1] / int(depth_factor)
+ target[:, 2] = joints_3d[:, 2] / int(depth_factor)
+
+ target = target.reshape((-1))
+ target_weight = target_weight.reshape((-1))
+ return target, target_weight
+
+ def _integral_target_generator_coco(self, joints_3d, num_joints,
+ patch_height, patch_width):
+ target_weight = np.ones((num_joints, 2), dtype=np.float32)
+ target_weight[:, 0] = joints_3d[:, 0, 1]
+ target_weight[:, 1] = joints_3d[:, 0, 1]
+
+ target = np.zeros((num_joints, 2), dtype=np.float32)
+ target[:, 0] = joints_3d[:, 0, 0] / patch_width - 0.5
+ target[:, 1] = joints_3d[:, 1, 0] / patch_height - 0.5
+
+ target = target.reshape((-1))
+ target_weight = target_weight.reshape((-1))
+ return target, target_weight
+
+ def __call__(self, results):
+
+ has_smpl = results['has_smpl']
+ inp_h, inp_w = self.image_size[0], self.image_size[1]
+
+ keypoints3d = results['keypoints3d']
+ num_joints = len(keypoints3d)
+ keypoints3d_vis = results['keypoints3d_vis']
+ depth_factor = results['depth_factor']
+
+ c = results['center']
+ s = results['scale']
+ r = results['rotation']
+
+ # generate new keys
+ trans_inv = get_affine_transform(c,
+ s,
+ r,
+ self.image_size,
+ inv=True,
+ pixel_std=1).astype(np.float32)
+ results['trans_inv'] = trans_inv.astype(np.float32)
+ bbox = _center_scale_to_box(c, s)
+ results['bbox'] = np.array(bbox, dtype=np.float32)
+
+ if has_smpl:
+ theta = results['pose']
+ # aa to quat
+ results['target_theta'] = aa_to_quat_numpy(theta).reshape(
+ 24 * 4).astype(np.float32)
+ theta_24_weights = np.ones((24, 4))
+ results['target_theta_weight'] = theta_24_weights.reshape(
+ 24 * 4).astype(np.float32)
+
+ results['target_beta'] = results['beta'].astype(np.float32)
+ results['target_smpl_weight'] = np.ones(1).astype(np.float32)
+
+ keypoints3d17_vis = results['keypoints3d17_vis']
+ keypoints3d17_relative = results['keypoints3d17_relative']
+ joints24_relative_3d = results['keypoints3d_relative'][:24, :]
+
+ gt_joints_29 = np.zeros((29, 3, 2), dtype=np.float32)
+ gt_joints_29[:, :, 0] = keypoints3d.copy()
+ gt_joints_29[:, :, 1] = keypoints3d_vis.copy()
+
+ target_uvd_29, target_weight_29 = \
+ self._integral_uvd_target_generator(
+ gt_joints_29, 29, inp_h, inp_w, depth_factor)
+ target_xyz_17, target_weight_17 = \
+ self._integral_xyz_target_generator(
+ keypoints3d17_relative, keypoints3d17_vis, 17,
+ depth_factor)
+ target_xyz_24, target_weight_24 = \
+ self._integral_xyz_target_generator(
+ joints24_relative_3d, keypoints3d_vis[:24, :], 24,
+ depth_factor)
+ target_weight_29 *= keypoints3d_vis.reshape(-1)
+ target_weight_24 *= keypoints3d_vis[:24, :].reshape(-1)
+ target_weight_17 *= keypoints3d17_vis.reshape(-1)
+
+ results['target_uvd_29'] = target_uvd_29.astype(np.float32)
+ results['target_xyz_24'] = target_xyz_24.astype(np.float32)
+ results['target_weight_29'] = target_weight_29.astype(np.float32)
+ results['target_weight_24'] = target_weight_24.astype(np.float32)
+ results['target_xyz_17'] = target_xyz_17.astype(np.float32)
+ results['target_weight_17'] = target_weight_17.astype(np.float32)
+ else:
+ label_uvd_29 = np.zeros((29, 3))
+ label_xyz_24 = np.zeros((24, 3))
+ label_uvd_29_mask = np.zeros((29, 3))
+ label_xyz_17 = np.zeros((17, 3))
+ label_xyz_17_mask = np.zeros((17, 3))
+
+ gt_joints = np.zeros((num_joints, 3, 2), dtype=np.float32)
+ gt_joints[:, :, 0] = keypoints3d.copy()
+ gt_joints[:, :, 1] = keypoints3d_vis.copy()
+ mask_idx = [1, 2, 6, 9, 10, 11]
+
+ if results['ann_info']['dataset_name'] == 'coco':
+ target, target_weight = self._integral_target_generator_coco(
+ gt_joints, num_joints, inp_h, inp_w)
+
+ label_jts_origin = target * target_weight
+ label_jts_mask_origin = target_weight
+
+ label_jts_origin = label_jts_origin.reshape(num_joints, 2)
+ label_jts_mask_origin = label_jts_mask_origin.reshape(
+ num_joints, 2)
+ label_jts_origin[mask_idx] = label_jts_origin[mask_idx] * 0
+ label_jts_mask_origin[
+ mask_idx] = label_jts_origin[mask_idx] * 0
+ label_uvd_29 = np.hstack([label_jts_origin, np.zeros([29, 1])])
+ label_uvd_29_mask = np.hstack(
+ [label_jts_mask_origin,
+ np.zeros([29, 1])])
+
+ elif results['ann_info']['dataset_name'] == 'mpi_inf_3dhp':
+ if not self.test_mode:
+ target, target_weight = self._integral_target_generator(
+ gt_joints, num_joints, inp_h, inp_w, depth_factor)
+ target_weight *= keypoints3d_vis.reshape(-1)
+
+ label_jts_origin = target * target_weight
+ label_jts_mask_origin = target_weight
+
+ label_jts_origin = label_jts_origin.reshape(num_joints, 3)
+ label_jts_mask_origin = label_jts_mask_origin.reshape(
+ num_joints, 3)
+ label_jts_origin[mask_idx] = label_jts_origin[mask_idx] * 0
+ label_jts_mask_origin[
+ mask_idx] = label_jts_origin[mask_idx] * 0
+ label_uvd_29 = label_jts_origin
+ label_uvd_29_mask = label_jts_mask_origin
+
+ label_uvd_29 = label_uvd_29.reshape(-1)
+ label_xyz_24 = label_xyz_24.reshape(-1)
+ label_uvd_24_mask = label_uvd_29_mask[:24, :].reshape(-1)
+ label_uvd_29_mask = label_uvd_29_mask.reshape(-1)
+ label_xyz_17 = label_xyz_17.reshape(-1)
+ label_xyz_17_mask = label_xyz_17_mask.reshape(-1)
+
+ results['target_uvd_29'] = label_uvd_29.astype(np.float32)
+ results['target_xyz_24'] = label_xyz_24.astype(np.float32)
+ results['target_weight_24'] = label_uvd_24_mask.astype(np.float32)
+ results['target_weight_29'] = label_uvd_29_mask.astype(np.float32)
+ results['target_xyz_17'] = label_xyz_17.astype(np.float32)
+ results['target_weight_17'] = label_xyz_17_mask.astype(np.float32)
+ results['target_theta'] = np.zeros(24 * 4).astype(np.float32)
+ results['target_beta'] = np.zeros(10).astype(np.float32)
+ results['target_smpl_weight'] = np.zeros(1).astype(np.float32)
+ results['target_theta_weight'] = np.zeros(24 * 4).astype(
+ np.float32)
+
+ return results
+
+
+@PIPELINES.register_module()
+class NewKeypointsSelection:
+ """Select keypoints.
+
+ Modifies specified keys
+
+ Args:
+ map (dict): keypoints and index for selection
+ """
+ def __init__(self, maps):
+ self.maps = maps
+
+ def __call__(self, results):
+ """Perform keypoints selection."""
+
+ for map in self.maps:
+ for keypoint in map['keypoints']:
+ keypoints_index = map['keypoints_index']
+ if keypoint in results:
+ results[keypoint] = results[keypoint][...,
+ keypoints_index, :]
+ return results
diff --git a/detrsmpl/data/datasets/pipelines/loading.py b/detrsmpl/data/datasets/pipelines/loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3a3f01c4a9242542b1552eb13519f2297295fc8
--- /dev/null
+++ b/detrsmpl/data/datasets/pipelines/loading.py
@@ -0,0 +1,85 @@
+import os.path as osp
+
+import cv2
+import mmcv
+import numpy as np
+
+from detrsmpl.data.data_structures.smc_reader import SMCReader
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class LoadImageFromFile(object):
+ """Load an image from file.
+
+ Required keys are "img_prefix" and "img_info" (a dict that must contain the
+ key "filename"). Added or updated keys are "filename", "img", "img_shape",
+ "ori_shape" (same as `img_shape`) and "img_norm_cfg" (means=0 and stds=1).
+ Both "img_shape" and "ori_shape" use (height, width) convention.
+
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image is an uint8 array.
+ Defaults to False.
+ color_type (str): The flag argument for :func:`mmcv.imfrombytes()`.
+ Defaults to 'color'.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ """
+ def __init__(self,
+ to_float32=False,
+ color_type='color',
+ file_client_args=dict(backend='disk')):
+ self.to_float32 = to_float32
+ self.color_type = color_type
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+
+ def __call__(self, results):
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ if results['img_prefix'] is not None:
+ filename = osp.join(results['img_prefix'], results['image_path'])
+ else:
+ filename = results['image_path']
+
+ if filename.endswith('smc'):
+ assert 'image_id' in results, 'Load image from .smc, ' \
+ 'but image_id is not provided.'
+ device, device_id, frame_id = results['image_id']
+ smc_reader = SMCReader(filename)
+ img = smc_reader.get_color(device,
+ device_id,
+ frame_id,
+ disable_tqdm=True)
+ img = img.squeeze() # (1, H, W, 3) -> (H, W, 3)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # BGR is used
+ del smc_reader
+ else:
+ img_bytes = self.file_client.get(filename)
+ img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
+
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['filename'] = filename
+ results['ori_filename'] = results['image_path']
+ results['img'] = img
+ results['img_shape'] = img.shape[:2]
+ results['ori_shape'] = img.shape[:2]
+ num_channels = 1 if len(img.shape) < 3 else img.shape[2]
+ results['img_norm_cfg'] = dict(mean=np.zeros(num_channels,
+ dtype=np.float32),
+ std=np.ones(num_channels,
+ dtype=np.float32),
+ to_rgb=False)
+ return results
+
+ def __repr__(self):
+ repr_str = (f'{self.__class__.__name__}('
+ f'to_float32={self.to_float32}, '
+ f"color_type='{self.color_type}', "
+ f'file_client_args={self.file_client_args})')
+ return repr_str
diff --git a/detrsmpl/data/datasets/pipelines/synthetic_occlusion_augmentation.py b/detrsmpl/data/datasets/pipelines/synthetic_occlusion_augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ca81ef0f1dfeebe44637741fd1f4a57d644e414
--- /dev/null
+++ b/detrsmpl/data/datasets/pipelines/synthetic_occlusion_augmentation.py
@@ -0,0 +1,137 @@
+"""This script is modified from https://github.com/ isarandi/synthetic-
+occlusion.
+
+Original license please see docs/additional_licenses.md.
+"""
+import os.path
+import random
+
+import cv2
+import numpy as np
+
+from ..builder import PIPELINES
+
+
+def load_pascal_occluders(occluders_file):
+ """load pascal occluders from the occluder file."""
+
+ if os.path.isfile(occluders_file):
+ return np.load(occluders_file, allow_pickle=True)
+ else:
+ raise NotImplementedError()
+
+
+def occlude_with_pascal_objects(im, occluders):
+ """Returns an augmented version of `im`, containing some occluders from the
+ Pascal VOC dataset."""
+
+ result = im.copy()
+ width_height = np.asarray([im.shape[1], im.shape[0]])
+ im_scale_factor = min(width_height) / 256
+ count = np.random.randint(1, 8)
+
+ # logger.debug(f'Number of augmentation objects: {count}')
+
+ for _ in range(count):
+ occluder = random.choice(occluders)
+
+ center = np.random.uniform([0, 0], width_height)
+ random_scale_factor = np.random.uniform(0.2, 1.0)
+ scale_factor = random_scale_factor * im_scale_factor
+
+ # logger.debug(f'occluder size: {occluder.shape},
+ # scale_f: {scale_factor}, img_scale: {im_scale_factor}')
+ occluder = resize_by_factor(occluder, scale_factor)
+
+ paste_over(im_src=occluder, im_dst=result, center=center)
+
+ return result
+
+
+def paste_over(im_src, im_dst, center):
+ """Pastes `im_src` onto `im_dst` at a specified position, with alpha
+ blending, in place.
+
+ Locations outside the bounds of `im_dst`
+ are handled as expected (only a part or none of `im_src` becomes visible).
+
+ Args:
+ im_src: The RGBA image to be pasted onto `im_dst`.
+ Its size can be arbitrary.
+ im_dst: The target image.
+ alpha: A float (0.0-1.0) array of the same size as `im_src`
+ controlling the alpha blending at each pixel.
+ Large values mean more visibility for `im_src`.
+ center: coordinates in `im_dst` where
+ the center of `im_src` should be placed.
+ """
+
+ width_height_src = np.asarray([im_src.shape[1], im_src.shape[0]])
+ width_height_dst = np.asarray([im_dst.shape[1], im_dst.shape[0]])
+
+ center = np.round(center).astype(np.int32)
+ raw_start_dst = center - width_height_src // 2
+ raw_end_dst = raw_start_dst + width_height_src
+
+ start_dst = np.clip(raw_start_dst, 0, width_height_dst)
+ end_dst = np.clip(raw_end_dst, 0, width_height_dst)
+ region_dst = im_dst[start_dst[1]:end_dst[1], start_dst[0]:end_dst[0]]
+
+ start_src = start_dst - raw_start_dst
+ end_src = width_height_src + (end_dst - raw_end_dst)
+ region_src = im_src[start_src[1]:end_src[1], start_src[0]:end_src[0]]
+ color_src = region_src[..., 0:3]
+ alpha = region_src[..., 3:].astype(np.float32) / 255
+
+ im_dst[start_dst[1]:end_dst[1],
+ start_dst[0]:end_dst[0]] = (alpha * color_src +
+ (1 - alpha) * region_dst)
+
+
+def resize_by_factor(im, factor):
+ """Returns a copy of `im` resized by `factor`, using bilinear interp for up
+ and area interp for downscaling."""
+ new_size = tuple(
+ np.round(np.array([im.shape[1], im.shape[0]]) * factor).astype(int))
+ interp = cv2.INTER_LINEAR if factor > 1.0 else cv2.INTER_AREA
+ return cv2.resize(im, new_size, fx=factor, fy=factor, interpolation=interp)
+
+
+def list_filepaths(dirpath):
+ """list the file paths."""
+ names = os.listdir(dirpath)
+ paths = [os.path.join(dirpath, name) for name in names]
+ return sorted(filter(os.path.isfile, paths))
+
+
+@PIPELINES.register_module()
+class SyntheticOcclusion:
+ """Data augmentation with synthetic occlusion.
+
+ Required keys: 'img'
+ Modifies key: 'img'
+ Args:
+ flip_prob (float): probability of the image being flipped. Default: 0.5
+ flip_pairs (list[int]): list of left-right keypoint pairs for flipping
+ occ_aug_dataset (str): name of occlusion dataset. Default: pascal
+ pascal_voc_root_path (str): the path to pascal voc dataset,
+ which can generate occluders file.
+ occluders_file (str): occluders file.
+ """
+ def __init__(self, occluders_file='', occluders=None):
+ self.occluders = None
+ if occluders is not None:
+ self.occluders = occluders
+
+ else:
+ self.occluders = load_pascal_occluders(
+ occluders_file=occluders_file, )
+
+ def __call__(self, results):
+ """Perform data augmentation with random channel noise."""
+ img = results['img']
+
+ img = occlude_with_pascal_objects(img, self.occluders)
+
+ results['img'] = img
+ return results
diff --git a/detrsmpl/data/datasets/pipelines/transforms.py b/detrsmpl/data/datasets/pipelines/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..826088283cf16be2371f29902704fe5518b033d7
--- /dev/null
+++ b/detrsmpl/data/datasets/pipelines/transforms.py
@@ -0,0 +1,1284 @@
+import math
+import random
+from collections import Iterable
+
+import cv2
+import mmcv
+import numpy as np
+
+from detrsmpl.utils.demo_utils import xywh2xyxy, xyxy2xywh
+from detrsmpl.core.conventions.keypoints_mapping import get_flip_pairs
+from detrsmpl.utils.transforms import aa_to_rotmat, rotmat_to_aa
+from ..builder import PIPELINES
+from .compose import Compose
+
+
+def get_affine_transform(center,
+ scale,
+ rot,
+ output_size,
+ shift=(0., 0.),
+ inv=False,
+ pixel_std=1.0):
+ """Get the affine transform matrix, given the center/scale/rot/output_size.
+
+ Args:
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
+ scale (np.ndarray[2, ]): Scale of the bounding box
+ wrt [width, height].
+ rot (float): Rotation angle (degree).
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
+ destination heatmaps.
+ shift (0-100%): Shift translation ratio wrt the width/height.
+ Default (0., 0.).
+ inv (bool): Option to inverse the affine transform direction.
+ (inv=False: src->dst or inv=True: dst->src)
+ Returns:
+ np.ndarray: The transform matrix.
+ """
+ assert len(center) == 2
+ assert len(scale) == 2
+ assert len(output_size) == 2
+ assert len(shift) == 2
+
+ scale_tmp = scale * pixel_std
+
+ shift = np.array(shift)
+ src_h = scale_tmp[1]
+ dst_w = output_size[0]
+ dst_h = output_size[1]
+
+ rot_rad = np.pi * rot / 180
+ src_dir = rotate_point([0., src_h * -0.5], rot_rad)
+ dst_dir = np.array([0., dst_h * -0.5])
+
+ src = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = center + scale_tmp * shift
+ src[1, :] = center + src_dir + scale_tmp * shift
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
+
+ dst = np.zeros((3, 2), dtype=np.float32)
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
+
+ if inv:
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+ else:
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+ return trans
+
+
+def affine_transform(pt, trans_mat):
+ """Apply an affine transformation to the points.
+
+ Args:
+ pt (np.ndarray): a 2 dimensional point to be transformed
+ trans_mat (np.ndarray): 2x3 matrix of an affine transform
+ Returns:
+ np.ndarray: Transformed points.
+ """
+ if pt.ndim == 2:
+ new_pt = np.einsum('ij,mj->im', pt, trans_mat)
+ elif pt.ndim == 3:
+ new_pt = np.einsum('nij,mj->nim', pt, trans_mat)
+ else:
+ msg = f'Expected pt to have ndim of 2 or 3, but get {pt.ndim} '
+ raise ValueError(msg)
+ # new_pt = np.array(trans_mat) @ np.array([pt[0], pt[1], 1.])
+
+ return new_pt
+
+
+def _get_3rd_point(a, b):
+ """To calculate the affine matrix, three pairs of points are required. This
+ function is used to get the 3rd point, given 2D points a & b.
+
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
+ anticlockwise, using b as the rotation center.
+ Args:
+ a (np.ndarray): point(x,y)
+ b (np.ndarray): point(x,y)
+ Returns:
+ np.ndarray: The 3rd point.
+ """
+ assert len(a) == 2
+ assert len(b) == 2
+ direction = a - b
+ third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
+
+ return third_pt
+
+
+def rotate_point(pt, angle_rad):
+ """Rotate a point by an angle.
+
+ Args:
+ pt (list[float]): 2 dimensional point to be rotated
+ angle_rad (float): rotation angle by radian
+ Returns:
+ list[float]: Rotated point.
+ """
+ assert len(pt) == 2
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
+ new_x = pt[0] * cs - pt[1] * sn
+ new_y = pt[0] * sn + pt[1] * cs
+ rotated_pt = [new_x, new_y]
+
+ return rotated_pt
+
+
+def get_warp_matrix(theta, size_input, size_dst, size_target):
+ """Calculate the transformation matrix under the constraint of unbiased.
+
+ Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
+ Data Processing for Human Pose Estimation (CVPR 2020).
+ Args:
+ theta (float): Rotation angle in degrees.
+ size_input (np.ndarray): Size of input image [w, h].
+ size_dst (np.ndarray): Size of output image [w, h].
+ size_target (np.ndarray): Size of ROI in input plane [w, h].
+ Returns:
+ matrix (np.ndarray): A matrix for transformation.
+ """
+ theta = np.deg2rad(theta)
+ matrix = np.zeros((2, 3), dtype=np.float32)
+ scale_x = size_dst[0] / size_target[0]
+ scale_y = size_dst[1] / size_target[1]
+ matrix[0, 0] = math.cos(theta) * scale_x
+ matrix[0, 1] = -math.sin(theta) * scale_x
+ matrix[0, 2] = scale_x * (-0.5 * size_input[0] * math.cos(theta) +
+ 0.5 * size_input[1] * math.sin(theta) +
+ 0.5 * size_target[0])
+ matrix[1, 0] = math.sin(theta) * scale_y
+ matrix[1, 1] = math.cos(theta) * scale_y
+ matrix[1, 2] = scale_y * (-0.5 * size_input[0] * math.sin(theta) -
+ 0.5 * size_input[1] * math.cos(theta) +
+ 0.5 * size_target[1])
+ return matrix
+
+
+def warp_affine_joints(joints, mat):
+ """Apply affine transformation defined by the transform matrix on the
+ joints.
+
+ Args:
+ joints (np.ndarray[..., 2]): Origin coordinate of joints.
+ mat (np.ndarray[3, 2]): The affine matrix.
+ Returns:
+ matrix (np.ndarray[..., 2]): Result coordinate of joints.
+ """
+ joints = np.array(joints)
+ shape = joints.shape
+ joints = joints.reshape(-1, 2)
+ return np.dot(np.concatenate((joints, joints[:, 0:1] * 0 + 1), axis=1),
+ mat.T).reshape(shape)
+
+
+def _construct_rotation_matrix(rot, size=3):
+ """Construct the in-plane rotation matrix.
+
+ Args:
+ rot (float): Rotation angle (degree).
+ size (int): The size of the rotation matrix.
+ Candidate Values: 2, 3. Defaults to 3.
+ Returns:
+ rot_mat (np.ndarray([size, size]): Rotation matrix.
+ """
+ rot_mat = np.eye(size, dtype=np.float32)
+ if rot != 0:
+ rot_rad = np.deg2rad(rot)
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
+ rot_mat[0, :2] = [cs, -sn]
+ rot_mat[1, :2] = [sn, cs]
+
+ return rot_mat
+
+
+def _flip_smpl_pose(pose):
+ """Flip SMPL pose parameters horizontally.
+
+ Args:
+ pose (np.ndarray([72])): SMPL pose parameters
+ Returns:
+ pose_flipped
+ """
+
+ flippedParts = [
+ 0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13, 14, 18, 19,
+ 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33, 34, 35, 30, 31, 32, 36, 37,
+ 38, 42, 43, 44, 39, 40, 41, 45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58,
+ 59, 54, 55, 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68
+ ]
+ pose_flipped = pose[..., flippedParts]
+ # Negate the second and the third dimension of the axis-angle
+ pose_flipped[..., 1::3] = -pose_flipped[..., 1::3]
+ pose_flipped[..., 2::3] = -pose_flipped[..., 2::3]
+ return pose_flipped
+
+
+def _flip_smplx_pose(pose):
+ """Flip SMPLX pose parameters horizontally.
+
+ Args:
+ pose (np.ndarray([63])): SMPLX pose parameters
+ Returns:
+ pose_flipped (np.ndarray([21,3]))
+ """
+ flippedParts = np.array([
+ 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13, 14, 18, 19, 20, 24,
+ 25, 26, 21, 22, 23, 27, 28, 29, 33, 34, 35, 30, 31, 32, 36, 37, 38, 42,
+ 43, 44, 39, 40, 41, 45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54,
+ 55, 56, 63, 64, 65, 60, 61, 62
+ ],
+ dtype=np.int32) - 3
+ dim_flip = np.array([1, -1, -1], dtype=pose.dtype)
+ pose = (pose[..., flippedParts].reshape(-1, 21, 3) * dim_flip).copy()
+ return pose
+
+
+def _flip_axis_angle(r):
+ """Flip axis_angle horizontally.
+
+ Args:
+ r (np.ndarray([3]))
+ Returns:
+ f_flipped
+ """
+ dim_flip = np.array([1, -1, -1], dtype=r.dtype)
+ r = r * dim_flip
+ return r
+
+
+def _flip_hand_pose(r_pose, l_pose):
+ dim_flip = np.array([1, -1, -1], dtype=r_pose.dtype)
+ ret_l_pose = r_pose * dim_flip
+ ret_r_pose = l_pose * dim_flip
+ return ret_r_pose, ret_l_pose
+
+
+def _flip_keypoints(keypoints, flip_pairs, img_width=None):
+ """Flip human joints horizontally.
+
+ Note:
+ num_keypoints: K
+ num_dimension: D
+ Args:
+ keypoints (np.ndarray([K, D])): Coordinates of keypoints.
+ flip_pairs (list[tuple()]): Pairs of keypoints which are mirrored
+ (for example, left ear -- right ear).
+ img_width (int | None, optional): The width of the original image.
+ To flip 2D keypoints, image width is needed. To flip 3D keypoints,
+ we simply negate the value of x-axis. Default: None.
+ Returns:
+ keypoints_flipped
+ """
+
+ keypoints_flipped = keypoints.copy()
+
+ # Swap left-right parts
+ for left, right in flip_pairs:
+ keypoints_flipped[..., left, :] = keypoints[..., right, :]
+ keypoints_flipped[..., right, :] = keypoints[..., left, :]
+
+ # Flip horizontally
+ if img_width is None:
+ keypoints_flipped[..., 0] = -keypoints_flipped[..., 0]
+ else:
+ keypoints_flipped[..., 0] = img_width - 1 - keypoints_flipped[..., 0]
+
+ return keypoints_flipped
+
+
+def _rotate_joints_3d(joints_3d, rot):
+ """Rotate the 3D joints in the local coordinates.
+
+ Notes:
+ Joints number: K
+ Args:
+ joints_3d (np.ndarray([K, 3])): Coordinates of keypoints.
+ rot (float): Rotation angle (degree).
+ Returns:
+ joints_3d_rotated
+ """
+ # in-plane rotation
+ # 3D joints are rotated counterclockwise,
+ # so the rot angle is inversed.
+ rot_mat = _construct_rotation_matrix(-rot, 3)
+ if joints_3d.ndim == 2:
+ joints_3d_rotated = np.einsum('ij,kj->ki', rot_mat, joints_3d)
+ elif joints_3d.ndim == 3:
+ joints_3d_rotated = np.einsum('ij,mkj->mki', rot_mat, joints_3d)
+ else:
+ msg = 'Expected joints_3d to have ndim of 2 or 3, '
+ f'but get {joints_3d.ndim}.'
+ raise ValueError(msg)
+ joints_3d_rotated = joints_3d_rotated.astype('float32')
+ return joints_3d_rotated
+
+
+def _rotate_smpl_pose(pose, rot):
+ """Rotate SMPL pose parameters.
+
+ SMPL (https://smpl.is.tue.mpg.de/) is a 3D
+ human model.
+ Args:
+ pose (np.ndarray([72])): SMPL pose parameters
+ rot (float): Rotation angle (degree).
+ Returns:
+ pose_rotated
+ """
+ pose_rotated = pose.copy()
+ if rot != 0:
+ # rot_mat = _construct_rotation_matrix(-rot)
+ # orient = pose[:3]
+ # # find the rotation of the body in camera frame
+ # per_rdg, _ = cv2.Rodrigues(orient.astype(np.float32))
+ # # apply the global rotation to the global orientation
+ # res_rot, _ = cv2.Rodrigues(np.dot(rot_mat, per_rdg))
+ # pose_rotated[:3] = (res_rot.T)[0]
+
+ # use pytorch3d
+ rot_mat = _construct_rotation_matrix(-rot)
+ orient = pose[..., :3]
+ per_rdg = aa_to_rotmat(orient)
+
+ if pose.ndim == 1:
+ tmp_rot = np.einsum('ij,jk->ik', rot_mat, per_rdg)
+ elif pose.ndim == 2:
+ tmp_rot = np.einsum('ij,mjk->mik', rot_mat, per_rdg)
+ else:
+ msg = f'Expected pose to have ndim of 2 or 3, but get {pose.ndim} '
+ raise ValueError(msg)
+
+ res_rot = rotmat_to_aa(tmp_rot)
+ pose_rotated[..., :3] = res_rot
+
+ # use cv2
+ # rot_mat = _construct_rotation_matrix(-rot)
+ # for i in range(pose.shape[0]):
+ # orient = pose[i, :3]
+ # # find the rotation of the body in camera frame
+ # per_rdg, _ = cv2.Rodrigues(orient.astype(np.float32))
+ # # apply the global rotation to the global orientation
+ # res_rot, _ = cv2.Rodrigues(np.dot(rot_mat, per_rdg))
+ # pose_rotated[i, :3] = (res_rot.T)[0]
+
+ return pose_rotated
+
+
+def _bbox_flip(bboxes, img_shape, direction):
+ """Flip bboxes horizontally.
+
+ Args:
+ bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k)
+ img_shape (tuple[int]): Image shape (height, width)
+ direction (str): Flip direction. Options are 'horizontal',
+ 'vertical'.
+
+ Returns:
+ numpy.ndarray: Flipped bounding boxes.
+ """
+
+ assert bboxes.shape[-1] % 5 == 0
+ flipped = bboxes.copy()
+ if direction == 'horizontal':
+ w = img_shape[1]
+ flipped[..., 0::4] = w - bboxes[..., 2::4]
+ flipped[..., 2::4] = w - bboxes[..., 0::4]
+ elif direction == 'vertical':
+ h = img_shape[0]
+ flipped[..., 1::4] = h - bboxes[..., 3::4]
+ flipped[..., 3::4] = h - bboxes[..., 1::4]
+ elif direction == 'diagonal':
+ w = img_shape[1]
+ h = img_shape[0]
+ flipped[..., 0::4] = w - bboxes[..., 2::4]
+ flipped[..., 1::4] = h - bboxes[..., 3::4]
+ flipped[..., 2::4] = w - bboxes[..., 0::4]
+ flipped[..., 3::4] = h - bboxes[..., 1::4]
+ else:
+ raise ValueError(f"Invalid flipping direction '{direction}'")
+ return flipped
+
+
+@PIPELINES.register_module()
+class RandomHorizontalFlip(object):
+ """Flip the image randomly.
+
+ Flip the image randomly based on flip probaility.
+
+ Args:
+ flip_prob (float): probability of the image being flipped. Default: 0.5
+ """
+ def __init__(self, flip_prob=0.5, convention=None):
+ assert 0 <= flip_prob <= 1
+ self.flip_prob = flip_prob
+ self.flip_pairs = get_flip_pairs(convention)
+
+ def __call__(self, results):
+ """Call function to flip image and annotations.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Flipped results, 'flip' key is added into
+ result dict.
+ """
+ if np.random.rand() > self.flip_prob:
+ results['is_flipped'] = np.array([0])
+ return results
+
+ results['is_flipped'] = np.array([1])
+
+ # flip image
+ for key in results.get('img_fields', ['img']):
+ results[key] = mmcv.imflip(results[key], direction='horizontal')
+
+ # flip keypoints2d
+ if 'keypoints2d' in results:
+ assert self.flip_pairs is not None
+ width = results['img'][:, ::-1, :].shape[1]
+ keypoints2d = results['keypoints2d'].copy()
+ keypoints2d = _flip_keypoints(keypoints2d, self.flip_pairs, width)
+ results['keypoints2d'] = keypoints2d
+ elif 'keypoints2d_ori' in results:
+ assert self.flip_pairs is not None
+ width = results['img'][:, ::-1, :].shape[1]
+ keypoints2d = results['keypoints2d_ori'].copy()
+ keypoints2d = _flip_keypoints(keypoints2d, self.flip_pairs, width)
+ results['keypoints2d_ori'] = keypoints2d
+
+ if 'keypoints2d_smpl' in results:
+ assert self.flip_pairs is not None
+ width = results['img'][:, ::-1, :].shape[1]
+ keypoints2d = results['keypoints2d_smpl'].copy()
+ keypoints2d = _flip_keypoints(keypoints2d, self.flip_pairs, width)
+ results['keypoints2d_smpl'] = keypoints2d
+
+ # flip bbox center
+ center = results['center']
+ center[..., 0] = width - 1 - center[..., 0]
+ results['center'] = center
+
+ # flip keypoints3d
+ if 'keypoints3d' in results:
+ assert self.flip_pairs is not None
+ keypoints3d = results['keypoints3d'].copy()
+ keypoints3d = _flip_keypoints(keypoints3d, self.flip_pairs)
+ results['keypoints3d'] = keypoints3d
+ elif 'keypoints3d_ori' in results:
+ assert self.flip_pairs is not None
+ keypoints3d = results['keypoints3d_ori'].copy()
+ keypoints3d = _flip_keypoints(keypoints3d, self.flip_pairs)
+ results['keypoints3d_ori'] = keypoints3d
+
+ if 'keypoints3d_smpl' in results:
+ assert self.flip_pairs is not None
+ keypoints3d = results['keypoints3d_smpl'].copy()
+ keypoints3d = _flip_keypoints(keypoints3d, self.flip_pairs)
+ results['keypoints3d_smpl'] = keypoints3d
+
+ if 'bbox_xywh' in results:
+ width = results['img'].shape[1]
+ bbox_xywh = results['bbox_xywh'].copy()
+ bbox_xyxy = xywh2xyxy(bbox_xywh)
+
+ bbox_xyxy = bbox_xyxy[:, [2, 1, 0, 3]] * np.array(
+ [-1, 1, -1, 1]) + np.array([width, 0, width, 0])
+
+ # img = mmcv.imshow_bboxes(results['img'], bbox_xyxy, show=False)
+ # cv2.imwrite('test.png',img)
+ results['bbox_xywh'] = xyxy2xywh(bbox_xyxy)
+
+ # flip smpl
+ if 'smpl_body_pose' in results:
+ global_orient = results['smpl_global_orient'].copy()
+ body_pose = results['smpl_body_pose'].copy().reshape((-1, 23 * 3))
+ smpl_pose = np.concatenate((global_orient, body_pose), axis=-1)
+ smpl_pose_flipped = _flip_smpl_pose(smpl_pose)
+ global_orient = smpl_pose_flipped[..., :3]
+ body_pose = smpl_pose_flipped[..., 3:]
+ results['smpl_global_orient'] = global_orient
+ results['smpl_body_pose'] = body_pose.reshape((-1, 23, 3))
+
+ # TODO: to check multi-human for smplx
+ if 'smplx_body_pose' in results:
+
+ body_pose = results['smplx_body_pose'].copy().reshape((-1))
+ body_pose_flipped = _flip_smplx_pose(body_pose)
+ results['smplx_body_pose'] = body_pose_flipped
+
+ if 'smplx_global_orient' in results:
+ global_orient = results['smplx_global_orient'].copy().reshape((-1))
+ global_orient_flipped = _flip_axis_angle(global_orient)
+ results['smplx_global_orient'] = global_orient_flipped
+
+ if 'smplx_jaw_pose' in results:
+ jaw_pose = results['smplx_jaw_pose'].copy().reshape((-1))
+ jaw_pose_flipped = _flip_axis_angle(jaw_pose)
+ results['smplx_jaw_pose'] = jaw_pose_flipped
+
+ if 'smplx_right_hand_pose' in results:
+ right_hand_pose = results['smplx_right_hand_pose'].copy()
+ left_hand_pose = results['smplx_left_hand_pose'].copy()
+ results['smplx_right_hand_pose'], results[
+ 'smplx_left_hand_pose'] = _flip_hand_pose(
+ right_hand_pose, left_hand_pose)
+
+ # Expressions are not symmetric. Remove them when flipped.
+ if 'smplx_expression' in results:
+ results['smplx_expression'] = np.zeros(
+ (results['smplx_expression'].shape[0]), dtype=np.float32)
+ results['has_smplx_expression'] = 0
+
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(flip_prob={self.flip_prob})'
+
+
+def resize(ori_shape, size, max_size=None):
+ # size can be min_size (scalar) or (w, h) tuple
+ # import ipdb; ipdb.set_trace(context=15)
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
+ w, h = image_size
+ if max_size is not None:
+ min_original_size = float(min((w, h)))
+ max_original_size = float(max((w, h)))
+ if max_original_size / min_original_size * size > max_size:
+ size = int(
+ round(max_size * min_original_size / max_original_size))
+
+ if (w <= h and w == size) or (h <= w and h == size):
+ return (w, h)
+
+ if w < h:
+ ow = size
+ oh = int(size * h / w)
+ else:
+ oh = size
+ ow = int(size * w / h)
+
+ return (ow, oh)
+
+ def get_size(ori_shape, size, max_size=None):
+ if isinstance(size, (list, tuple)):
+ return size[::-1]
+ else:
+ return get_size_with_aspect_ratio(ori_shape, size, max_size)
+
+ size = get_size(ori_shape, size, max_size)
+
+ return size
+
+
+@PIPELINES.register_module()
+class CenterCrop(object):
+ r"""Center crop the image.
+
+ Args:
+ crop_size (int | tuple): Expected size after cropping with the format
+ of (h, w).
+ efficientnet_style (bool): Whether to use efficientnet style center
+ crop. Defaults to False.
+ crop_padding (int): The crop padding parameter in efficientnet style
+ center crop. Only valid if efficientnet style is True. Defaults to
+ 32.
+ interpolation (str): Interpolation method, accepted values are
+ 'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Only valid if
+ efficientnet style is True. Defaults to 'bilinear'.
+ backend (str): The image resize backend type, accpeted values are
+ `cv2` and `pillow`. Only valid if efficientnet style is True.
+ Defaults to `cv2`.
+
+
+ Notes:
+ If the image is smaller than the crop size, return the original image.
+ If efficientnet_style is set to False, the pipeline would be a simple
+ center crop using the crop_size.
+ If efficientnet_style is set to True, the pipeline will be to first to
+ perform the center crop with the crop_size_ as:
+
+ .. math::
+ crop\_size\_ = crop\_size / (crop\_size + crop\_padding) * short\_edge
+
+ And then the pipeline resizes the img to the input crop size.
+ """
+ def __init__(self,
+ crop_size,
+ efficientnet_style=False,
+ crop_padding=32,
+ interpolation='bilinear',
+ backend='cv2'):
+ if efficientnet_style:
+ assert isinstance(crop_size, int)
+ assert crop_padding >= 0
+ assert interpolation in ('nearest', 'bilinear', 'bicubic', 'area',
+ 'lanczos')
+ if backend not in ['cv2', 'pillow']:
+ raise ValueError(
+ f'backend: {backend} is not supported for '
+ 'resize. Supported backends are "cv2", "pillow"')
+ else:
+ assert isinstance(crop_size, int) or (isinstance(crop_size, tuple)
+ and len(crop_size) == 2)
+ if isinstance(crop_size, int):
+ crop_size = (crop_size, crop_size)
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ self.crop_size = crop_size
+ self.efficientnet_style = efficientnet_style
+ self.crop_padding = crop_padding
+ self.interpolation = interpolation
+ self.backend = backend
+
+ def __call__(self, results):
+ crop_height, crop_width = self.crop_size[0], self.crop_size[1]
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ # img.shape has length 2 for grayscale, length 3 for color
+ img_height, img_width = img.shape[:2]
+
+ # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/preprocessing.py#L118 # noqa
+ if self.efficientnet_style:
+ img_short = min(img_height, img_width)
+ crop_height = crop_height / (crop_height +
+ self.crop_padding) * img_short
+ crop_width = crop_width / (crop_width +
+ self.crop_padding) * img_short
+
+ y1 = max(0, int(round((img_height - crop_height) / 2.)))
+ x1 = max(0, int(round((img_width - crop_width) / 2.)))
+ y2 = min(img_height, y1 + crop_height) - 1
+ x2 = min(img_width, x1 + crop_width) - 1
+
+ # crop the image
+ img = mmcv.imcrop(img, bboxes=np.array([x1, y1, x2, y2]))
+
+ if self.efficientnet_style:
+ img = mmcv.imresize(img,
+ tuple(self.crop_size[::-1]),
+ interpolation=self.interpolation,
+ backend=self.backend)
+ img_shape = img.shape
+ results[key] = img
+ results['img_shape'] = img_shape
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}'
+ repr_str += f', efficientnet_style={self.efficientnet_style}'
+ repr_str += f', crop_padding={self.crop_padding}'
+ repr_str += f', interpolation={self.interpolation}'
+ repr_str += f', backend={self.backend})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Normalize(object):
+ """Normalize the image.
+
+ Args:
+ mean (sequence): Mean values of 3 channels.
+ std (sequence): Std values of 3 channels.
+ to_rgb (bool): Whether to convert the image from BGR to RGB,
+ default is true.
+ """
+ def __init__(self, mean, std, to_rgb=True):
+ self.mean = np.array(mean, dtype=np.float32)
+ self.std = np.array(std, dtype=np.float32)
+ self.to_rgb = to_rgb
+
+ def __call__(self, results):
+ for key in results.get('img_fields', ['img']):
+ results[key] = mmcv.imnormalize(results[key], self.mean, self.std,
+ self.to_rgb)
+ results['img_norm_cfg'] = dict(mean=self.mean,
+ std=self.std,
+ to_rgb=self.to_rgb)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(mean={list(self.mean)}, '
+ repr_str += f'std={list(self.std)}, '
+ repr_str += f'to_rgb={self.to_rgb})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class ColorJitter(object):
+ """Randomly change the brightness, contrast and saturation of an image.
+
+ Args:
+ brightness (float): How much to jitter brightness.
+ brightness_factor is chosen uniformly from
+ [max(0, 1 - brightness), 1 + brightness].
+ contrast (float): How much to jitter contrast.
+ contrast_factor is chosen uniformly from
+ [max(0, 1 - contrast), 1 + contrast].
+ saturation (float): How much to jitter saturation.
+ saturation_factor is chosen uniformly from
+ [max(0, 1 - saturation), 1 + saturation].
+ """
+ def __init__(self, brightness, contrast, saturation):
+ self.brightness = brightness
+ self.contrast = contrast
+ self.saturation = saturation
+
+ def __call__(self, results):
+ brightness_factor = random.uniform(0, self.brightness)
+ contrast_factor = random.uniform(0, self.contrast)
+ saturation_factor = random.uniform(0, self.saturation)
+ color_jitter_transforms = [
+ dict(type='Brightness',
+ magnitude=brightness_factor,
+ prob=1.,
+ random_negative_prob=0.5),
+ dict(type='Contrast',
+ magnitude=contrast_factor,
+ prob=1.,
+ random_negative_prob=0.5),
+ dict(type='ColorTransform',
+ magnitude=saturation_factor,
+ prob=1.,
+ random_negative_prob=0.5)
+ ]
+ random.shuffle(color_jitter_transforms)
+ transform = Compose(color_jitter_transforms)
+ return transform(results)
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(brightness={self.brightness}, '
+ repr_str += f'contrast={self.contrast}, '
+ repr_str += f'saturation={self.saturation})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Lighting(object):
+ """Adjust images lighting using AlexNet-style PCA jitter.
+
+ Args:
+ eigval (list): the eigenvalue of the convariance matrix of pixel
+ values, respectively.
+ eigvec (list[list]): the eigenvector of the convariance matrix of pixel
+ values, respectively.
+ alphastd (float): The standard deviation for distribution of alpha.
+ Defaults to 0.1
+ to_rgb (bool): Whether to convert img to rgb.
+ """
+ def __init__(self, eigval, eigvec, alphastd=0.1, to_rgb=True):
+ assert isinstance(eigval, list), \
+ f'eigval must be of type list, got {type(eigval)} instead.'
+ assert isinstance(eigvec, list), \
+ f'eigvec must be of type list, got {type(eigvec)} instead.'
+ for vec in eigvec:
+ assert isinstance(vec, list) and len(vec) == len(eigvec[0]), \
+ 'eigvec must contains lists with equal length.'
+ self.eigval = np.array(eigval)
+ self.eigvec = np.array(eigvec)
+ self.alphastd = alphastd
+ self.to_rgb = to_rgb
+
+ def __call__(self, results):
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ results[key] = mmcv.adjust_lighting(img,
+ self.eigval,
+ self.eigvec,
+ alphastd=self.alphastd,
+ to_rgb=self.to_rgb)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(eigval={self.eigval.tolist()}, '
+ repr_str += f'eigvec={self.eigvec.tolist()}, '
+ repr_str += f'alphastd={self.alphastd}, '
+ repr_str += f'to_rgb={self.to_rgb})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomChannelNoise:
+ """Data augmentation with random channel noise.
+
+ Required keys: 'img'
+ Modifies key: 'img'
+ Args:
+ noise_factor (float): Multiply each channel with
+ a factor between``[1-scale_factor, 1+scale_factor]``
+ """
+ def __init__(self, noise_factor=0.4):
+ self.noise_factor = noise_factor
+
+ def __call__(self, results):
+ """Perform data augmentation with random channel noise."""
+ img = results['img']
+
+ # Each channel is multiplied with a number
+ # in the area [1-self.noise_factor, 1+self.noise_factor]
+ pn = np.random.uniform(1 - self.noise_factor, 1 + self.noise_factor,
+ (1, 3))
+ img = cv2.multiply(img, pn)
+
+ results['img'] = img
+
+ if 'ori_img' in results:
+ img = results['ori_img']
+ img = cv2.multiply(img, pn)
+
+ results['ori_img'] = img
+
+ return results
+
+
+@PIPELINES.register_module()
+class GetRandomScaleRotation:
+ """Data augmentation with random scaling & rotating.
+
+ Required key: 'scale'. Modifies key: 'scale' and 'rotation'.
+ Args:
+ rot_factor (int): Rotating to ``[-2*rot_factor, 2*rot_factor]``.
+ scale_factor (float): Scaling to ``[1-scale_factor, 1+scale_factor]``.
+ rot_prob (float): Probability of random rotation.
+ """
+ def __init__(self, rot_factor=30, scale_factor=0.25, rot_prob=0.6):
+ self.rot_factor = rot_factor
+ self.scale_factor = scale_factor
+ self.rot_prob = rot_prob
+
+ def __call__(self, results):
+ """Perform data augmentation with random scaling & rotating."""
+ s = results['scale']
+
+ sf = self.scale_factor
+ rf = self.rot_factor
+
+ s_factor = np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
+ s = s * s_factor
+
+ r_factor = np.clip(np.random.randn() * rf, -rf * 2, rf * 2)
+ r = r_factor if np.random.rand() <= self.rot_prob else 0.0
+
+ results['scale'] = s
+ results['rotation'] = r
+
+ return results
+
+
+@PIPELINES.register_module()
+class SampleInstance:
+ def __init__(self, sample_ratio):
+ self.sample_ratio = sample_ratio
+
+ def __call__(self, results):
+
+
+ assert 'bbox_xywh' in results
+ bbox_xywh = results['bbox_xywh'].copy()
+ crop_person_number = len(bbox_xywh)
+ if random.random() < self.sample_ratio:
+ crop_person_number = np.random.randint(len(bbox_xywh)) + 1
+
+ sample_ids = np.array(
+ random.sample(list(range(len(bbox_xywh))), crop_person_number))
+
+ bbox_xyxy = xywh2xyxy(bbox_xywh)[sample_ids]
+
+ leftTop_ = bbox_xyxy[:, :2]
+ leftTop_ = np.array([np.min(leftTop_[:, 0]), np.min(leftTop_[:, 1])])
+ rightBottom_ = bbox_xyxy[:, 2:4]
+ rightBottom_ = np.array(
+ [np.max(rightBottom_[:, 0]),
+ np.max(rightBottom_[:, 1])])
+ bbox_xyxy = np.concatenate([leftTop_, rightBottom_])
+ results['bbox_xyxy'] = bbox_xyxy
+ center = (rightBottom_ + leftTop_) / 2
+ scale = (rightBottom_ - leftTop_)
+ scale[0] = scale[1] = max(scale)
+ results['center'] = center
+ results['scale'] = scale
+ return results
+
+
+
+
+@PIPELINES.register_module()
+class MeshAffine:
+ """Affine transform the image to get input image.
+
+ Affine transform the 2D keypoints, 3D kepoints. Required keys: 'img',
+ 'pose', 'img_shape', 'rotation' and 'center'. Modifies key: 'img',
+ ''keypoints2d', 'keypoints3d', 'pose'.
+ """
+ def __init__(self, img_res, crop_with_bbox=True):
+ if isinstance(img_res, tuple):
+ self.image_size = img_res
+ else:
+ self.image_size = np.array([img_res, img_res])
+ self.img_res = img_res
+ self.crop_with_bbox = crop_with_bbox
+
+ def __call__(self, results):
+
+ c = results['center']
+ s = results['scale']
+ r = results['rotation']
+
+ trans = get_affine_transform(c, s, r, self.image_size)
+
+ if 'img' in results:
+ img = results['img'].copy()
+
+ # img before affine
+ ori_img = img.copy()
+ results['crop_transform'] = trans
+ results['ori_img'] = ori_img
+ results['img_fields'] = ['img', 'ori_img']
+
+ img = cv2.warpAffine(
+ img,
+ trans, (int(self.image_size[0]), int(self.image_size[1])),
+ flags=cv2.INTER_LINEAR)
+ results['img'] = img
+
+ if 'keypoints2d' in results:
+ keypoints2d = results['keypoints2d'].copy()
+
+ results['keypoints2d'][..., :2] = affine_transform(
+ keypoints2d, trans)
+ if 'bbox_xywh' in results:
+ bbox_xywh = results['bbox_xywh'].copy()
+
+ leftTop = bbox_xywh[..., :2]
+ rightTop = np.concatenate([
+ bbox_xywh[..., [0]] + bbox_xywh[..., [2]], bbox_xywh[..., [1]]
+ ], -1)
+ leftBottom = np.concatenate([
+ bbox_xywh[..., [0]], bbox_xywh[..., [1]] + bbox_xywh[..., [3]]
+ ], -1)
+ rightBottom = np.concatenate([
+ bbox_xywh[..., [0]] + bbox_xywh[..., [2]],
+ bbox_xywh[..., [1]] + bbox_xywh[..., [3]]
+ ], -1)
+
+ bbox_point = np.vstack(
+ [leftTop, rightTop, leftBottom, rightBottom])
+ bbox_point = np.concatenate(
+ [bbox_point, np.ones_like(bbox_point[..., [0]])], -1)
+ bbox_point = affine_transform(bbox_point, trans)
+ # TODO:
+ bbox_point = np.clip(bbox_point, 0, self.img_res)
+ bbox__xywh_t = bbox_point.clone()
+ bbox__xywh_t
+ results['bbox'] = bbox_point
+
+ # bbox_xyxy = xywh2xyxy(bbox_xywh)[:,:4].reshape(-1, 2, 2)
+ # bbox_xyxy = np.concatenate([bbox_xyxy, np.ones_like(bbox_xyxy[...,[0]])], -1)
+ # bbox_xyxy = np.concatenate([affine_transform(bbox_xyxy, trans).reshape(-1,4), bbox_xywh[...,[-1]]],-1)
+ # results['bbox_xywh'] = xyxy2xywh(bbox_xyxy)
+
+ # image_array=np.array([img]),
+ # overwrite=True,
+ # data_source='smpl_54')
+ if 'keypoints3d' in results:
+ keypoints3d = results['keypoints3d'].copy()
+ keypoints3d[..., :3] = _rotate_joints_3d(keypoints3d[..., :3], r)
+ results['keypoints3d'] = keypoints3d
+
+ if 'smpl_body_pose' in results:
+ global_orient = results['smpl_global_orient'].copy()
+ body_pose = results['smpl_body_pose'].copy().reshape((-1, 23 * 3))
+ pose = np.concatenate((global_orient, body_pose), axis=-1)
+ pose = _rotate_smpl_pose(pose, r)
+ results['smpl_global_orient'] = pose[..., :3]
+ results['smpl_body_pose'] = pose[..., 3:].reshape((-1, 23, 3))
+
+ if 'smplx_global_orient' in results:
+ global_orient = results['smplx_global_orient'].copy()
+ global_orient = _rotate_smpl_pose(global_orient, r)
+ results['smplx_global_orient'] = global_orient
+
+ return results
+
+
+@PIPELINES.register_module()
+class MeshAffineED:
+ """Affine transform the image to get input image.
+
+ Affine transform the 2D keypoints, 3D kepoints. Required keys: 'img',
+ 'pose', 'img_shape', 'rotation' and 'center'. Modifies key: 'img',
+ ''keypoints2d', 'keypoints3d', 'pose'.
+ """
+ def __init__(self, sizes, max_size=None):
+ assert isinstance(sizes, (list, tuple))
+ self.sizes = sizes
+ self.max_size = max_size
+
+ def __call__(self, results):
+ ori_shape = np.array(results['ori_shape'])
+ # ori_shape = ori_shape[::-1]
+ # print(ori_shape)
+ size = random.choice(self.sizes)
+ reshape_size = resize(ori_shape, size, self.max_size)
+ c = (ori_shape / 2)[::-1]
+ s = ori_shape[::-1]
+ r = results['rotation']
+
+ trans = get_affine_transform(c, s, r, reshape_size[::-1])
+
+ results['img_shape'] = reshape_size
+ if 'img' in results:
+ img = results['img'].copy()
+
+ # img before affine
+ ori_img = img.copy()
+ results['crop_transform'] = trans
+ results['ori_img'] = ori_img
+ results['img_fields'] = ['img', 'ori_img']
+
+ img = cv2.warpAffine(img,
+ trans,
+ (int(reshape_size[1]), int(reshape_size[0])),
+ flags=cv2.INTER_LINEAR)
+ results['img'] = img
+
+ if 'keypoints2d_ori' in results:
+ keypoints2d_ori = results['keypoints2d_ori'].copy()
+
+ results['keypoints2d_ori'][..., :2] = affine_transform(
+ keypoints2d_ori, trans)
+
+ if 'keypoints2d_smpl' in results:
+ keypoints2d_smpl = results['keypoints2d_smpl'].copy()
+
+ results['keypoints2d_smpl'][..., :2] = affine_transform(
+ keypoints2d_smpl, trans)
+
+ if 'bbox_xywh' in results:
+ bbox_xywh = results['bbox_xywh'].copy()
+
+ leftTop = bbox_xywh[..., :2]
+ rightTop = np.concatenate([
+ bbox_xywh[..., [0]] + bbox_xywh[..., [2]], bbox_xywh[..., [1]]
+ ], -1)
+ leftBottom = np.concatenate([
+ bbox_xywh[..., [0]], bbox_xywh[..., [1]] + bbox_xywh[..., [3]]
+ ], -1)
+ rightBottom = np.concatenate([
+ bbox_xywh[..., [0]] + bbox_xywh[..., [2]],
+ bbox_xywh[..., [1]] + bbox_xywh[..., [3]]
+ ], -1)
+
+ bbox_point = np.vstack(
+ [leftTop, rightTop, leftBottom, rightBottom])
+ bbox_point = np.concatenate(
+ [bbox_point, np.ones_like(bbox_point[..., [0]])], -1)
+ bbox_point = affine_transform(bbox_point, trans)
+ # TODO:
+
+ bbox_point = np.clip(bbox_point, 0,
+ (int(reshape_size[1]), int(reshape_size[0])))
+ results['bbox'] = bbox_point
+
+ bbox_xyxy_t = bbox_xywh.copy()
+ num_sample = bbox_xywh.shape[0]
+ bbox_xyxy_t[..., :2] = bbox_point[:num_sample, :]
+ bbox_xyxy_t[...,
+ 2:4] = bbox_point[num_sample * 3:num_sample * 4, :]
+
+ results['bbox_xywh'] = xyxy2xywh(bbox_xyxy_t)
+ # bbox_xywh = results['bbox_xywh'].copy()
+ # bbox_xyxy = xywh2xyxy(bbox_xywh)[:,:4].reshape(-1, 2, 2)
+ # bbox_xyxy = np.concatenate([bbox_xyxy, np.ones_like(bbox_xyxy[...,[0]])], -1)
+ # bbox_xyxy = np.concatenate([affine_transform(bbox_xyxy, trans).reshape(-1,4), bbox_xywh[...,[-1]]],-1)
+ # results['bbox_xywh'] = xyxy2xywh(bbox_xyxy)
+
+ # image_array=np.array([img]),
+ # overwrite=True,
+ # data_source='smpl_54')
+ if 'keypoints3d_ori' in results:
+ keypoints3d_ori = results['keypoints3d_ori'].copy()
+ keypoints3d_ori[..., :3] = _rotate_joints_3d(
+ keypoints3d_ori[..., :3], r)
+ results['keypoints3d_ori'] = keypoints3d_ori
+
+ if 'keypoints3d_smpl' in results:
+ keypoints3d_smpl = results['keypoints3d_smpl'].copy()
+ keypoints3d_smpl[..., :3] = _rotate_joints_3d(
+ keypoints3d_smpl[..., :3], r)
+ results['keypoints3d_smpl'] = keypoints3d_smpl
+
+ if 'smpl_body_pose' in results:
+ global_orient = results['smpl_global_orient'].copy()
+ body_pose = results['smpl_body_pose'].copy().reshape((-1, 23 * 3))
+ pose = np.concatenate((global_orient, body_pose), axis=-1)
+ pose = _rotate_smpl_pose(pose, r)
+ results['smpl_global_orient'] = pose[..., :3]
+ results['smpl_body_pose'] = pose[..., 3:].reshape((-1, 23, 3))
+
+ if 'area' in results:
+ area = results['area'] * (trans[0, 0] * trans[1, 1])
+ results['area'] = area
+ # if 'smplx_global_orient' in results:
+ # global_orient = results['smplx_global_orient'].copy()
+ # global_orient = _rotate_smpl_pose(global_orient, r)
+ # results['smplx_global_orient'] = global_orient
+
+ return results
+
+
+@PIPELINES.register_module()
+class Rotation:
+ """Rotate the image with the given rotation.
+
+ Rotate the 2D keypoints, 3D kepoints, poses. Required keys: 'img',
+ 'pose', 'rotation' and 'center'. Modifies key: 'img',
+ ''keypoints2d', 'keypoints3d', 'pose'.
+
+ To avoid conflicts with MeshAffine, rotation will be set to 0.0
+ after rotate the image.
+ The rotation value will be stored to 'ori_rotation'.
+ """
+ def __init__(self):
+ pass
+
+ def __call__(self, results):
+ r = results['rotation']
+ if r == 0.0:
+ return results
+ img = results['img']
+
+ # img before affine
+ (h, w) = img.shape[:2]
+ (cX, cY) = (w // 2, h // 2)
+ M = cv2.getRotationMatrix2D((cX, cY), r, 1.0)
+ cos = np.abs(M[0, 0])
+ sin = np.abs(M[0, 1])
+ # compute the new bounding dimensions of the image
+ nW = int((h * sin) + (w * cos))
+ nH = int((h * cos) + (w * sin))
+ # adjust the rotation matrix to take into account translation
+ M[0, 2] += (nW / 2) - cX
+ M[1, 2] += (nH / 2) - cY
+ # perform the actual rotation and return the image
+ img = cv2.warpAffine(img, M, (nW, nH))
+
+ results['img'] = img
+
+ c = results['center']
+ c = np.dot(M[:2, :2], c) + M[:2, 2]
+ results['center'] = c
+
+ if 'keypoints2d' in results:
+ keypoints2d = results['keypoints2d'].copy()
+ keypoints2d[:, :2] = (np.dot(keypoints2d[:, :2], M[:2, :2].T) +
+ M[:2, 2] + 1).astype(np.int)
+ results['keypoints2d'] = keypoints2d
+
+ if 'keypoints3d' in results:
+ keypoints3d = results['keypoints3d'].copy()
+ keypoints3d[:, :3] = _rotate_joints_3d(keypoints3d[:, :3], r)
+ results['keypoints3d'] = keypoints3d
+
+ if 'smpl_body_pose' in results:
+ global_orient = results['smpl_global_orient'].copy()
+ body_pose = results['smpl_body_pose'].copy().reshape((-1))
+ pose = np.concatenate((global_orient, body_pose), axis=-1)
+ pose = _rotate_smpl_pose(pose, r)
+ results['smpl_global_orient'] = pose[:3]
+ results['smpl_body_pose'] = pose[3:].reshape((-1, 3))
+
+ if 'smplx_global_orient' in results:
+ global_orient = results['smplx_global_orient'].copy()
+ global_orient = _rotate_smpl_pose(global_orient, r)
+ results['smplx_global_orient'] = global_orient
+
+ results['rotation'] = 0.0
+ results['ori_rotation'] = r
+ return results
+
+
+@PIPELINES.register_module()
+class BBoxCenterJitter(object):
+ def __init__(self, factor=0.0, dist='normal'):
+ super(BBoxCenterJitter, self).__init__()
+ self.factor = factor
+ self.dist = dist
+ assert self.dist in [
+ 'normal', 'uniform'
+ ], (f'Distribution must be normal or uniform, not {self.dist}')
+
+ def __call__(self, results):
+ # body model: no process
+ if self.factor <= 1e-3:
+ return results
+
+ bbox_size = results['scale'][0]
+
+ jitter = bbox_size * self.factor
+
+ if self.dist == 'normal':
+ center_jitter = np.random.randn(2) * jitter
+ elif self.dist == 'uniform':
+ center_jitter = np.random.rand(2) * 2 * jitter - jitter
+
+ center = results['center']
+ H, W = results['img_shape']
+ new_center = center + center_jitter
+ new_center[0] = np.clip(new_center[0], 0, W)
+ new_center[1] = np.clip(new_center[1], 0, H)
+
+ results['center'] = new_center
+ return results
+
+
+@PIPELINES.register_module()
+class SimulateLowRes(object):
+ def __init__(self,
+ dist: str = 'categorical',
+ factor: float = 1.0,
+ cat_factors=(1.0, ),
+ factor_min: float = 1.0,
+ factor_max: float = 1.0) -> None:
+ self.factor_min = factor_min
+ self.factor_max = factor_max
+ self.dist = dist
+ self.cat_factors = cat_factors
+ assert dist in ['uniform', 'categorical']
+
+ def _sample_low_res(self, image: np.ndarray) -> np.ndarray:
+ """"""
+ if self.dist == 'uniform':
+ downsample = self.factor_min != self.factor_max
+ if not downsample:
+ return image
+ factor = np.random.rand() * (self.factor_max -
+ self.factor_min) + self.factor_min
+ elif self.dist == 'categorical':
+ if len(self.cat_factors) < 2:
+ return image
+ idx = np.random.randint(0, len(self.cat_factors))
+ factor = self.cat_factors[idx]
+
+ H, W, _ = image.shape
+ downsampled_image = cv2.resize(image,
+ (int(W // factor), int(H // factor)),
+ cv2.INTER_NEAREST)
+ resized_image = cv2.resize(downsampled_image, (W, H),
+ cv2.INTER_LINEAR_EXACT)
+ return resized_image
+
+ def __call__(self, results):
+ """"""
+ img = results['img']
+ img = self._sample_low_res(img)
+ results['img'] = img
+
+ return results
diff --git a/detrsmpl/data/datasets/samplers/__init__.py b/detrsmpl/data/datasets/samplers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cffe4dcb0fa663750bd88941e06ad336a43b527f
--- /dev/null
+++ b/detrsmpl/data/datasets/samplers/__init__.py
@@ -0,0 +1,3 @@
+from .distributed_sampler import DistributedSampler
+
+__all__ = ['DistributedSampler']
diff --git a/detrsmpl/data/datasets/samplers/distributed_sampler.py b/detrsmpl/data/datasets/samplers/distributed_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..2388e072cba588db134c224dd04e20ee20c9bbbd
--- /dev/null
+++ b/detrsmpl/data/datasets/samplers/distributed_sampler.py
@@ -0,0 +1,41 @@
+import torch
+from torch.utils.data import DistributedSampler as _DistributedSampler
+
+
+class DistributedSampler(_DistributedSampler):
+ def __init__(self,
+ dataset,
+ num_replicas=None,
+ rank=None,
+ shuffle=True,
+ round_up=True):
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank)
+ self.shuffle = shuffle
+ self.round_up = round_up
+ if self.round_up:
+ self.total_size = self.num_samples * self.num_replicas
+ else:
+ self.total_size = len(self.dataset)
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ if self.shuffle:
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = torch.arange(len(self.dataset)).tolist()
+
+ # add extra samples to make it evenly divisible
+ if self.round_up:
+ indices = (
+ indices *
+ int(self.total_size / len(indices) + 1))[:self.total_size]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ if self.round_up:
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
diff --git a/detrsmpl/models/__init__.py b/detrsmpl/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/models/architectures/DetrSMPL.py b/detrsmpl/models/architectures/DetrSMPL.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f5d8eaa960578f07f0e5b1a548dc34dfe0a0ec0
--- /dev/null
+++ b/detrsmpl/models/architectures/DetrSMPL.py
@@ -0,0 +1,771 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta
+from typing import Optional, Union
+
+import torch
+from scipy.optimize import linear_sum_assignment
+import numpy as np
+from detrsmpl.core.post_processing.bbox.assigners import build_assigner
+from detrsmpl.core.post_processing.bbox.samplers import build_sampler
+from detrsmpl.core.conventions.keypoints_mapping import (get_keypoint_idx,
+ convert_kps)
+from detrsmpl.utils.geometry import batch_rodrigues
+from detrsmpl.utils.geometry import project_points
+from detrsmpl.utils.misc import multi_apply
+from ..backbones.builder import build_backbone
+from ..body_models.builder import build_body_model
+from ..heads.builder import build_head
+from ..losses.builder import build_loss
+from ..necks.builder import build_neck
+from .base_architecture import BaseArchitecture
+
+# from mmdet.core import bbox2result
+
+
+class MultiBodyEstimator(BaseArchitecture, metaclass=ABCMeta):
+ def __init__(
+ self,
+ backbone: Optional[Union[dict, None]] = None,
+ neck: Optional[Union[dict, None]] = None,
+ head: Optional[Union[dict, None]] = None,
+ disc: Optional[Union[dict, None]] = None,
+ registration: Optional[Union[dict, None]] = None,
+ body_model_train: Optional[Union[dict, None]] = None,
+ body_model_test: Optional[Union[dict, None]] = None,
+ convention: Optional[str] = 'human_data',
+ loss_keypoints2d: Optional[Union[dict, None]] = None,
+ loss_keypoints3d: Optional[Union[dict, None]] = None,
+ loss_vertex: Optional[Union[dict, None]] = None,
+ loss_smpl_pose: Optional[Union[dict, None]] = None,
+ loss_smpl_betas: Optional[Union[dict, None]] = None,
+ loss_camera: Optional[Union[dict, None]] = None,
+ loss_cls: Optional[Union[dict,
+ None]] = dict(type='CrossEntropyLoss',
+ bg_cls_weight=0.1,
+ use_sigmoid=False,
+ loss_weight=1.0,
+ class_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=5.0),
+ loss_iou=dict(type='GIoULoss', loss_weight=2.0),
+ init_cfg: Optional[Union[list, dict, None]] = None,
+ train_cfg:
+ Optional[Union[dict, None]] = dict(assigner=dict(
+ type='HungarianAssigner',
+ kp3d_cost=dict(
+ type='Keypoints3DCost', convention='smpl_54', weight=5.0),
+ kp2d_cost=dict(
+ type='Keypoints2DCost', convention='smpl_54', weight=5.0),
+ # cls_cost=dict(type='ClassificationCost', weight=1.),
+ # reg_cost=dict(type='BBoxL1Cost', weight=5.0),
+ # iou_cost=dict(
+ # type='IoUCost', iou_mode='giou', weight=2.0))
+ )),
+ test_cfg: Optional[Union[dict, None]] = None):
+
+ super(MultiBodyEstimator, self).__init__(init_cfg)
+ self.backbone = build_backbone(backbone)
+ if neck is not None:
+ self.neck = build_neck(neck)
+ head.update(train_cfg=train_cfg)
+ head.update(test_cfg=test_cfg)
+ self.head = build_head(head)
+ # class_weight = loss_cls.get('class_weight', None)
+ if train_cfg:
+ assert 'assigner' in train_cfg, 'assigner should be provided '\
+ 'when train_cfg is set.'
+ assigner = train_cfg['assigner']
+ # TODO: update these
+ # assert loss_cls['loss_weight'] == assigner['kp3d_cost']['weight'], \
+ # 'The classification weight for loss and matcher should be' \
+ # 'exactly the same.'
+ # assert loss_bbox['loss_weight'] == assigner['kp3d_cost'][
+ # 'weight'], 'The regression L1 weight for loss and matcher ' \
+ # 'should be exactly the same.'
+ # assert loss_iou['loss_weight'] == assigner['kp3d_cost']['weight'], \
+ # 'The regression iou weight for loss and matcher should be' \
+ # 'exactly the same.'
+ self.assigner = build_assigner(assigner)
+ # DETR sampling=False, so use PseudoSampler
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ # build loss
+ self.loss_keypoints2d = build_loss(loss_keypoints2d)
+ self.loss_keypoints3d = build_loss(loss_keypoints3d)
+ self.loss_vertex = build_loss(loss_vertex)
+ self.loss_smpl_pose = build_loss(loss_smpl_pose)
+ self.loss_smpl_betas = build_loss(loss_smpl_betas)
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox = build_loss(loss_bbox)
+ self.loss_iou = build_loss(loss_iou)
+
+ self.body_model_train = build_body_model(body_model_train)
+ self.body_model_test = build_body_model(body_model_test)
+ self.convention = convention
+
+ def extract_feat(self, img):
+ """Directly extract features from the backbone+neck."""
+ x = self.backbone(img)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def forward_dummy(self, img):
+ """Used for computing network flops.
+
+ See `mmdetection/tools/analysis_tools/get_flops.py`
+ """
+ x = self.extract_feat(img)
+ outs = self.bbox_head(x)
+ return outs
+
+ def forward_train(self, img, img_metas, **kwargs):
+ """
+ Args:
+ img (Tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ gt_bboxes (list[Tensor]): Each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ # super(SingleStageDetector, self).forward_train(img, img_metas)
+ # NOTE the batched image size information may be useful, e.g.
+ # in DETR, this is needed for the construction of masks, which is
+ # then used for the transformer_head.
+
+ has_smpl = kwargs['has_smpl']
+ gt_smpl_body_pose = kwargs[
+ 'smpl_body_pose'] # [bs_0: [ins_num, 23, 3]]
+ gt_smpl_global_orient = kwargs['smpl_global_orient']
+ gt_smpl_body_pose = \
+ [torch.cat((gt_smpl_global_orient[i].view(-1, 1, 3),
+ gt_smpl_body_pose[i]), dim=1).float()
+ for i in range(len(gt_smpl_body_pose))]
+ gt_smpl_betas = kwargs['smpl_betas']
+ gt_smpl_transl = kwargs['smpl_transl']
+ gt_keypoints2d = kwargs['keypoints2d']
+ gt_keypoints3d = kwargs['keypoints3d'] # [bs_0: [N. K, D], ...]
+
+ if 'has_keypoints3d' in kwargs:
+ has_keypoints3d = kwargs['has_keypoints3d']
+ else:
+ has_keypoints3d = None
+
+ if 'has_keypoints2d' in kwargs:
+ has_keypoints2d = kwargs['has_keypoints2d']
+ else:
+ has_keypoints2d = None
+
+ batch_input_shape = tuple(img[0].size()[-2:])
+ for img_meta in img_metas:
+ img_meta['batch_input_shape'] = batch_input_shape
+
+ # features = self.extract_feat(img)
+ features = self.backbone(img)
+
+ if self.neck is not None:
+ features = self.neck(features)
+
+ # outputs_classes, outputs_coords,
+ pred_pose, \
+ pred_betas, pred_cameras, _, _ = self.head(features, img_metas)
+
+ L, B, N = pred_pose.shape[:3]
+ if self.body_model_train is not None:
+ pred_output = self.body_model_train(
+ betas=pred_betas.reshape(L * B * N, 10),
+ body_pose=pred_pose.reshape(L * B * N, 24, 3, 3)[:, 1:],
+ global_orient=pred_pose.reshape(L * B * N, 24, 3,
+ 3)[:, 0].unsqueeze(1),
+ pose2rot=False,
+ num_joints=gt_keypoints2d[0].shape[1])
+ pred_keypoints3d = pred_output['joints'].reshape(L, B, N, -1, 3)
+ pred_vertices = pred_output['vertices'].reshape(L, B, N, 6890, 3)
+ # loss
+ num_dec_layers = pred_pose.shape[0]
+
+ all_gt_smpl_body_pose_list = [
+ gt_smpl_body_pose for _ in range(num_dec_layers)
+ ]
+ all_gt_smpl_global_orient_list = [
+ gt_smpl_global_orient for _ in range(num_dec_layers)
+ ]
+ all_gt_smpl_betas_list = [gt_smpl_betas for _ in range(num_dec_layers)]
+ all_gt_smpl_transl_list = [
+ gt_smpl_transl for _ in range(num_dec_layers)
+ ]
+ all_gt_keypoints2d_list = [
+ gt_keypoints2d for _ in range(num_dec_layers)
+ ]
+ all_gt_keypoints3d_list = [
+ gt_keypoints3d for _ in range(num_dec_layers)
+ ]
+ all_has_smpl_list = [has_smpl for _ in range(num_dec_layers)]
+ all_has_keypoints3d_list = [
+ has_keypoints3d for _ in range(num_dec_layers)
+ ]
+ all_has_keypoints2d_list = [
+ has_keypoints2d for _ in range(num_dec_layers)
+ ]
+ all_gt_ignore_list = [None for _ in range(num_dec_layers)]
+ img_metas_list = [img_metas for _ in range(num_dec_layers)]
+ # all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
+ # all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
+ # all_gt_bboxes_ignore_list = [
+ # gt_bboxes_ignore for _ in range(num_dec_layers)
+ # ]
+ # computer loss for each layer
+ (kp2d_loss, kp3d_loss, vert_loss, pose_loss, beta_loss) = multi_apply(
+ self.compute_losses, pred_pose, pred_betas, pred_keypoints3d,
+ pred_vertices, pred_cameras, all_gt_smpl_body_pose_list,
+ all_gt_smpl_betas_list, all_gt_keypoints2d_list,
+ all_gt_keypoints3d_list, all_has_keypoints2d_list,
+ all_has_keypoints3d_list, all_has_smpl_list, img_metas_list,
+ all_gt_ignore_list)
+
+ losses = {}
+ losses['keypoints2d_loss'] = kp2d_loss[-1]
+ losses['keypoints3d_loss'] = kp3d_loss[-1]
+ losses['vertex_loss'] = vert_loss[-1]
+ losses['smpl_pose_loss'] = pose_loss[-1]
+ losses['smpl_betas_loss'] = beta_loss[-1]
+
+ # loss from other decoder layers
+ num_dec_layer = 0
+ for (kp2d_loss_i, kp3d_loss_i, vert_loss_i, pose_loss_i,
+ beta_loss_i) in zip(kp2d_loss[:-1], kp3d_loss[:-1],
+ vert_loss[:-1], pose_loss[:-1],
+ beta_loss[:-1]):
+ losses[f'd{num_dec_layer}.keypoints2d_loss'] = kp2d_loss_i
+ losses[f'd{num_dec_layer}.keypoints3d_loss'] = kp3d_loss_i
+ losses[f'd{num_dec_layer}.vertex_loss'] = vert_loss_i
+ losses[f'd{num_dec_layer}.smpl_pose_loss'] = pose_loss_i
+ losses[f'd{num_dec_layer}.smpl_betas_loss'] = beta_loss_i
+ num_dec_layer += 1
+
+ return losses
+
+ def compute_losses(self,
+ outputs_poses,
+ outputs_shapes,
+ outputs_kp3ds,
+ outputs_verts,
+ outputs_cameras,
+ all_gt_smpl_body_pose_list,
+ all_gt_smpl_betas_list,
+ all_gt_kp2d_list,
+ all_gt_kp3d_list,
+ all_has_keypoints2d_list,
+ all_has_keypoints3d_list,
+ all_has_smpl_list,
+ img_metas_list,
+ all_gt_ignore_list=None):
+ """_summary_
+ loss_single
+ get_targets
+ Args:
+ outputs_poses (_type_): with shape [B, N, 24, 3, 3]
+ outputs_shapes (_type_): _description_
+ all_gt_smpl_body_pose_list (_type_): _description_
+ all_gt_smpl_betas_list (_type_): _description_
+ all_gt_kp2d_list (Torch.tensor):
+ all_gt_kp3d_list (list): with shape [B, N, K, D]
+ img_metas_list (_type_): _description_
+ all_gt_ignore_list (_type_): _description_
+ """
+ num_img = outputs_poses.size(0) # batch_size
+ all_pred_smpl_pose_list = [outputs_poses[i] for i in range(num_img)]
+ all_pred_smpl_shape_list = [outputs_shapes[i] for i in range(num_img)]
+ all_pred_kp3d_list = [outputs_kp3ds[i] for i in range(num_img)]
+ all_pred_vert_list = [outputs_verts[i] for i in range(num_img)]
+ all_pred_cam_list = [outputs_cameras[i] for i in range(num_img)]
+
+ gt_bboxes_ignore_list = [all_gt_ignore_list for _ in range(num_img)]
+
+ if all_has_keypoints2d_list is None:
+ all_has_keypoints2d_list = [
+ all_has_keypoints2d_list for _ in range(num_img)
+ ]
+
+ if all_has_keypoints3d_list is None:
+ all_has_keypoints3d_list = [
+ all_has_keypoints3d_list for _ in range(num_img)
+ ]
+
+ if all_has_smpl_list is None:
+ all_has_smpl_list = [all_has_smpl_list for _ in range(num_img)]
+
+ # for each batch data
+ (kp2d_list, kp2d_weight_list, kp3d_list, kp3d_weight_list,
+ smpl_pose_list, smpl_pose_weight_list, smpl_shape_list,
+ smpl_shape_weight_list, vert_list, vert_weight_list, has_smpl_list,
+ has_keypoints2d_list, has_keypoints3d_list, pos_inds_list,
+ neg_inds_list) = multi_apply(
+ self.prepare_targets,
+ all_pred_smpl_pose_list,
+ all_pred_smpl_shape_list,
+ all_pred_kp3d_list,
+ all_pred_vert_list,
+ all_pred_cam_list,
+ all_gt_smpl_body_pose_list,
+ all_gt_smpl_betas_list,
+ all_gt_kp2d_list,
+ all_gt_kp3d_list,
+ all_has_keypoints2d_list,
+ all_has_keypoints3d_list,
+ all_has_smpl_list,
+ img_metas_list,
+ gt_bboxes_ignore_list,
+ )
+ num_total_pos = sum((inds.numel() for inds in pos_inds_list))
+ num_total_neg = sum((inds.numel() for inds in neg_inds_list))
+
+ K = outputs_kp3ds.shape[-2]
+
+ gt_kp2d = torch.cat(kp2d_list, 0)
+ kp2d_weight = torch.cat(kp2d_weight_list, 0)
+ pred_cam = outputs_cameras.reshape(-1, 3)
+ # pred_kp2d = torch.cat()
+
+ gt_kp3d = torch.cat(kp3d_list, 0)
+ kp3d_weight = torch.cat(kp3d_weight_list, 0)
+ pred_kp3d = outputs_kp3ds.reshape(-1, K, 3)
+
+ gt_smpl_pose = torch.cat(smpl_pose_list, 0)
+ smpl_pose_weight = torch.cat(smpl_pose_weight_list, 0)
+ pred_smpl_pose = outputs_poses.reshape(-1, 24, 3, 3)
+
+ gt_smpl_shape = torch.cat(smpl_shape_list, 0)
+ smpl_shape_weight = torch.cat(smpl_shape_weight_list, 0)
+ pred_smpl_shape = outputs_shapes.reshape(-1, 10)
+
+ gt_vert = torch.cat(vert_list, 0)
+ vert_weight = torch.cat(vert_weight_list, 0)
+ pred_verts = outputs_verts.reshape(-1, 6890, 3)
+
+ has_smpl = torch.cat(has_smpl_list, 0).squeeze()
+ has_keypoints2d = torch.cat(has_keypoints2d_list, 0).squeeze()
+ has_keypoints3d = torch.cat(has_keypoints3d_list, 0).squeeze()
+
+ # losses = {}
+ if self.loss_keypoints2d is not None:
+ keypoints2d_loss = self.compute_keypoints2d_loss(
+ pred_kp3d, pred_cam, gt_kp2d, has_keypoints2d=has_keypoints2d)
+ else:
+ keypoints2d_loss = 0.0
+
+ if self.loss_keypoints3d is not None:
+ keypoints3d_loss = self.compute_keypoints3d_loss(
+ pred_kp3d,
+ gt_kp3d,
+ has_keypoints3d=has_keypoints3d,
+ )
+ else:
+ keypoints3d_loss = 0.0
+
+ if self.loss_vertex is not None:
+ vertex_loss = self.compute_vertex_loss(pred_verts,
+ gt_vert,
+ has_smpl=has_smpl)
+ else:
+ vertex_loss = 0.0
+
+ if self.loss_smpl_pose is not None:
+ smpl_pose_loss = self.compute_smpl_pose_loss(pred_smpl_pose,
+ gt_smpl_pose,
+ has_smpl=has_smpl)
+ else:
+ smpl_pose_loss = 0.0
+
+ if self.loss_smpl_betas is not None:
+ smpl_betas_loss = self.compute_smpl_betas_loss(pred_smpl_shape,
+ gt_smpl_shape,
+ has_smpl=has_smpl)
+ else:
+ smpl_betas_loss = 0.0
+ # if self.loss_iou is not None:
+ # losses['iou_loss'] = self.loss_iou()
+
+ # if self.loss_bbox is not None:
+ # losses['bbox_loss'] = self.loss_bbox()
+
+ # if self.loss_cls is not None:
+ # losses['cls_loss'] = self.loss_bbox()
+
+ return (keypoints2d_loss, keypoints3d_loss, vertex_loss,
+ smpl_pose_loss, smpl_betas_loss)
+
+ def prepare_targets(self, pred_smpl_pose, pred_smpl_shape, pred_kp3d,
+ pred_vert, pred_cam, gt_smpl_pose, gt_smpl_shape,
+ gt_kp2d, gt_kp3d, has_keypoints2d, has_keypoints3d,
+ has_smpl, img_meta, gt_bboxes_ignore):
+ """_summary_
+
+ Args:
+ all_pred_smpl_pose (_type_): _description_
+ all_pred_smpl_shape (_type_): _description_
+ all_pred_kp3d (_type_): _description_
+ all_pred_vert (_type_): _description_
+ all_gt_smpl_body_pose (_type_): _description_
+ all_gt_smpl_betas (_type_): _description_
+ all_gt_kp2d (_type_): _description_
+ all_gt_kp3d (_type_): with shape [N, K, D]
+ img_meta (_type_): _description_
+ gt_bboxes_ignore (_type_): _description_
+ """
+ num_query = pred_smpl_pose.shape[0]
+ assign_result = self.assigner.assign(pred_smpl_pose, pred_smpl_shape,
+ pred_kp3d, pred_vert, pred_cam,
+ gt_smpl_pose, gt_smpl_shape,
+ gt_kp2d, gt_kp3d, has_keypoints2d,
+ has_keypoints3d, has_smpl,
+ img_meta, gt_bboxes_ignore)
+
+ gt_smpl_pose = gt_smpl_pose.float()
+ gt_smpl_shape = gt_smpl_shape.float()
+ gt_kp2d = gt_kp2d.float()
+ gt_kp3d = gt_kp3d.float()
+ has_keypoints2d = has_keypoints2d.float()
+ has_keypoints3d = has_keypoints3d.float()
+ has_smpl = has_smpl.float()
+
+ sampling_result = self.sampler.sample(assign_result, pred_smpl_pose,
+ gt_smpl_pose)
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+
+ # img_h, img_w, _ = img_meta['img_shape']
+
+ # kp2d target
+ kp2d_targets = torch.zeros_like(pred_kp3d[..., :2])
+ kp2d_weights = torch.zeros_like(pred_kp3d[..., :2])
+ kp2d_targets[pos_inds] = gt_kp2d[sampling_result.pos_assigned_gt_inds][
+ ..., :2]
+ kp2d_weights[pos_inds] = gt_kp2d[sampling_result.pos_assigned_gt_inds][
+ ..., [2]].repeat(1, 1, 2)
+ kp2d_targets = torch.cat(
+ [kp2d_targets, kp2d_weights[..., 0].unsqueeze(-1)], dim=-1)
+ # kp3d target
+ kp3d_targets = torch.zeros_like(pred_kp3d)
+ kp3d_weights = torch.zeros_like(pred_kp3d)
+ kp3d_targets[pos_inds] = gt_kp3d[sampling_result.pos_assigned_gt_inds][
+ ..., :3]
+ kp3d_weights[pos_inds] = gt_kp3d[sampling_result.pos_assigned_gt_inds][
+ ..., [3]].repeat(1, 1, 3)
+ kp3d_targets = torch.cat(
+ [kp3d_targets, kp3d_weights[..., 0].unsqueeze(-1)], dim=-1)
+
+ # smpl_pose target
+ smpl_pose_targets = torch.zeros_like(pred_smpl_pose)
+ smpl_pose_weights = torch.zeros_like(pred_smpl_pose)
+ gt_smpl_pose_rotmat = batch_rodrigues(gt_smpl_pose.view(-1, 3)).view(
+ -1, 24, 3, 3)
+ smpl_pose_targets[pos_inds] = gt_smpl_pose_rotmat[
+ sampling_result.pos_assigned_gt_inds]
+ smpl_pose_weights[pos_inds] = 1.0
+
+ # smpl_beta target
+ smpl_shape_targets = torch.zeros_like(pred_smpl_shape)
+ smpl_shape_weights = torch.zeros_like(pred_smpl_shape)
+ smpl_shape_targets[pos_inds] = gt_smpl_shape[
+ sampling_result.pos_assigned_gt_inds]
+ smpl_shape_weights[pos_inds] = 1.0
+
+ # verts
+ if self.body_model_train is not None:
+ gt_output = self.body_model_train(
+ betas=gt_smpl_shape,
+ body_pose=gt_smpl_pose_rotmat[:, 1:],
+ global_orient=gt_smpl_pose_rotmat[:, 0].unsqueeze(1),
+ pose2rot=False)
+ gt_vertices = gt_output['vertices']
+ gt_model_joints = gt_output['joints']
+
+ vert_targets = torch.zeros_like(pred_vert)
+ vert_weights = torch.zeros_like(pred_vert)
+ vert_targets[pos_inds] = gt_vertices[
+ sampling_result.pos_assigned_gt_inds]
+ vert_weights[pos_inds] = 1.0
+
+ if has_keypoints2d is not None:
+ has_keypoints2d_ = torch.zeros(
+ (num_query, 1)).to(smpl_pose_targets.device)
+ has_keypoints2d_[pos_inds] = has_keypoints2d[
+ sampling_result.pos_assigned_gt_inds]
+ else:
+ has_keypoints2d_ = None
+
+ if has_keypoints3d is not None:
+ has_keypoints3d_ = torch.zeros(
+ (num_query, 1)).to(smpl_pose_targets.device)
+ has_keypoints3d_[pos_inds] = has_keypoints3d[
+ sampling_result.pos_assigned_gt_inds]
+ else:
+ has_keypoints3d_ = None
+
+ if has_smpl is not None:
+ has_smpl_ = torch.zeros(
+ (num_query, 1)).to(smpl_pose_targets.device)
+ # if len(sampling_result.pos_assigned_gt_inds) == 1:
+ # has_smpl_[pos_inds] = has_smpl
+ # else:
+ has_smpl_[pos_inds] = has_smpl[
+ sampling_result.pos_assigned_gt_inds]
+ else:
+ has_smpl_ = None
+ return (kp2d_targets, kp2d_weights, kp3d_targets, kp3d_weights,
+ smpl_pose_targets, smpl_pose_weights, smpl_shape_targets,
+ smpl_shape_weights, vert_targets, vert_weights, has_smpl_,
+ has_keypoints2d_, has_keypoints3d_, pos_inds, neg_inds)
+
+ def forward_test(self, img, img_metas, **kwargs):
+ batch_input_shape = tuple(img[0].size()[-2:])
+ for img_meta in img_metas:
+ img_meta['batch_input_shape'] = batch_input_shape
+ features = self.backbone(img)
+ if self.neck is not None:
+ features = self.neck(features)
+ pred_pose, pred_betas, pred_cam, _, _ = \
+ self.head(features, img_metas)
+
+ # pred_pose = pred_pose[-1]
+ # pred_betas = pred_betas[-1]
+ # pred_cam = pred_cam[-1]
+
+ L, B, N = pred_pose.shape[:3]
+ if self.body_model_test is not None:
+ pred_output = self.body_model_test(
+ betas=pred_betas.reshape(L * B * N, 10),
+ body_pose=pred_pose.reshape(L * B * N, 24, 3, 3)[:, 1:],
+ global_orient=pred_pose.reshape(L * B * N, 24, 3,
+ 3)[:, 0].unsqueeze(1),
+ pose2rot=False)
+ else:
+ raise ValueError('Please provide a builded body model.')
+
+ pred_keypoints_3d = pred_output['joints'].reshape(L, B, N, -1, 3)
+ pred_keypoints_3d = (pred_keypoints_3d -
+ pred_keypoints_3d[..., [0], :])
+ pred_keypoints_3d = pred_keypoints_3d.detach().cpu().numpy()
+ # pred_vertices = pred_output['vertices'].reshape(L, B, N, 6890, 3)
+ pred_cam = pred_cam.detach().cpu().numpy()
+ pred_pose = pred_pose.detach().cpu().numpy()
+ pred_betas = pred_betas.detach().cpu().numpy()
+ # batch, instance_num, kp_num, 4
+ gt_keypoints3d = kwargs['keypoints3d'].repeat([1, N, 1, 1]).clone()
+ # keypoints3d_mask = kwargs['keypoints3d_mask']
+ gt_keypoints3d = gt_keypoints3d.detach().cpu().numpy()
+ # gt_keypoints3d, _ = convert_kps(
+ # gt_keypoints3d,
+ # src='human_data',
+ # dst='h36m')
+
+ cost = np.sum((pred_keypoints_3d[-1] - gt_keypoints3d[..., :3]),
+ axis=(2, 3))
+ index = np.argmin(abs(cost), -1)
+
+ pred_keypoints_3d_ = []
+ pred_pose_ = []
+ pred_betas_ = []
+ pred_cam_ = []
+
+ for batch_i in range(B):
+ ind = index[batch_i]
+ pred_keypoints_3d_.append(pred_keypoints_3d[-1, batch_i, ind])
+ pred_pose_.append(pred_pose[-1, batch_i, ind])
+ pred_betas_.append(pred_betas[-1, batch_i, ind])
+ pred_cam_.append(pred_cam[-1, batch_i, ind])
+
+ # for img_id in range(len(img_metas)):
+ # pred_pose_ = pred_pose[:, img_id]
+ # pred_betas_ = pred_betas[:, img_id]
+ # pred_cam_ = pred_cam[:, img_id]
+ # pred_keypoints_3d_ = pred_keypoints_3d[:, img_id]
+ # pred_vertices_ = pred_vertices[:, img_id]
+ # img_shape_ = img_metas[img_id]['img_shape']
+
+ # result_list.append()
+
+ all_preds = {}
+ all_preds['keypoints_3d'] = np.array(pred_keypoints_3d_)
+ all_preds['smpl_pose'] = np.array(pred_pose_)
+ all_preds['smpl_beta'] = np.array(pred_betas_)
+ all_preds['camera'] = np.array(pred_cam_)
+ # all_preds['vertices'] = pred_vertices.detach().cpu().numpy()
+
+ image_path = []
+ for img_meta in img_metas:
+ image_path.append(img_meta['image_path'])
+ all_preds['image_path'] = image_path
+ all_preds['image_idx'] = kwargs['sample_idx']
+ return all_preds
+ # loss
+
+ def compute_keypoints3d_loss(
+ self,
+ pred_keypoints3d: torch.Tensor,
+ gt_keypoints3d: torch.Tensor,
+ has_keypoints3d: Optional[torch.Tensor] = None):
+ """Compute loss for 3d keypoints."""
+ keypoints3d_conf = gt_keypoints3d[:, :, 3].float().unsqueeze(-1)
+ keypoints3d_conf = keypoints3d_conf.repeat(1, 1, 3)
+ pred_keypoints3d = pred_keypoints3d.float()
+ gt_keypoints3d = gt_keypoints3d[:, :, :3].float()
+
+ # currently, only mpi_inf_3dhp and h36m have 3d keypoints
+ # both datasets have right_hip_extra and left_hip_extra
+ right_hip_idx = get_keypoint_idx('right_hip_extra', self.convention)
+ left_hip_idx = get_keypoint_idx('left_hip_extra', self.convention)
+ gt_pelvis = (gt_keypoints3d[:, right_hip_idx, :] +
+ gt_keypoints3d[:, left_hip_idx, :]) / 2
+ pred_pelvis = (pred_keypoints3d[:, right_hip_idx, :] +
+ pred_keypoints3d[:, left_hip_idx, :]) / 2
+
+ gt_keypoints3d = gt_keypoints3d - gt_pelvis[:, None, :]
+ pred_keypoints3d = pred_keypoints3d - pred_pelvis[:, None, :]
+ loss = self.loss_keypoints3d(pred_keypoints3d,
+ gt_keypoints3d,
+ reduction_override='none')
+
+ # If has_keypoints3d is not None, then computes the losses on the
+ # instances that have ground-truth keypoints3d.
+ # But the zero confidence keypoints will be included in mean.
+ # Otherwise, only compute the keypoints3d
+ # which have positive confidence.
+
+ # has_keypoints3d is None when the key has_keypoints3d
+ # is not in the datasets
+ if has_keypoints3d is None:
+
+ valid_pos = keypoints3d_conf > 0
+ if keypoints3d_conf[valid_pos].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints3d)
+ loss = torch.sum(loss * keypoints3d_conf)
+ loss /= keypoints3d_conf[valid_pos].numel()
+ else:
+
+ keypoints3d_conf = keypoints3d_conf[has_keypoints3d == 1]
+ if keypoints3d_conf.shape[0] == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints3d)
+ loss = loss[has_keypoints3d == 1]
+ loss = (loss * keypoints3d_conf).mean()
+ return loss
+
+ def compute_keypoints2d_loss(
+ self,
+ pred_keypoints3d: torch.Tensor,
+ pred_cam: torch.Tensor,
+ gt_keypoints2d: torch.Tensor,
+ img_res: Optional[int] = 512,
+ focal_length: Optional[int] = 5000.,
+ has_keypoints2d: Optional[torch.Tensor] = None):
+ """Compute loss for 2d keypoints."""
+ keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1)
+ keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2)
+ gt_keypoints2d = gt_keypoints2d[:, :, :2].float()
+ pred_keypoints2d = project_points(pred_keypoints3d,
+ pred_cam,
+ focal_length=focal_length,
+ img_res=img_res)
+ # Normalize keypoints to [-1,1]
+ # The coordinate origin of pred_keypoints_2d is
+ # the center of the input image.
+ pred_keypoints2d = 2 * pred_keypoints2d / (img_res - 1)
+ # The coordinate origin of gt_keypoints_2d is
+ # the top left corner of the input image.
+ gt_keypoints2d = 2 * gt_keypoints2d / (img_res - 1) - 1
+ loss = self.loss_keypoints2d(pred_keypoints2d,
+ gt_keypoints2d,
+ reduction_override='none')
+
+ # If has_keypoints2d is not None, then computes the losses on the
+ # instances that have ground-truth keypoints2d.
+ # But the zero confidence keypoints will be included in mean.
+ # Otherwise, only compute the keypoints2d
+ # which have positive confidence.
+ # has_keypoints2d is None when the key has_keypoints2d
+ # is not in the datasets
+
+ if has_keypoints2d is None:
+ valid_pos = keypoints2d_conf > 0
+ if keypoints2d_conf[valid_pos].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints2d)
+ loss = torch.sum(loss * keypoints2d_conf)
+ loss /= keypoints2d_conf[valid_pos].numel()
+ else:
+ keypoints2d_conf = keypoints2d_conf[has_keypoints2d == 1]
+ if keypoints2d_conf.shape[0] == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints2d)
+ loss = loss[has_keypoints2d == 1]
+ loss = (loss * keypoints2d_conf).mean()
+
+ return loss
+
+ def compute_vertex_loss(self, pred_vertices: torch.Tensor,
+ gt_vertices: torch.Tensor, has_smpl: torch.Tensor):
+ """Compute loss for vertices."""
+ gt_vertices = gt_vertices.float()
+ conf = has_smpl.float().view(-1, 1, 1)
+ conf = conf.repeat(1, gt_vertices.shape[1], gt_vertices.shape[2])
+ loss = self.loss_vertex(pred_vertices,
+ gt_vertices,
+ reduction_override='none')
+ valid_pos = conf > 0
+ if conf[valid_pos].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_vertices)
+ loss = torch.sum(loss * conf) / conf[valid_pos].numel()
+ return loss
+
+ def compute_smpl_pose_loss(self, pred_pose: torch.Tensor,
+ gt_pose: torch.Tensor, has_smpl: torch.Tensor):
+ """Compute loss for smpl pose."""
+ conf = has_smpl.float().view(-1)
+ valid_pos = conf > 0
+ if conf[valid_pos].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_pose)
+ pred_pose = pred_pose[valid_pos]
+ gt_pose = gt_pose[valid_pos]
+ conf = conf[valid_pos]
+ # gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24, 3, 3)
+ loss = self.loss_smpl_pose(pred_pose,
+ gt_pose,
+ reduction_override='none')
+ loss = loss.view(loss.shape[0], -1).mean(-1)
+ loss = torch.mean(loss * conf)
+ return loss
+
+ def compute_smpl_betas_loss(self, pred_betas: torch.Tensor,
+ gt_betas: torch.Tensor,
+ has_smpl: torch.Tensor):
+ """Compute loss for smpl betas."""
+ conf = has_smpl.float().view(-1)
+ valid_pos = conf > 0
+ if conf[valid_pos].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_betas)
+ pred_betas = pred_betas[valid_pos]
+ gt_betas = gt_betas[valid_pos]
+ conf = conf[valid_pos]
+ loss = self.loss_smpl_betas(pred_betas,
+ gt_betas,
+ reduction_override='none')
+ loss = loss.view(loss.shape[0], -1).mean(-1)
+ loss = torch.mean(loss * conf)
+ return loss
+
+ def compute_camera_loss(self, cameras: torch.Tensor):
+ """Compute loss for predicted camera parameters."""
+ loss = self.loss_camera(cameras)
+ return loss
diff --git a/detrsmpl/models/architectures/DetrSMPLloss.py b/detrsmpl/models/architectures/DetrSMPLloss.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac9b98dfd353c69f8bfdd503023c11910f54fe4
--- /dev/null
+++ b/detrsmpl/models/architectures/DetrSMPLloss.py
@@ -0,0 +1,739 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta
+from typing import Optional, Union
+
+import torch
+from scipy.optimize import linear_sum_assignment
+import numpy as np
+from detrsmpl.core.post_processing.bbox.assigners import build_assigner
+from detrsmpl.core.post_processing.bbox.samplers import build_sampler
+from detrsmpl.core.conventions.keypoints_mapping import (get_keypoint_idx,
+ convert_kps)
+from detrsmpl.utils.geometry import batch_rodrigues
+from detrsmpl.utils.geometry import project_points
+from detrsmpl.utils.misc import multi_apply
+from ..backbones.builder import build_backbone
+from ..body_models.builder import build_body_model
+from ..heads.builder import build_head
+from ..losses.builder import build_loss
+from ..necks.builder import build_neck
+from .base_architecture import BaseArchitecture
+
+# from mmdet.core import bbox2result
+
+
+class DETRLoss(BaseArchitecture, metaclass=ABCMeta):
+ def __init__(
+ self,
+ body_model_train: Optional[Union[dict, None]] = None,
+ body_model_test: Optional[Union[dict, None]] = None,
+ convention: Optional[str] = 'human_data',
+ loss_keypoints2d: Optional[Union[dict, None]] = None,
+ loss_keypoints3d: Optional[Union[dict, None]] = None,
+ loss_vertex: Optional[Union[dict, None]] = None,
+ loss_smpl_pose: Optional[Union[dict, None]] = None,
+ loss_smpl_betas: Optional[Union[dict, None]] = None,
+ loss_camera: Optional[Union[dict, None]] = None,
+ loss_cls: Optional[Union[dict,
+ None]] = dict(type='CrossEntropyLoss',
+ bg_cls_weight=0.1,
+ use_sigmoid=False,
+ loss_weight=1.0,
+ class_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=5.0),
+ loss_iou=dict(type='GIoULoss', loss_weight=2.0),
+ init_cfg: Optional[Union[list, dict, None]] = None,
+ train_cfg:
+ Optional[Union[dict, None]] = dict(assigner=dict(
+ type='HungarianAssigner',
+ kp3d_cost=dict(
+ type='Keypoints3DCost', convention='smpl_54', weight=5.0),
+ kp2d_cost=dict(
+ type='Keypoints2DCost', convention='smpl_54', weight=5.0),
+ # cls_cost=dict(type='ClassificationCost', weight=1.),
+ # reg_cost=dict(type='BBoxL1Cost', weight=5.0),
+ # iou_cost=dict(
+ # type='IoUCost', iou_mode='giou', weight=2.0))
+ )),
+ test_cfg: Optional[Union[dict, None]] = None):
+
+ super(DETRLoss, self).__init__(init_cfg)
+ if train_cfg:
+ assert 'assigner' in train_cfg, 'assigner should be provided '\
+ 'when train_cfg is set.'
+ assigner = train_cfg['assigner']
+ # TODO: update these
+ # assert loss_cls['loss_weight'] == assigner['kp3d_cost']['weight'], \
+ # 'The classification weight for loss and matcher should be' \
+ # 'exactly the same.'
+ # assert loss_bbox['loss_weight'] == assigner['kp3d_cost'][
+ # 'weight'], 'The regression L1 weight for loss and matcher ' \
+ # 'should be exactly the same.'
+ # assert loss_iou['loss_weight'] == assigner['kp3d_cost']['weight'], \
+ # 'The regression iou weight for loss and matcher should be' \
+ # 'exactly the same.'
+ self.assigner = build_assigner(assigner)
+ # DETR sampling=False, so use PseudoSampler
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ # build loss
+ self.loss_keypoints2d = build_loss(loss_keypoints2d)
+ self.loss_keypoints3d = build_loss(loss_keypoints3d)
+ self.loss_vertex = build_loss(loss_vertex)
+ self.loss_smpl_pose = build_loss(loss_smpl_pose)
+ self.loss_smpl_betas = build_loss(loss_smpl_betas)
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox = build_loss(loss_bbox)
+ self.loss_iou = build_loss(loss_iou)
+
+ self.body_model_train = build_body_model(body_model_train)
+ self.body_model_test = build_body_model(body_model_test)
+ self.convention = convention
+
+ def forward_train(self, preds, targets):
+ pass
+
+ def forward(self, preds, targets):
+ """
+ Args:
+ img (Tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ gt_bboxes (list[Tensor]): Each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ # super(SingleStageDetector, self).forward_train(img, img_metas)
+ # NOTE the batched image size information may be useful, e.g.
+ # in DETR, this is needed for the construction of masks, which is
+ # then used for the transformer_head.
+ pred_pose = preds['pred_pose']
+ pred_betas = preds['pred_betas']
+ pred_cameras = preds['pred_cameras']
+ has_smpl = targets['has_smpl']
+ gt_smpl_body_pose = targets[
+ 'smpl_body_pose'] # [bs_0: [ins_num, 23, 3]]
+ gt_smpl_global_orient = targets['smpl_global_orient']
+ gt_smpl_body_pose = \
+ [torch.cat((gt_smpl_global_orient[i].view(-1, 1, 3),
+ gt_smpl_body_pose[i]), dim=1).float()
+ for i in range(len(gt_smpl_body_pose))]
+ gt_smpl_betas = targets['smpl_betas']
+ gt_smpl_transl = targets['smpl_transl']
+ gt_keypoints2d = targets['keypoints2d']
+ gt_keypoints3d = targets['keypoints3d'] # [bs_0: [N. K, D], ...]
+ img_metas = targets['img_metas']
+ if 'has_keypoints3d' in targets:
+ has_keypoints3d = targets['has_keypoints3d']
+ else:
+ has_keypoints3d = None
+
+ if 'has_keypoints2d' in targets:
+ has_keypoints2d = targets['has_keypoints2d']
+ else:
+ has_keypoints2d = None
+
+ img = targets['img']
+
+ batch_input_shape = tuple(img[0].size()[-2:])
+ for img_meta in img_metas:
+ img_meta['batch_input_shape'] = batch_input_shape
+
+ L, B, N = pred_pose.shape[:3]
+ if self.body_model_train is not None:
+ pred_output = self.body_model_train(
+ betas=pred_betas.reshape(L * B * N, 10),
+ body_pose=pred_pose.reshape(L * B * N, 24, 3, 3)[:, 1:],
+ global_orient=pred_pose.reshape(L * B * N, 24, 3,
+ 3)[:, 0].unsqueeze(1),
+ pose2rot=False,
+ num_joints=gt_keypoints2d[0].shape[1])
+ pred_keypoints3d = pred_output['joints'].reshape(L, B, N, -1, 3)
+ pred_vertices = pred_output['vertices'].reshape(L, B, N, 6890, 3)
+ # loss
+ num_dec_layers = pred_pose.shape[0]
+
+ all_gt_smpl_body_pose_list = [
+ gt_smpl_body_pose for _ in range(num_dec_layers)
+ ]
+ all_gt_smpl_global_orient_list = [
+ gt_smpl_global_orient for _ in range(num_dec_layers)
+ ]
+ all_gt_smpl_betas_list = [gt_smpl_betas for _ in range(num_dec_layers)]
+ all_gt_smpl_transl_list = [
+ gt_smpl_transl for _ in range(num_dec_layers)
+ ]
+ all_gt_keypoints2d_list = [
+ gt_keypoints2d for _ in range(num_dec_layers)
+ ]
+ all_gt_keypoints3d_list = [
+ gt_keypoints3d for _ in range(num_dec_layers)
+ ]
+ all_has_smpl_list = [has_smpl for _ in range(num_dec_layers)]
+ all_has_keypoints3d_list = [
+ has_keypoints3d for _ in range(num_dec_layers)
+ ]
+ all_has_keypoints2d_list = [
+ has_keypoints2d for _ in range(num_dec_layers)
+ ]
+ all_gt_ignore_list = [None for _ in range(num_dec_layers)]
+ img_metas_list = [img_metas for _ in range(num_dec_layers)]
+ # all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
+ # all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
+ # all_gt_bboxes_ignore_list = [
+ # gt_bboxes_ignore for _ in range(num_dec_layers)
+ # ]
+ # computer loss for each layer
+ (kp2d_loss, kp3d_loss, vert_loss, pose_loss, beta_loss) = multi_apply(
+ self.compute_losses, pred_pose, pred_betas, pred_keypoints3d,
+ pred_vertices, pred_cameras, all_gt_smpl_body_pose_list,
+ all_gt_smpl_betas_list, all_gt_keypoints2d_list,
+ all_gt_keypoints3d_list, all_has_keypoints2d_list,
+ all_has_keypoints3d_list, all_has_smpl_list, img_metas_list,
+ all_gt_ignore_list)
+
+ losses = {}
+ losses['keypoints2d_loss'] = kp2d_loss[-1]
+ losses['keypoints3d_loss'] = kp3d_loss[-1]
+ losses['vertex_loss'] = vert_loss[-1]
+ losses['smpl_pose_loss'] = pose_loss[-1]
+ losses['smpl_betas_loss'] = beta_loss[-1]
+
+ # loss from other decoder layers
+ num_dec_layer = 0
+ for (kp2d_loss_i, kp3d_loss_i, vert_loss_i, pose_loss_i,
+ beta_loss_i) in zip(kp2d_loss[:-1], kp3d_loss[:-1],
+ vert_loss[:-1], pose_loss[:-1],
+ beta_loss[:-1]):
+ losses[f'd{num_dec_layer}.keypoints2d_loss'] = kp2d_loss_i
+ losses[f'd{num_dec_layer}.keypoints3d_loss'] = kp3d_loss_i
+ losses[f'd{num_dec_layer}.vertex_loss'] = vert_loss_i
+ losses[f'd{num_dec_layer}.smpl_pose_loss'] = pose_loss_i
+ losses[f'd{num_dec_layer}.smpl_betas_loss'] = beta_loss_i
+ num_dec_layer += 1
+
+ return losses
+
+ def compute_losses(self,
+ outputs_poses,
+ outputs_shapes,
+ outputs_kp3ds,
+ outputs_verts,
+ outputs_cameras,
+ all_gt_smpl_body_pose_list,
+ all_gt_smpl_betas_list,
+ all_gt_kp2d_list,
+ all_gt_kp3d_list,
+ all_has_keypoints2d_list,
+ all_has_keypoints3d_list,
+ all_has_smpl_list,
+ img_metas_list,
+ all_gt_ignore_list=None):
+ """_summary_
+ loss_single
+ get_targets
+ Args:
+ outputs_poses (_type_): with shape [B, N, 24, 3, 3]
+ outputs_shapes (_type_): _description_
+ all_gt_smpl_body_pose_list (_type_): _description_
+ all_gt_smpl_betas_list (_type_): _description_
+ all_gt_kp2d_list (Torch.tensor):
+ all_gt_kp3d_list (list): with shape [B, N, K, D]
+ img_metas_list (_type_): _description_
+ all_gt_ignore_list (_type_): _description_
+ """
+ num_img = outputs_poses.size(0) # batch_size
+ all_pred_smpl_pose_list = [outputs_poses[i] for i in range(num_img)]
+ all_pred_smpl_shape_list = [outputs_shapes[i] for i in range(num_img)]
+ all_pred_kp3d_list = [outputs_kp3ds[i] for i in range(num_img)]
+ all_pred_vert_list = [outputs_verts[i] for i in range(num_img)]
+ all_pred_cam_list = [outputs_cameras[i] for i in range(num_img)]
+
+ gt_bboxes_ignore_list = [all_gt_ignore_list for _ in range(num_img)]
+
+ if all_has_keypoints2d_list is None:
+ all_has_keypoints2d_list = [
+ all_has_keypoints2d_list for _ in range(num_img)
+ ]
+
+ if all_has_keypoints3d_list is None:
+ all_has_keypoints3d_list = [
+ all_has_keypoints3d_list for _ in range(num_img)
+ ]
+
+ if all_has_smpl_list is None:
+ all_has_smpl_list = [all_has_smpl_list for _ in range(num_img)]
+
+ # for each batch data
+ (kp2d_list, kp2d_weight_list, kp3d_list, kp3d_weight_list,
+ smpl_pose_list, smpl_pose_weight_list, smpl_shape_list,
+ smpl_shape_weight_list, vert_list, vert_weight_list, has_smpl_list,
+ has_keypoints2d_list, has_keypoints3d_list, pos_inds_list,
+ neg_inds_list) = multi_apply(
+ self.prepare_targets,
+ all_pred_smpl_pose_list,
+ all_pred_smpl_shape_list,
+ all_pred_kp3d_list,
+ all_pred_vert_list,
+ all_pred_cam_list,
+ all_gt_smpl_body_pose_list,
+ all_gt_smpl_betas_list,
+ all_gt_kp2d_list,
+ all_gt_kp3d_list,
+ all_has_keypoints2d_list,
+ all_has_keypoints3d_list,
+ all_has_smpl_list,
+ img_metas_list,
+ gt_bboxes_ignore_list,
+ )
+ num_total_pos = sum((inds.numel() for inds in pos_inds_list))
+ num_total_neg = sum((inds.numel() for inds in neg_inds_list))
+
+ K = outputs_kp3ds.shape[-2]
+
+ gt_kp2d = torch.cat(kp2d_list, 0)
+ kp2d_weight = torch.cat(kp2d_weight_list, 0)
+ pred_cam = outputs_cameras.reshape(-1, 3)
+ # pred_kp2d = torch.cat()
+
+ gt_kp3d = torch.cat(kp3d_list, 0)
+ kp3d_weight = torch.cat(kp3d_weight_list, 0)
+ pred_kp3d = outputs_kp3ds.reshape(-1, K, 3)
+
+ gt_smpl_pose = torch.cat(smpl_pose_list, 0)
+ smpl_pose_weight = torch.cat(smpl_pose_weight_list, 0)
+ pred_smpl_pose = outputs_poses.reshape(-1, 24, 3, 3)
+
+ gt_smpl_shape = torch.cat(smpl_shape_list, 0)
+ smpl_shape_weight = torch.cat(smpl_shape_weight_list, 0)
+ pred_smpl_shape = outputs_shapes.reshape(-1, 10)
+
+ gt_vert = torch.cat(vert_list, 0)
+ vert_weight = torch.cat(vert_weight_list, 0)
+ pred_verts = outputs_verts.reshape(-1, 6890, 3)
+
+ has_smpl = torch.cat(has_smpl_list, 0).squeeze()
+ has_keypoints2d = torch.cat(has_keypoints2d_list, 0).squeeze()
+ has_keypoints3d = torch.cat(has_keypoints3d_list, 0).squeeze()
+
+ # losses = {}
+ if self.loss_keypoints2d is not None:
+ keypoints2d_loss = self.compute_keypoints2d_loss(
+ pred_kp3d, pred_cam, gt_kp2d, has_keypoints2d=has_keypoints2d)
+ else:
+ keypoints2d_loss = 0.0
+
+ if self.loss_keypoints3d is not None:
+ keypoints3d_loss = self.compute_keypoints3d_loss(
+ pred_kp3d,
+ gt_kp3d,
+ has_keypoints3d=has_keypoints3d,
+ )
+ else:
+ keypoints3d_loss = 0.0
+
+ if self.loss_vertex is not None:
+ vertex_loss = self.compute_vertex_loss(pred_verts,
+ gt_vert,
+ has_smpl=has_smpl)
+ else:
+ vertex_loss = 0.0
+
+ if self.loss_smpl_pose is not None:
+ smpl_pose_loss = self.compute_smpl_pose_loss(pred_smpl_pose,
+ gt_smpl_pose,
+ has_smpl=has_smpl)
+ else:
+ smpl_pose_loss = 0.0
+
+ if self.loss_smpl_betas is not None:
+ smpl_betas_loss = self.compute_smpl_betas_loss(pred_smpl_shape,
+ gt_smpl_shape,
+ has_smpl=has_smpl)
+ else:
+ smpl_betas_loss = 0.0
+ # if self.loss_iou is not None:
+ # losses['iou_loss'] = self.loss_iou()
+
+ # if self.loss_bbox is not None:
+ # losses['bbox_loss'] = self.loss_bbox()
+
+ # if self.loss_cls is not None:
+ # losses['cls_loss'] = self.loss_bbox()
+
+ return (keypoints2d_loss, keypoints3d_loss, vertex_loss,
+ smpl_pose_loss, smpl_betas_loss)
+
+ def prepare_targets(self, pred_smpl_pose, pred_smpl_shape, pred_kp3d,
+ pred_vert, pred_cam, gt_smpl_pose, gt_smpl_shape,
+ gt_kp2d, gt_kp3d, has_keypoints2d, has_keypoints3d,
+ has_smpl, img_meta, gt_bboxes_ignore):
+ """_summary_
+
+ Args:
+ all_pred_smpl_pose (_type_): _description_
+ all_pred_smpl_shape (_type_): _description_
+ all_pred_kp3d (_type_): _description_
+ all_pred_vert (_type_): _description_
+ all_gt_smpl_body_pose (_type_): _description_
+ all_gt_smpl_betas (_type_): _description_
+ all_gt_kp2d (_type_): _description_
+ all_gt_kp3d (_type_): with shape [N, K, D]
+ img_meta (_type_): _description_
+ gt_bboxes_ignore (_type_): _description_
+ """
+ num_query = pred_smpl_pose.shape[0]
+ assign_result = self.assigner.assign(pred_smpl_pose, pred_smpl_shape,
+ pred_kp3d, pred_vert, pred_cam,
+ gt_smpl_pose, gt_smpl_shape,
+ gt_kp2d, gt_kp3d, has_keypoints2d,
+ has_keypoints3d, has_smpl,
+ img_meta, gt_bboxes_ignore)
+
+ gt_smpl_pose = gt_smpl_pose.float()
+ gt_smpl_shape = gt_smpl_shape.float()
+ gt_kp2d = gt_kp2d.float()
+ gt_kp3d = gt_kp3d.float()
+ has_keypoints2d = has_keypoints2d.float()
+ has_keypoints3d = has_keypoints3d.float()
+ has_smpl = has_smpl.float()
+
+ sampling_result = self.sampler.sample(assign_result, pred_smpl_pose,
+ gt_smpl_pose)
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+
+ # img_h, img_w, _ = img_meta['img_shape']
+
+ # kp2d target
+ kp2d_targets = torch.zeros_like(pred_kp3d[..., :2])
+ kp2d_weights = torch.zeros_like(pred_kp3d[..., :2])
+ kp2d_targets[pos_inds] = gt_kp2d[sampling_result.pos_assigned_gt_inds][
+ ..., :2]
+ kp2d_weights[pos_inds] = gt_kp2d[sampling_result.pos_assigned_gt_inds][
+ ..., [2]].repeat(1, 1, 2)
+ kp2d_targets = torch.cat(
+ [kp2d_targets, kp2d_weights[..., 0].unsqueeze(-1)], dim=-1)
+ # kp3d target
+ kp3d_targets = torch.zeros_like(pred_kp3d)
+ kp3d_weights = torch.zeros_like(pred_kp3d)
+ kp3d_targets[pos_inds] = gt_kp3d[sampling_result.pos_assigned_gt_inds][
+ ..., :3]
+ kp3d_weights[pos_inds] = gt_kp3d[sampling_result.pos_assigned_gt_inds][
+ ..., [3]].repeat(1, 1, 3)
+ kp3d_targets = torch.cat(
+ [kp3d_targets, kp3d_weights[..., 0].unsqueeze(-1)], dim=-1)
+ # smpl_pose target
+ smpl_pose_targets = torch.zeros_like(pred_smpl_pose)
+ smpl_pose_weights = torch.zeros_like(pred_smpl_pose)
+ gt_smpl_pose_rotmat = batch_rodrigues(gt_smpl_pose.view(-1, 3)).view(
+ -1, 24, 3, 3)
+ smpl_pose_targets[pos_inds] = gt_smpl_pose_rotmat[
+ sampling_result.pos_assigned_gt_inds]
+ smpl_pose_weights[pos_inds] = 1.0
+
+ # smpl_beta target
+ smpl_shape_targets = torch.zeros_like(pred_smpl_shape)
+ smpl_shape_weights = torch.zeros_like(pred_smpl_shape)
+ smpl_shape_targets[pos_inds] = gt_smpl_shape[
+ sampling_result.pos_assigned_gt_inds]
+ smpl_shape_weights[pos_inds] = 1.0
+
+ # verts
+ if self.body_model_train is not None:
+ gt_output = self.body_model_train(
+ betas=gt_smpl_shape,
+ body_pose=gt_smpl_pose_rotmat[:, 1:],
+ global_orient=gt_smpl_pose_rotmat[:, 0].unsqueeze(1),
+ pose2rot=False)
+ gt_vertices = gt_output['vertices']
+ gt_model_joints = gt_output['joints']
+
+ vert_targets = torch.zeros_like(pred_vert)
+ vert_weights = torch.zeros_like(pred_vert)
+ vert_targets[pos_inds] = gt_vertices[
+ sampling_result.pos_assigned_gt_inds]
+ vert_weights[pos_inds] = 1.0
+
+ if has_keypoints2d is not None:
+ has_keypoints2d_ = torch.zeros(
+ (num_query, 1)).to(smpl_pose_targets.device)
+ has_keypoints2d_[pos_inds] = has_keypoints2d[
+ sampling_result.pos_assigned_gt_inds]
+ else:
+ has_keypoints2d_ = None
+
+ if has_keypoints3d is not None:
+ has_keypoints3d_ = torch.zeros(
+ (num_query, 1)).to(smpl_pose_targets.device)
+ has_keypoints3d_[pos_inds] = has_keypoints3d[
+ sampling_result.pos_assigned_gt_inds]
+ else:
+ has_keypoints3d_ = None
+
+ if has_smpl is not None:
+ has_smpl_ = torch.zeros(
+ (num_query, 1)).to(smpl_pose_targets.device)
+ # if len(sampling_result.pos_assigned_gt_inds) == 1:
+ # has_smpl_[pos_inds] = has_smpl
+ # else:
+ has_smpl_[pos_inds] = has_smpl[
+ sampling_result.pos_assigned_gt_inds]
+ else:
+ has_smpl_ = None
+ return (kp2d_targets, kp2d_weights, kp3d_targets, kp3d_weights,
+ smpl_pose_targets, smpl_pose_weights, smpl_shape_targets,
+ smpl_shape_weights, vert_targets, vert_weights, has_smpl_,
+ has_keypoints2d_, has_keypoints3d_, pos_inds, neg_inds)
+
+ def forward_test(self, img, img_metas, **kwargs):
+ batch_input_shape = tuple(img[0].size()[-2:])
+ for img_meta in img_metas:
+ img_meta['batch_input_shape'] = batch_input_shape
+ features = self.backbone(img)
+ if self.neck is not None:
+ features = self.neck(features)
+ pred_pose, pred_betas, pred_cam, _, _ = \
+ self.head(features, img_metas)
+
+ # pred_pose = pred_pose[-1]
+ # pred_betas = pred_betas[-1]
+ # pred_cam = pred_cam[-1]
+
+ L, B, N = pred_pose.shape[:3]
+ if self.body_model_test is not None:
+ pred_output = self.body_model_test(
+ betas=pred_betas.reshape(L * B * N, 10),
+ body_pose=pred_pose.reshape(L * B * N, 24, 3, 3)[:, 1:],
+ global_orient=pred_pose.reshape(L * B * N, 24, 3,
+ 3)[:, 0].unsqueeze(1),
+ pose2rot=False)
+ else:
+ raise ValueError('Please provide a builded body model.')
+
+ pred_keypoints_3d = pred_output['joints'].reshape(L, B, N, -1, 3)
+ pred_keypoints_3d = (pred_keypoints_3d -
+ pred_keypoints_3d[..., [0], :])
+ pred_keypoints_3d = pred_keypoints_3d.detach().cpu().numpy()
+ # pred_vertices = pred_output['vertices'].reshape(L, B, N, 6890, 3)
+ pred_cam = pred_cam.detach().cpu().numpy()
+ pred_pose = pred_pose.detach().cpu().numpy()
+ pred_betas = pred_betas.detach().cpu().numpy()
+ # batch, instance_num, kp_num, 4
+ gt_keypoints3d = kwargs['keypoints3d'].repeat([1, N, 1, 1]).clone()
+ # keypoints3d_mask = kwargs['keypoints3d_mask']
+ gt_keypoints3d = gt_keypoints3d.detach().cpu().numpy()
+ # gt_keypoints3d, _ = convert_kps(
+ # gt_keypoints3d,
+ # src='human_data',
+ # dst='h36m')
+
+ cost = np.sum((pred_keypoints_3d[-1] - gt_keypoints3d[..., :3]),
+ axis=(2, 3))
+ index = np.argmin(abs(cost), -1)
+
+ pred_keypoints_3d_ = []
+ pred_pose_ = []
+ pred_betas_ = []
+ pred_cam_ = []
+
+ for batch_i in range(B):
+ ind = index[batch_i]
+ pred_keypoints_3d_.append(pred_keypoints_3d[-1, batch_i, ind])
+ pred_pose_.append(pred_pose[-1, batch_i, ind])
+ pred_betas_.append(pred_betas[-1, batch_i, ind])
+ pred_cam_.append(pred_cam[-1, batch_i, ind])
+
+ # for img_id in range(len(img_metas)):
+ # pred_pose_ = pred_pose[:, img_id]
+ # pred_betas_ = pred_betas[:, img_id]
+ # pred_cam_ = pred_cam[:, img_id]
+ # pred_keypoints_3d_ = pred_keypoints_3d[:, img_id]
+ # pred_vertices_ = pred_vertices[:, img_id]
+ # img_shape_ = img_metas[img_id]['img_shape']
+
+ # result_list.append()
+
+ all_preds = {}
+ all_preds['keypoints_3d'] = np.array(pred_keypoints_3d_)
+ all_preds['smpl_pose'] = np.array(pred_pose_)
+ all_preds['smpl_beta'] = np.array(pred_betas_)
+ all_preds['camera'] = np.array(pred_cam_)
+ # all_preds['vertices'] = pred_vertices.detach().cpu().numpy()
+
+ image_path = []
+ for img_meta in img_metas:
+ image_path.append(img_meta['image_path'])
+ all_preds['image_path'] = image_path
+ all_preds['image_idx'] = kwargs['sample_idx']
+ return all_preds
+ # loss
+
+ def compute_keypoints3d_loss(
+ self,
+ pred_keypoints3d: torch.Tensor,
+ gt_keypoints3d: torch.Tensor,
+ has_keypoints3d: Optional[torch.Tensor] = None):
+ """Compute loss for 3d keypoints."""
+ keypoints3d_conf = gt_keypoints3d[:, :, 3].float().unsqueeze(-1)
+ keypoints3d_conf = keypoints3d_conf.repeat(1, 1, 3)
+ pred_keypoints3d = pred_keypoints3d.float()
+ gt_keypoints3d = gt_keypoints3d[:, :, :3].float()
+
+ # currently, only mpi_inf_3dhp and h36m have 3d keypoints
+ # both datasets have right_hip_extra and left_hip_extra
+ right_hip_idx = get_keypoint_idx('right_hip_extra', self.convention)
+ left_hip_idx = get_keypoint_idx('left_hip_extra', self.convention)
+ gt_pelvis = (gt_keypoints3d[:, right_hip_idx, :] +
+ gt_keypoints3d[:, left_hip_idx, :]) / 2
+ pred_pelvis = (pred_keypoints3d[:, right_hip_idx, :] +
+ pred_keypoints3d[:, left_hip_idx, :]) / 2
+
+ gt_keypoints3d = gt_keypoints3d - gt_pelvis[:, None, :]
+ pred_keypoints3d = pred_keypoints3d - pred_pelvis[:, None, :]
+ loss = self.loss_keypoints3d(pred_keypoints3d,
+ gt_keypoints3d,
+ reduction_override='none')
+
+ # If has_keypoints3d is not None, then computes the losses on the
+ # instances that have ground-truth keypoints3d.
+ # But the zero confidence keypoints will be included in mean.
+ # Otherwise, only compute the keypoints3d
+ # which have positive confidence.
+
+ # has_keypoints3d is None when the key has_keypoints3d
+ # is not in the datasets
+ if has_keypoints3d is None:
+
+ valid_pos = keypoints3d_conf > 0
+ if keypoints3d_conf[valid_pos].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints3d)
+ loss = torch.sum(loss * keypoints3d_conf)
+ loss /= keypoints3d_conf[valid_pos].numel()
+ else:
+
+ keypoints3d_conf = keypoints3d_conf[has_keypoints3d == 1]
+ if keypoints3d_conf.shape[0] == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints3d)
+ loss = loss[has_keypoints3d == 1]
+ loss = (loss * keypoints3d_conf).mean()
+ return loss
+
+ def compute_keypoints2d_loss(
+ self,
+ pred_keypoints3d: torch.Tensor,
+ pred_cam: torch.Tensor,
+ gt_keypoints2d: torch.Tensor,
+ img_res: Optional[int] = 512,
+ focal_length: Optional[int] = 5000.,
+ has_keypoints2d: Optional[torch.Tensor] = None):
+ """Compute loss for 2d keypoints."""
+ keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1)
+ keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2)
+ gt_keypoints2d = gt_keypoints2d[:, :, :2].float()
+ pred_keypoints2d = project_points(pred_keypoints3d,
+ pred_cam,
+ focal_length=focal_length,
+ img_res=img_res)
+ # Normalize keypoints to [-1,1]
+ # The coordinate origin of pred_keypoints_2d is
+ # the center of the input image.
+ pred_keypoints2d = 2 * pred_keypoints2d / (img_res - 1)
+ # The coordinate origin of gt_keypoints_2d is
+ # the top left corner of the input image.
+ gt_keypoints2d = 2 * gt_keypoints2d / (img_res - 1) - 1
+ loss = self.loss_keypoints2d(pred_keypoints2d,
+ gt_keypoints2d,
+ reduction_override='none')
+
+ # If has_keypoints2d is not None, then computes the losses on the
+ # instances that have ground-truth keypoints2d.
+ # But the zero confidence keypoints will be included in mean.
+ # Otherwise, only compute the keypoints2d
+ # which have positive confidence.
+ # has_keypoints2d is None when the key has_keypoints2d
+ # is not in the datasets
+
+ if has_keypoints2d is None:
+ valid_pos = keypoints2d_conf > 0
+ if keypoints2d_conf[valid_pos].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints2d)
+ loss = torch.sum(loss * keypoints2d_conf)
+ loss /= keypoints2d_conf[valid_pos].numel()
+ else:
+ keypoints2d_conf = keypoints2d_conf[has_keypoints2d == 1]
+ if keypoints2d_conf.shape[0] == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints2d)
+ loss = loss[has_keypoints2d == 1]
+ loss = (loss * keypoints2d_conf).mean()
+
+ return loss
+
+ def compute_vertex_loss(self, pred_vertices: torch.Tensor,
+ gt_vertices: torch.Tensor, has_smpl: torch.Tensor):
+ """Compute loss for vertices."""
+ gt_vertices = gt_vertices.float()
+ conf = has_smpl.float().view(-1, 1, 1)
+ conf = conf.repeat(1, gt_vertices.shape[1], gt_vertices.shape[2])
+ loss = self.loss_vertex(pred_vertices,
+ gt_vertices,
+ reduction_override='none')
+ valid_pos = conf > 0
+ if conf[valid_pos].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_vertices)
+ loss = torch.sum(loss * conf) / conf[valid_pos].numel()
+ return loss
+
+ def compute_smpl_pose_loss(self, pred_pose: torch.Tensor,
+ gt_pose: torch.Tensor, has_smpl: torch.Tensor):
+ """Compute loss for smpl pose."""
+ conf = has_smpl.float().view(-1)
+ valid_pos = conf > 0
+ if conf[valid_pos].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_pose)
+ pred_pose = pred_pose[valid_pos]
+ gt_pose = gt_pose[valid_pos]
+ conf = conf[valid_pos]
+ # gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24, 3, 3)
+ loss = self.loss_smpl_pose(pred_pose,
+ gt_pose,
+ reduction_override='none')
+ loss = loss.view(loss.shape[0], -1).mean(-1)
+ loss = torch.mean(loss * conf)
+ return loss
+
+ def compute_smpl_betas_loss(self, pred_betas: torch.Tensor,
+ gt_betas: torch.Tensor,
+ has_smpl: torch.Tensor):
+ """Compute loss for smpl betas."""
+ conf = has_smpl.float().view(-1)
+ valid_pos = conf > 0
+ if conf[valid_pos].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_betas)
+ pred_betas = pred_betas[valid_pos]
+ gt_betas = gt_betas[valid_pos]
+ conf = conf[valid_pos]
+ loss = self.loss_smpl_betas(pred_betas,
+ gt_betas,
+ reduction_override='none')
+ loss = loss.view(loss.shape[0], -1).mean(-1)
+ loss = torch.mean(loss * conf)
+ return loss
+
+ def compute_camera_loss(self, cameras: torch.Tensor):
+ """Compute loss for predicted camera parameters."""
+ loss = self.loss_camera(cameras)
+ return loss
diff --git a/detrsmpl/models/architectures/__init__.py b/detrsmpl/models/architectures/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/models/architectures/base_architecture.py b/detrsmpl/models/architectures/base_architecture.py
new file mode 100644
index 0000000000000000000000000000000000000000..09c6e5130cdcef666b4407a5ae891370d877601a
--- /dev/null
+++ b/detrsmpl/models/architectures/base_architecture.py
@@ -0,0 +1,108 @@
+from abc import ABCMeta, abstractmethod
+from collections import OrderedDict
+
+import torch
+import torch.distributed as dist
+from mmcv.runner import BaseModule
+
+
+class BaseArchitecture(BaseModule, metaclass=ABCMeta):
+ """Base class for mmhuman3d architecture."""
+ def __init__(self, init_cfg=None):
+ super(BaseArchitecture, self).__init__(init_cfg)
+
+ @abstractmethod
+ def forward_train(self, **kwargs):
+ pass
+
+ @abstractmethod
+ def forward_test(self, **kwargs):
+ pass
+
+ def _parse_losses(self, losses):
+ """Parse the raw outputs (losses) of the network.
+
+ Args:
+ losses (dict): Raw output of the network, which usually contain
+ losses and other necessary information.
+ Returns:
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \
+ which may be a weighted sum of all losses, log_vars contains \
+ all the variables to be sent to the logger.
+ """
+ log_vars = OrderedDict()
+ for loss_name, loss_value in losses.items():
+ if isinstance(loss_value, torch.Tensor):
+ log_vars[loss_name] = loss_value.mean()
+ elif isinstance(loss_value, list):
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+ else:
+ raise TypeError(
+ f'{loss_name} is not a tensor or list of tensors')
+
+ loss = sum(_value for _key, _value in log_vars.items()
+ if 'loss' in _key)
+
+ log_vars['loss'] = loss
+ for loss_name, loss_value in log_vars.items():
+ # reduce loss when distributed training
+ if dist.is_available() and dist.is_initialized():
+ loss_value = loss_value.data.clone()
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
+ log_vars[loss_name] = loss_value.item()
+
+ return loss, log_vars
+
+ def train_step(self, data, optimizer):
+ """The iteration step during training.
+ This method defines an iteration step during training, except for the
+ back propagation and optimizer updating, which are done in an optimizer
+ hook. Note that in some complicated cases or models, the whole process
+ including back propagation and optimizer updating is also defined in
+ this method, such as GAN.
+ Args:
+ data (dict): The output of dataloader.
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
+ runner is passed to ``train_step()``. This argument is unused
+ and reserved.
+ Returns:
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
+ ``num_samples``.
+ - ``loss`` is a tensor for back propagation, which can be a
+ weighted sum of multiple losses.
+ - ``log_vars`` contains all the variables to be sent to the
+ logger.
+ - ``num_samples`` indicates the batch size (when the model is
+ DDP, it means the batch size on each GPU), which is used for
+ averaging the logs.
+ """
+ losses = self(**data)
+ loss, log_vars = self._parse_losses(losses)
+
+ outputs = dict(loss=loss,
+ log_vars=log_vars,
+ num_samples=len(data['img_metas']))
+
+ return outputs
+
+ def val_step(self, data, optimizer=None):
+ """The iteration step during validation.
+
+ This method shares the same signature as :func:`train_step`, but used
+ during val epochs. Note that the evaluation after training epochs is
+ not implemented with this method, but an evaluation hook.
+ """
+ losses = self(**data)
+ loss, log_vars = self._parse_losses(losses)
+
+ outputs = dict(loss=loss,
+ log_vars=log_vars,
+ num_samples=len(data['img_metas']))
+
+ return outputs
+
+ def forward(self, **kwargs):
+ if self.training:
+ return self.forward_train(**kwargs)
+ else:
+ return self.forward_test(**kwargs)
diff --git a/detrsmpl/models/architectures/builder.py b/detrsmpl/models/architectures/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cd876521feb1c96a5a6fa07a7d5211d59574eed
--- /dev/null
+++ b/detrsmpl/models/architectures/builder.py
@@ -0,0 +1,37 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from mmcv.cnn import MODELS as MMCV_MODELS
+from mmcv.utils import Registry
+
+from .DetrSMPL import MultiBodyEstimator
+from .expressive_mesh_estimator import SMPLXImageBodyModelEstimator
+from .hybrik import HybrIK_trainer
+from .mesh_estimator import ImageBodyModelEstimator, VideoBodyModelEstimator
+from .DetrSMPLloss import DETRLoss
+
+
+def build_from_cfg(cfg, registry, default_args=None):
+ if cfg is None:
+ return None
+ return MMCV_MODELS.build_func(cfg, registry, default_args)
+
+
+ARCHITECTURES = Registry('architectures',
+ parent=MMCV_MODELS,
+ build_func=build_from_cfg)
+
+ARCHITECTURES.register_module(name='HybrIK_trainer', module=HybrIK_trainer)
+ARCHITECTURES.register_module(name='ImageBodyModelEstimator',
+ module=ImageBodyModelEstimator)
+ARCHITECTURES.register_module(name='VideoBodyModelEstimator',
+ module=VideoBodyModelEstimator)
+ARCHITECTURES.register_module(name='SMPLXImageBodyModelEstimator',
+ module=SMPLXImageBodyModelEstimator)
+ARCHITECTURES.register_module(name='MultiBodyEstimator',
+ module=MultiBodyEstimator)
+ARCHITECTURES.register_module(name='DETRLoss', module=DETRLoss)
+
+
+def build_architecture(cfg):
+ """Build framework."""
+ return ARCHITECTURES.build(cfg)
diff --git a/detrsmpl/models/architectures/expressive_mesh_estimator.py b/detrsmpl/models/architectures/expressive_mesh_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a45a0c06ae26db0336ff4c122d936ed5374f365
--- /dev/null
+++ b/detrsmpl/models/architectures/expressive_mesh_estimator.py
@@ -0,0 +1,848 @@
+from abc import ABCMeta, abstractmethod
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from detrsmpl.core.conventions.keypoints_mapping import (
+ get_keypoint_idx,
+ get_keypoint_idxs_by_part,
+)
+from detrsmpl.utils.geometry import (
+ batch_rodrigues,
+ weak_perspective_projection,
+)
+from ..backbones.builder import build_backbone
+from ..body_models.builder import build_body_model
+from ..heads.builder import build_head
+from ..losses.builder import build_loss
+from ..necks.builder import build_neck
+from ..utils import (
+ SMPLXFaceCropFunc,
+ SMPLXFaceMergeFunc,
+ SMPLXHandCropFunc,
+ SMPLXHandMergeFunc,
+)
+from .base_architecture import BaseArchitecture
+
+
+def set_requires_grad(nets, requires_grad=False):
+ """Set requies_grad for all the networks.
+
+ Args:
+ nets (nn.Module | list[nn.Module]): A list of networks or a single
+ network.
+ requires_grad (bool): Whether the networks require gradients or not
+ """
+ if not isinstance(nets, list):
+ nets = [nets]
+ for net in nets:
+ if net is not None:
+ for param in net.parameters():
+ param.requires_grad = requires_grad
+
+
+def pose2rotmat(pred_pose):
+ """aa2rotmat."""
+ if len(pred_pose.shape) == 3:
+ num_joints = pred_pose.shape[1]
+ pred_pose = batch_rodrigues(pred_pose.view(-1, 3)).view(
+ -1, num_joints, 3, 3)
+ return pred_pose
+
+
+class SMPLXBodyModelEstimator(BaseArchitecture, metaclass=ABCMeta):
+ """BodyModelEstimator Architecture.
+
+ Args:
+ backbone (dict | None, optional): Backbone config dict. Default: None.
+ neck (dict | None, optional): Neck config dict. Default: None
+ head (dict | None, optional): Regressor config dict. Default: None.
+ body_model_train (dict | None, optional): SMPL config dict during
+ training. Default: None.
+ body_model_test (dict | None, optional): SMPL config dict during
+ test. Default: None.
+ convention (str, optional): Keypoints convention. Default: "human_data"
+ loss_keypoints2d (dict | None, optional): Losses config dict for
+ 2D keypoints. Default: None.
+ loss_keypoints3d (dict | None, optional): Losses config dict for
+ 3D keypoints. Default: None.
+ loss_smplx_global_orient (dict | None, optional): Losses config dict
+ for smplx global orient. Default: None
+ loss_smplx_body_pose (dict | None, optional): Losses config dict
+ for smplx body pose. Default: None
+ loss_smplx_hand_pose (dict | None, optional): Losses config dict
+ for smplx hand pose. Default: None
+ loss_smplx_jaw_pose (dict | None, optional): Losses config dict
+ for smplx jaw pose. Default: None
+ loss_smplx_expression (dict | None, optional): Losses config dict
+ for smplx expression. Default: None
+ loss_smplx_betas (dict | None, optional): Losses config dict for smplx
+ betas. Default: None
+ loss_camera (dict | None, optional): Losses config dict for predicted
+ camera parameters. Default: None
+ extra_hand_model_cfg (dict | None, optional) : Hand model config for
+ refining body model prediction. Default: None
+ extra_face_model_cfg (dict | None, optional) : Face model config for
+ refining body model prediction. Default: None
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+ def __init__(self,
+ backbone: Optional[Union[dict, None]] = None,
+ neck: Optional[Union[dict, None]] = None,
+ head: Optional[Union[dict, None]] = None,
+ body_model_train: Optional[Union[dict, None]] = None,
+ body_model_test: Optional[Union[dict, None]] = None,
+ convention: Optional[str] = 'human_data',
+ loss_keypoints2d: Optional[Union[dict, None]] = None,
+ loss_keypoints3d: Optional[Union[dict, None]] = None,
+ loss_smplx_global_orient: Optional[Union[dict, None]] = None,
+ loss_smplx_body_pose: Optional[Union[dict, None]] = None,
+ loss_smplx_hand_pose: Optional[Union[dict, None]] = None,
+ loss_smplx_jaw_pose: Optional[Union[dict, None]] = None,
+ loss_smplx_expression: Optional[Union[dict, None]] = None,
+ loss_smplx_betas: Optional[Union[dict, None]] = None,
+ loss_smplx_betas_prior: Optional[Union[dict, None]] = None,
+ loss_camera: Optional[Union[dict, None]] = None,
+ extra_hand_model_cfg: Optional[Union[dict, None]] = None,
+ extra_face_model_cfg: Optional[Union[dict, None]] = None,
+ frozen_batchnorm: bool = False,
+ init_cfg: Optional[Union[list, dict, None]] = None):
+ super(SMPLXBodyModelEstimator, self).__init__(init_cfg)
+ self.backbone = build_backbone(backbone)
+ self.neck = build_neck(neck)
+ self.head = build_head(head)
+
+ if frozen_batchnorm:
+ for param in self.backbone.parameters():
+ param.requires_grad = False
+ for param in self.head.parameters():
+ param.requires_grad = False
+
+ self.backbone = FrozenBatchNorm2d.convert_frozen_batchnorm(
+ self.backbone)
+ self.head = FrozenBatchNorm2d.convert_frozen_batchnorm(self.head)
+
+ self.body_model_train = build_body_model(body_model_train)
+ self.body_model_test = build_body_model(body_model_test)
+ self.convention = convention
+
+ self.apply_hand_model = False
+ self.apply_face_model = False
+ if extra_hand_model_cfg is not None:
+ self.hand_backbone = build_backbone(
+ extra_hand_model_cfg.get('backbone', None))
+ self.hand_neck = build_neck(extra_hand_model_cfg.get('neck', None))
+ self.hand_head = build_head(extra_hand_model_cfg.get('head', None))
+ crop_cfg = extra_hand_model_cfg.get('crop_cfg', None)
+ if crop_cfg is not None:
+ self.crop_hand_func = SMPLXHandCropFunc(
+ self.hand_head,
+ self.body_model_train,
+ convention=self.convention,
+ **crop_cfg)
+ self.hand_merge_func = SMPLXHandMergeFunc(
+ self.body_model_train, self.convention)
+ self.hand_crop_loss = build_loss(
+ extra_hand_model_cfg.get('loss_hand_crop', None))
+ self.apply_hand_model = True
+ self.left_hand_idxs = get_keypoint_idxs_by_part(
+ 'left_hand', self.convention)
+ self.left_hand_idxs.append(
+ get_keypoint_idx('left_wrist', self.convention))
+ self.left_hand_idxs = sorted(self.left_hand_idxs)
+ self.right_hand_idxs = get_keypoint_idxs_by_part(
+ 'right_hand', self.convention)
+ self.right_hand_idxs.append(
+ get_keypoint_idx('right_wrist', self.convention))
+ self.right_hand_idxs = sorted(self.right_hand_idxs)
+
+ if extra_face_model_cfg is not None:
+ self.face_backbone = build_backbone(
+ extra_face_model_cfg.get('backbone', None))
+ self.face_neck = build_neck(extra_face_model_cfg.get('neck', None))
+ self.face_head = build_head(extra_face_model_cfg.get('head', None))
+ crop_cfg = extra_face_model_cfg.get('crop_cfg', None)
+ if crop_cfg is not None:
+ self.crop_face_func = SMPLXFaceCropFunc(
+ self.face_head,
+ self.body_model_train,
+ convention=self.convention,
+ **crop_cfg)
+ self.face_merge_func = SMPLXFaceMergeFunc(
+ self.body_model_train, self.convention)
+ self.face_crop_loss = build_loss(
+ extra_face_model_cfg.get('loss_face_crop', None))
+ self.apply_face_model = True
+ self.face_idxs = get_keypoint_idxs_by_part('head', self.convention)
+ self.face_idxs = sorted(self.face_idxs)
+
+ self.loss_keypoints2d = build_loss(loss_keypoints2d)
+ self.loss_keypoints3d = build_loss(loss_keypoints3d)
+
+ self.loss_smplx_global_orient = build_loss(loss_smplx_global_orient)
+ self.loss_smplx_body_pose = build_loss(loss_smplx_body_pose)
+ self.loss_smplx_hand_pose = build_loss(loss_smplx_hand_pose)
+ self.loss_smplx_jaw_pose = build_loss(loss_smplx_jaw_pose)
+ self.loss_smplx_expression = build_loss(loss_smplx_expression)
+ self.loss_smplx_betas = build_loss(loss_smplx_betas)
+ self.loss_smplx_betas_piror = build_loss(loss_smplx_betas_prior)
+ self.loss_camera = build_loss(loss_camera)
+ set_requires_grad(self.body_model_train, False)
+ set_requires_grad(self.body_model_test, False)
+
+ def train_step(self, data_batch, optimizer, **kwargs):
+ """Train step function.
+
+ Args:
+ data_batch (torch.Tensor): Batch of data as input.
+ optimizer (dict[torch.optim.Optimizer]): Dict with optimizers for
+ generator.
+ Returns:
+ outputs (dict): Dict with loss, information for logger,
+ the number of samples.
+ """
+ if self.backbone is not None:
+ img = data_batch['img']
+ features = self.backbone(img)
+ else:
+ features = data_batch['features']
+
+ if self.neck is not None:
+ features = self.neck(features)
+
+ predictions = self.head(features)
+ if self.apply_hand_model:
+ hand_input_img, hand_mean, hand_crop_info = self.crop_hand_func(
+ predictions, data_batch['img_metas'])
+ hand_features = self.hand_backbone(hand_input_img)
+ if self.neck is not None:
+ hand_features = self.hand_neck(hand_features)
+ hand_predictions = self.hand_head(hand_features, cond=hand_mean)
+ predictions = self.hand_merge_func(predictions, hand_predictions)
+ predictions['hand_crop_info'] = hand_crop_info
+ if self.apply_face_model:
+ face_input_img, face_mean, face_crop_info = self.crop_face_func(
+ predictions, data_batch['img_metas'])
+ face_features = self.face_backbone(face_input_img)
+ if self.neck is not None:
+ face_features = self.face_neck(face_features)
+ face_predictions = self.face_head(face_features, cond=face_mean)
+ predictions = self.face_merge_func(predictions, face_predictions)
+ predictions['face_crop_info'] = face_crop_info
+
+ targets = self.prepare_targets(data_batch)
+
+ losses = self.compute_losses(predictions, targets)
+
+ loss, log_vars = self._parse_losses(losses)
+ if self.backbone is not None:
+ optimizer['backbone'].zero_grad()
+ if self.neck is not None:
+ optimizer['neck'].zero_grad()
+ if self.head is not None:
+ optimizer['head'].zero_grad()
+
+ if self.apply_hand_model:
+ if self.hand_backbone is not None:
+ optimizer['hand_backbone'].zero_grad()
+ if self.hand_neck is not None:
+ optimizer['hand_neck'].zero_grad()
+ if self.hand_head is not None:
+ optimizer['hand_head'].zero_grad()
+
+ if self.apply_face_model:
+ if self.face_backbone is not None:
+ optimizer['face_backbone'].zero_grad()
+ if self.face_neck is not None:
+ optimizer['face_neck'].zero_grad()
+ if self.face_head is not None:
+ optimizer['face_head'].zero_grad()
+
+ loss.backward()
+ if self.backbone is not None:
+ optimizer['backbone'].step()
+ if self.neck is not None:
+ optimizer['neck'].step()
+ if self.head is not None:
+ optimizer['head'].step()
+
+ if self.apply_hand_model:
+ if self.hand_backbone is not None:
+ optimizer['hand_backbone'].step()
+ if self.hand_neck is not None:
+ optimizer['hand_neck'].step()
+ if self.hand_head is not None:
+ optimizer['hand_head'].step()
+
+ if self.apply_face_model:
+ if self.face_backbone is not None:
+ optimizer['face_backbone'].step()
+ if self.face_neck is not None:
+ optimizer['face_neck'].step()
+ if self.face_head is not None:
+ optimizer['face_head'].step()
+
+ outputs = dict(loss=loss,
+ log_vars=log_vars,
+ num_samples=len(next(iter(data_batch.values()))))
+ return outputs
+
+ def compute_keypoints3d_loss(
+ self,
+ pred_keypoints3d: torch.Tensor,
+ gt_keypoints3d: torch.Tensor,
+ has_keypoints3d: Optional[torch.Tensor] = None):
+ """Compute loss for 3d keypoints."""
+ keypoints3d_conf = gt_keypoints3d[:, :, 3].float().unsqueeze(-1)
+ keypoints3d_conf = keypoints3d_conf.repeat(1, 1, 3)
+ pred_keypoints3d = pred_keypoints3d.float()
+ gt_keypoints3d = gt_keypoints3d[:, :, :3].float()
+
+ if has_keypoints3d is None:
+ has_keypoints3d = torch.ones((keypoints3d_conf.shape[0]))
+ if keypoints3d_conf[has_keypoints3d == 1].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints3d)
+ # Center the predictions using the pelvis
+ target_idxs = has_keypoints3d == 1
+ pred_keypoints3d = pred_keypoints3d[target_idxs]
+ gt_keypoints3d = gt_keypoints3d[target_idxs]
+ pred_pelvis = pred_keypoints3d[:, [1, 2], :].mean(dim=1, keepdim=True)
+ pred_keypoints3d = pred_keypoints3d - pred_pelvis
+ gt_pelvis = gt_keypoints3d[:, [1, 2], :].mean(dim=1, keepdim=True)
+ gt_keypoints3d = gt_keypoints3d - gt_pelvis
+
+ loss = self.loss_keypoints3d(pred_keypoints3d,
+ gt_keypoints3d,
+ weight=keypoints3d_conf[target_idxs])
+ loss /= gt_keypoints3d.shape[0]
+ return loss
+
+ def compute_keypoints2d_loss(
+ self,
+ pred_keypoints3d: torch.Tensor,
+ pred_cam: torch.Tensor,
+ gt_keypoints2d: torch.Tensor,
+ img_res: Optional[int] = 224,
+ focal_length: Optional[int] = 5000,
+ has_keypoints2d: Optional[torch.Tensor] = None):
+ """Compute loss for 2d keypoints."""
+ keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1)
+ keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2)
+ gt_keypoints2d = gt_keypoints2d[:, :, :2].float()
+ if has_keypoints2d is None:
+ has_keypoints2d = torch.ones((keypoints2d_conf.shape[0]))
+ if keypoints2d_conf[has_keypoints2d == 1].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints2d)
+
+ # Expose use weak_perspective_projection
+ pred_keypoints2d = weak_perspective_projection(
+ pred_keypoints3d,
+ scale=pred_cam[:, 0],
+ translation=pred_cam[:, 1:3])
+ gt_keypoints2d = 2 * gt_keypoints2d / (img_res - 1) - 1
+
+ target_idxs = has_keypoints2d == 1
+ pred_keypoints2d = pred_keypoints2d[target_idxs]
+ gt_keypoints2d = gt_keypoints2d[target_idxs]
+ loss = self.loss_keypoints2d(pred_keypoints2d,
+ gt_keypoints2d,
+ weight=keypoints2d_conf[target_idxs])
+ loss /= gt_keypoints2d.shape[0]
+ return loss
+
+ def compute_smplx_body_pose_loss(self, pred_rotmat: torch.Tensor,
+ gt_pose: torch.Tensor,
+ has_smplx_body_pose: torch.Tensor):
+ """Compute loss for smplx body pose."""
+ num_joints = pred_rotmat.shape[1]
+ target_idxs = has_smplx_body_pose == 1
+ if gt_pose[target_idxs].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_pose)
+
+ gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(
+ -1, num_joints, 3, 3)
+
+ loss = self.loss_smplx_body_pose(pred_rotmat[target_idxs],
+ gt_rotmat[target_idxs])
+ return loss
+
+ def compute_smplx_global_orient_loss(
+ self, pred_rotmat: torch.Tensor, gt_global_orient: torch.Tensor,
+ has_smplx_global_orient: torch.Tensor):
+ """Compute loss for smplx global orient."""
+ target_idxs = has_smplx_global_orient == 1
+ if gt_global_orient[target_idxs].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_global_orient)
+
+ gt_rotmat = batch_rodrigues(gt_global_orient.view(-1, 3)).view(
+ -1, 1, 3, 3)
+
+ loss = self.loss_smplx_global_orient(pred_rotmat[target_idxs],
+ gt_rotmat[target_idxs])
+ return loss
+
+ def compute_smplx_jaw_pose_loss(self, pred_rotmat: torch.Tensor,
+ gt_jaw_pose: torch.Tensor,
+ has_smplx_jaw_pose: torch.Tensor,
+ face_conf: torch.Tensor):
+ """Compute loss for smplx jaw pose."""
+ target_idxs = has_smplx_jaw_pose == 1
+ if gt_jaw_pose[target_idxs].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_jaw_pose)
+
+ gt_rotmat = batch_rodrigues(gt_jaw_pose.view(-1, 3)).view(-1, 1, 3, 3)
+ conf = face_conf.mean(axis=1).float()
+ conf = conf.view(-1, 1, 1, 1)
+
+ loss = self.loss_smplx_jaw_pose(pred_rotmat[target_idxs],
+ gt_rotmat[target_idxs],
+ weight=conf[target_idxs])
+ return loss
+
+ def compute_smplx_hand_pose_loss(self, pred_rotmat: torch.Tensor,
+ gt_hand_pose: torch.Tensor,
+ has_smplx_hand_pose: torch.Tensor,
+ hand_conf: torch.Tensor):
+ """Compute loss for smplx left/right hand pose."""
+ joint_num = pred_rotmat.shape[1]
+ target_idxs = has_smplx_hand_pose == 1
+ if gt_hand_pose[target_idxs].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_hand_pose)
+ gt_rotmat = batch_rodrigues(gt_hand_pose.view(-1, 3)).view(
+ -1, joint_num, 3, 3)
+ conf = hand_conf.mean(axis=1,
+ keepdim=True).float().expand(-1, joint_num)
+ conf = conf.view(-1, joint_num, 1, 1)
+
+ loss = self.loss_smplx_hand_pose(pred_rotmat[target_idxs],
+ gt_rotmat[target_idxs],
+ weight=conf[target_idxs])
+ return loss
+
+ def compute_smplx_betas_loss(self, pred_betas: torch.Tensor,
+ gt_betas: torch.Tensor,
+ has_smplx_betas: torch.Tensor):
+ """Compute loss for smplx betas."""
+ target_idxs = has_smplx_betas == 1
+ if gt_betas[target_idxs].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_betas)
+
+ loss = self.loss_smplx_betas(pred_betas[target_idxs],
+ gt_betas[target_idxs])
+ loss = loss / gt_betas[target_idxs].shape[0]
+ return loss
+
+ def compute_smplx_betas_prior_loss(self, pred_betas: torch.Tensor):
+ """Compute prior loss for smplx betas."""
+ loss = self.loss_smplx_betas_piror(pred_betas)
+ return loss
+
+ def compute_smplx_expression_loss(self, pred_expression: torch.Tensor,
+ gt_expression: torch.Tensor,
+ has_smplx_expression: torch.Tensor,
+ face_conf: torch.Tensor):
+ """Compute loss for smplx betas."""
+ target_idxs = has_smplx_expression == 1
+ if gt_expression[target_idxs].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_expression)
+ conf = face_conf.mean(axis=1).float()
+ conf = conf.view(-1, 1)
+
+ loss = self.loss_smplx_expression(pred_expression[target_idxs],
+ gt_expression[target_idxs],
+ weight=conf[target_idxs])
+ loss = loss / gt_expression[target_idxs].shape[0]
+ return loss
+
+ def compute_camera_loss(self, cameras: torch.Tensor):
+ """Compute loss for predicted camera parameters."""
+ loss = self.loss_camera(cameras)
+ return loss
+
+ def compute_face_crop_loss(self,
+ pred_keypoints3d: torch.Tensor,
+ pred_cam: torch.Tensor,
+ gt_keypoints2d: torch.Tensor,
+ face_crop_info: dict,
+ img_res: Optional[int] = 224,
+ has_keypoints2d: Optional[torch.Tensor] = None):
+ """Compute face crop loss for 2d keypoints."""
+ keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1)
+ keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2)
+ gt_keypoints2d = gt_keypoints2d[:, :, :2].float()
+ if has_keypoints2d is None:
+ has_keypoints2d = torch.ones((keypoints2d_conf.shape[0]))
+ if keypoints2d_conf[has_keypoints2d == 1].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints2d)
+
+ # Expose use weak_perspective_projection
+ pred_keypoints2d = weak_perspective_projection(
+ pred_keypoints3d,
+ scale=pred_cam[:, 0],
+ translation=pred_cam[:, 1:3])
+ target_idxs = has_keypoints2d == 1
+ pred_keypoints2d = pred_keypoints2d[target_idxs]
+ gt_keypoints2d = gt_keypoints2d[target_idxs]
+
+ pred_keypoints2d = (0.5 * pred_keypoints2d + 0.5) * (img_res - 1)
+ face_inv_crop_transforms = face_crop_info['face_inv_crop_transforms']
+ pred_keypoints2d_hd = torch.einsum('bij,bkj->bki', [
+ face_inv_crop_transforms[:, :2, :2], pred_keypoints2d
+ ]) + face_inv_crop_transforms[:, :2, 2].unsqueeze(dim=1)
+ gt_keypoints2d_hd = torch.einsum('bij,bkj->bki', [
+ face_inv_crop_transforms[:, :2, :2], gt_keypoints2d
+ ]) + face_inv_crop_transforms[:, :2, 2].unsqueeze(dim=1)
+
+ pred_face_keypoints_hd = pred_keypoints2d_hd[:, self.face_idxs]
+ face_crop_transform = face_crop_info['face_crop_transform']
+ inv_face_crop_transf = torch.inverse(face_crop_transform)
+ face_img_keypoints = torch.einsum('bij,bkj->bki', [
+ inv_face_crop_transf[:, :2, :2], pred_face_keypoints_hd
+ ]) + inv_face_crop_transf[:, :2, 2].unsqueeze(dim=1)
+ gt_face_keypoints_hd = gt_keypoints2d_hd[:, self.face_idxs]
+ gt_face_keypoints = torch.einsum('bij,bkj->bki', [
+ inv_face_crop_transf[:, :2, :2], gt_face_keypoints_hd
+ ]) + inv_face_crop_transf[:, :2, 2].unsqueeze(dim=1)
+
+ loss = self.face_crop_loss(
+ face_img_keypoints,
+ gt_face_keypoints,
+ weight=keypoints2d_conf[target_idxs][:, self.face_idxs])
+ loss /= gt_face_keypoints.shape[0]
+ return loss
+
+ def compute_hand_crop_loss(self,
+ pred_keypoints3d: torch.Tensor,
+ pred_cam: torch.Tensor,
+ gt_keypoints2d: torch.Tensor,
+ hand_crop_info: dict,
+ img_res: Optional[int] = 224,
+ has_keypoints2d: Optional[torch.Tensor] = None):
+ """Compute hand crop loss for 2d keypoints."""
+ keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1)
+ keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2)
+ gt_keypoints2d = gt_keypoints2d[:, :, :2].float()
+ if has_keypoints2d is None:
+ has_keypoints2d = torch.ones((keypoints2d_conf.shape[0]))
+ if keypoints2d_conf[has_keypoints2d == 1].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints2d)
+
+ # Expose use weak_perspective_projection
+ pred_keypoints2d = weak_perspective_projection(
+ pred_keypoints3d,
+ scale=pred_cam[:, 0],
+ translation=pred_cam[:, 1:3])
+ target_idxs = has_keypoints2d == 1
+ pred_keypoints2d = pred_keypoints2d[target_idxs]
+ gt_keypoints2d = gt_keypoints2d[target_idxs]
+
+ pred_keypoints2d = (0.5 * pred_keypoints2d + 0.5) * (img_res - 1)
+ hand_inv_crop_transforms = hand_crop_info['hand_inv_crop_transforms']
+ pred_keypoints2d_hd = torch.einsum('bij,bkj->bki', [
+ hand_inv_crop_transforms[:, :2, :2], pred_keypoints2d
+ ]) + hand_inv_crop_transforms[:, :2, 2].unsqueeze(dim=1)
+ gt_keypoints2d_hd = torch.einsum('bij,bkj->bki', [
+ hand_inv_crop_transforms[:, :2, :2], gt_keypoints2d
+ ]) + hand_inv_crop_transforms[:, :2, 2].unsqueeze(dim=1)
+
+ pred_left_hand_keypoints_hd = pred_keypoints2d_hd[:,
+ self.left_hand_idxs]
+ left_hand_crop_transform = hand_crop_info['left_hand_crop_transform']
+ inv_left_hand_crop_transf = torch.inverse(left_hand_crop_transform)
+ left_hand_img_keypoints = torch.einsum('bij,bkj->bki', [
+ inv_left_hand_crop_transf[:, :2, :2], pred_left_hand_keypoints_hd
+ ]) + inv_left_hand_crop_transf[:, :2, 2].unsqueeze(dim=1)
+ gt_left_hand_keypoints_hd = gt_keypoints2d_hd[:, self.left_hand_idxs]
+ gt_left_hand_keypoints = torch.einsum('bij,bkj->bki', [
+ inv_left_hand_crop_transf[:, :2, :2], gt_left_hand_keypoints_hd
+ ]) + inv_left_hand_crop_transf[:, :2, 2].unsqueeze(dim=1)
+
+ pred_right_hand_keypoints_hd = pred_keypoints2d_hd[:, self.
+ right_hand_idxs]
+ right_hand_crop_transform = hand_crop_info['right_hand_crop_transform']
+ inv_right_hand_crop_transf = torch.inverse(right_hand_crop_transform)
+ right_hand_img_keypoints = torch.einsum('bij,bkj->bki', [
+ inv_right_hand_crop_transf[:, :2, :2], pred_right_hand_keypoints_hd
+ ]) + inv_right_hand_crop_transf[:, :2, 2].unsqueeze(dim=1)
+ gt_right_hand_keypoints_hd = gt_keypoints2d_hd[:, self.right_hand_idxs]
+ gt_right_hand_keypoints = torch.einsum('bij,bkj->bki', [
+ inv_right_hand_crop_transf[:, :2, :2], gt_right_hand_keypoints_hd
+ ]) + inv_right_hand_crop_transf[:, :2, 2].unsqueeze(dim=1)
+
+ left_loss = self.hand_crop_loss(
+ left_hand_img_keypoints,
+ gt_left_hand_keypoints,
+ weight=keypoints2d_conf[target_idxs][:, self.left_hand_idxs])
+ left_loss /= gt_left_hand_keypoints.shape[0]
+
+ right_loss = self.hand_crop_loss(
+ right_hand_img_keypoints,
+ gt_right_hand_keypoints,
+ weight=keypoints2d_conf[target_idxs][:, self.right_hand_idxs])
+ right_loss /= gt_right_hand_keypoints.shape[0]
+
+ return left_loss + right_loss
+
+ def compute_losses(self, predictions: dict, targets: dict):
+ """Compute losses."""
+ pred_param = predictions['pred_param']
+ pred_cam = predictions['pred_cam']
+ gt_keypoints3d = targets['keypoints3d']
+ gt_keypoints2d = targets['keypoints2d']
+
+ if self.body_model_train is not None:
+ pred_output = self.body_model_train(**pred_param)
+ pred_keypoints3d = pred_output['joints']
+ if 'has_keypoints3d' in targets:
+ has_keypoints3d = targets['has_keypoints3d'].squeeze(-1)
+ else:
+ has_keypoints3d = None
+ if 'has_keypoints2d' in targets:
+ has_keypoints2d = targets['has_keypoints2d'].squeeze(-1)
+ else:
+ has_keypoints2d = None
+
+ losses = {}
+ if self.loss_keypoints3d is not None:
+ losses['keypoints3d_loss'] = self.compute_keypoints3d_loss(
+ pred_keypoints3d,
+ gt_keypoints3d,
+ has_keypoints3d=has_keypoints3d)
+ if self.loss_keypoints2d is not None:
+ losses['keypoints2d_loss'] = self.compute_keypoints2d_loss(
+ pred_keypoints3d,
+ pred_cam,
+ gt_keypoints2d,
+ img_res=targets['img'].shape[-1],
+ has_keypoints2d=has_keypoints2d)
+ if self.loss_smplx_global_orient is not None:
+ pred_global_orient = pred_param['global_orient']
+ pred_global_orient = pose2rotmat(pred_global_orient)
+ gt_global_orient = targets['smplx_global_orient']
+ has_smplx_global_orient = targets[
+ 'has_smplx_global_orient'].squeeze(-1)
+ losses['smplx_global_orient_loss'] = \
+ self.compute_smplx_global_orient_loss(
+ pred_global_orient, gt_global_orient,
+ has_smplx_global_orient)
+ if self.loss_smplx_body_pose is not None:
+ pred_pose = pred_param['body_pose']
+ pred_pose = pose2rotmat(pred_pose)
+ gt_pose = targets['smplx_body_pose']
+ has_smplx_body_pose = targets['has_smplx_body_pose'].squeeze(-1)
+ losses['smplx_body_pose_loss'] = \
+ self.compute_smplx_body_pose_loss(
+ pred_pose, gt_pose, has_smplx_body_pose)
+ if self.loss_smplx_jaw_pose is not None:
+ pred_jaw_pose = pred_param['jaw_pose']
+ pred_jaw_pose = pose2rotmat(pred_jaw_pose)
+ gt_jaw_pose = targets['smplx_jaw_pose']
+ face_conf = get_keypoint_idxs_by_part('head', self.convention)
+ has_smplx_jaw_pose = targets['has_smplx_jaw_pose'].squeeze(-1)
+ losses['smplx_jaw_pose_loss'] = self.compute_smplx_jaw_pose_loss(
+ pred_jaw_pose, gt_jaw_pose, has_smplx_jaw_pose,
+ gt_keypoints2d[:, face_conf, 2])
+ if self.loss_smplx_hand_pose is not None:
+ pred_right_hand_pose = pred_param['right_hand_pose']
+ pred_right_hand_pose = pose2rotmat(pred_right_hand_pose)
+ gt_right_hand_pose = targets['smplx_right_hand_pose']
+ right_hand_conf = get_keypoint_idxs_by_part(
+ 'right_hand', self.convention)
+ has_smplx_right_hand_pose = targets[
+ 'has_smplx_right_hand_pose'].squeeze(-1)
+ losses['smplx_right_hand_pose_loss'] = \
+ self.compute_smplx_hand_pose_loss(
+ pred_right_hand_pose, gt_right_hand_pose,
+ has_smplx_right_hand_pose,
+ gt_keypoints2d[:, right_hand_conf, 2])
+ if 'left_hand_pose' in pred_param:
+ pred_left_hand_pose = pred_param['left_hand_pose']
+ pred_left_hand_pose = pose2rotmat(pred_left_hand_pose)
+ gt_left_hand_pose = targets['smplx_left_hand_pose']
+ left_hand_conf = get_keypoint_idxs_by_part(
+ 'left_hand', self.convention)
+ has_smplx_left_hand_pose = targets[
+ 'has_smplx_left_hand_pose'].squeeze(-1)
+ losses['smplx_left_hand_pose_loss'] = \
+ self.compute_smplx_hand_pose_loss(
+ pred_left_hand_pose, gt_left_hand_pose,
+ has_smplx_left_hand_pose,
+ gt_keypoints2d[:, left_hand_conf, 2])
+ if self.loss_smplx_betas is not None:
+ pred_betas = pred_param['betas']
+ gt_betas = targets['smplx_betas']
+ has_smplx_betas = targets['has_smplx_betas'].squeeze(-1)
+ losses['smplx_betas_loss'] = self.compute_smplx_betas_loss(
+ pred_betas, gt_betas, has_smplx_betas)
+ if self.loss_smplx_expression is not None:
+ pred_expression = pred_param['expression']
+ gt_expression = targets['smplx_expression']
+ face_conf = get_keypoint_idxs_by_part('head', self.convention)
+ has_smplx_expression = targets['has_smplx_expression'].squeeze(-1)
+ losses[
+ 'smplx_expression_loss'] = self.compute_smplx_expression_loss(
+ pred_expression, gt_expression, has_smplx_expression,
+ gt_keypoints2d[:, face_conf, 2])
+ if self.loss_smplx_betas_piror is not None:
+ pred_betas = pred_param['betas']
+ losses['smplx_betas_prior_loss'] = \
+ self.compute_smplx_betas_prior_loss(
+ pred_betas)
+ if self.loss_camera is not None:
+ losses['camera_loss'] = self.compute_camera_loss(pred_cam)
+ if self.apply_hand_model and self.hand_crop_loss is not None:
+ losses['hand_crop_loss'] = self.compute_hand_crop_loss(
+ pred_keypoints3d, pred_cam, gt_keypoints2d,
+ predictions['hand_crop_info'], targets['img'].shape[-1],
+ has_keypoints2d)
+ if self.apply_face_model and self.face_crop_loss is not None:
+ losses['face_crop_loss'] = self.compute_face_crop_loss(
+ pred_keypoints3d, pred_cam, gt_keypoints2d,
+ predictions['face_crop_info'], targets['img'].shape[-1],
+ has_keypoints2d)
+ return losses
+
+ @abstractmethod
+ def prepare_targets(self, data_batch):
+ pass
+
+ def forward_train(self, **kwargs):
+ """Forward function for general training.
+
+ For mesh estimation, we do not use this interface.
+ """
+ raise NotImplementedError('This interface should not be used in '
+ 'current training schedule. Please use '
+ '`train_step` for training.')
+
+ @abstractmethod
+ def forward_test(self, img, img_metas, **kwargs):
+ """Defines the computation performed at every call when testing."""
+ pass
+
+
+class SMPLXImageBodyModelEstimator(SMPLXBodyModelEstimator):
+ def prepare_targets(self, data_batch: dict):
+ # Image Mesh Estimator does not need extra process for ground truth
+ return data_batch
+
+ def forward_test(self, img: torch.Tensor, img_metas: dict, **kwargs):
+ """Defines the computation performed at every call when testing."""
+ if self.backbone is not None:
+ features = self.backbone(img)
+ else:
+ features = kwargs['features']
+
+ if self.neck is not None:
+ features = self.neck(features)
+
+ predictions = self.head(features)
+ if self.apply_hand_model:
+ hand_input_img, hand_mean, hand_crop_info = self.crop_hand_func(
+ predictions, img_metas)
+ hand_features = self.hand_backbone(hand_input_img)
+ if self.neck is not None:
+ hand_features = self.hand_neck(hand_features)
+ hand_predictions = self.hand_head(hand_features, cond=hand_mean)
+ predictions = self.hand_merge_func(predictions, hand_predictions)
+ predictions['hand_crop_info'] = hand_crop_info
+ if self.apply_face_model:
+ face_input_img, face_mean, face_crop_info = self.crop_face_func(
+ predictions, img_metas)
+ face_features = self.face_backbone(face_input_img)
+ if self.neck is not None:
+ face_features = self.face_neck(face_features)
+ face_predictions = self.face_head(face_features, cond=face_mean)
+ predictions = self.face_merge_func(predictions, face_predictions)
+ predictions['face_crop_info'] = face_crop_info
+
+ pred_param = predictions['pred_param']
+ pred_cam = predictions['pred_cam']
+
+ pred_output = self.body_model_test(**pred_param)
+
+ pred_vertices = pred_output['vertices']
+ pred_keypoints_3d = pred_output['joints']
+ all_preds = {}
+ all_preds['keypoints_3d'] = pred_keypoints_3d.detach().cpu().numpy()
+ for value in pred_param.values():
+ if isinstance(value, torch.Tensor):
+ value = value.detach().cpu().numpy()
+ all_preds['param'] = pred_param
+ all_preds['camera'] = pred_cam.detach().cpu().numpy()
+ all_preds['vertices'] = pred_vertices.detach().cpu().numpy()
+ image_path = []
+ for img_meta in img_metas:
+ image_path.append(img_meta['image_path'])
+ all_preds['image_path'] = image_path
+ all_preds['image_idx'] = kwargs['sample_idx']
+ return all_preds
+
+
+class FrozenBatchNorm2d(nn.Module):
+ """BatchNorm2d where the batch statistics and the affine parameters are
+ fixed."""
+ def __init__(self, n):
+ super(FrozenBatchNorm2d, self).__init__()
+ self.register_buffer('weight', torch.ones(n))
+ self.register_buffer('bias', torch.zeros(n))
+ self.register_buffer('running_mean', torch.zeros(n))
+ self.register_buffer('running_var', torch.ones(n))
+
+ @staticmethod
+ def from_bn(module: nn.BatchNorm2d):
+ """Initializes a frozen batch norm module from a batch norm module."""
+ dim = len(module.weight.data)
+
+ frozen_module = FrozenBatchNorm2d(dim)
+ frozen_module.weight.data = module.weight.data
+
+ missing, not_found = frozen_module.load_state_dict(module.state_dict(),
+ strict=False)
+ return frozen_module
+
+ @classmethod
+ def convert_frozen_batchnorm(cls, module):
+ """Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
+
+ Args:
+ module (torch.nn.Module):
+
+ Returns:
+ If module is BatchNorm/SyncBatchNorm, returns a new module.
+ Otherwise, in-place convert module and return it.
+
+ Similar to convert_sync_batchnorm in
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
+ """
+ bn_module = nn.modules.batchnorm
+ bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
+ res = module
+ if isinstance(module, bn_module):
+ res = cls(module.num_features)
+ if module.affine:
+ res.weight.data = module.weight.data.clone().detach()
+ res.bias.data = module.bias.data.clone().detach()
+ res.running_mean.data = module.running_mean.data
+ res.running_var.data = module.running_var.data
+ res.eps = module.eps
+ else:
+ for name, child in module.named_children():
+ new_child = cls.convert_frozen_batchnorm(child)
+ if new_child is not child:
+ res.add_module(name, new_child)
+ return res
+
+ def forward(self, x):
+ # Cast all fixed parameters to half() if necessary
+ if x.dtype == torch.float16:
+ self.weight = self.weight.half()
+ self.bias = self.bias.half()
+ self.running_mean = self.running_mean.half()
+ self.running_var = self.running_var.half()
+
+ return F.batch_norm(x, self.running_mean, self.running_var,
+ self.weight, self.bias, False)
diff --git a/detrsmpl/models/architectures/hybrik.py b/detrsmpl/models/architectures/hybrik.py
new file mode 100644
index 0000000000000000000000000000000000000000..5866b051054d537aeb3b41113e964c16630d14bc
--- /dev/null
+++ b/detrsmpl/models/architectures/hybrik.py
@@ -0,0 +1,276 @@
+# isort: skip_file
+from abc import ABCMeta
+
+import torch
+
+from detrsmpl.data.datasets.pipelines.hybrik_transforms import heatmap2coord
+from detrsmpl.utils.transforms import rotmat_to_quat
+from ..backbones.builder import build_backbone
+from ..body_models.builder import build_body_model
+from ..heads.builder import build_head
+from ..losses.builder import build_loss
+from ..necks.builder import build_neck
+from .base_architecture import BaseArchitecture
+
+
+def set_requires_grad(nets, requires_grad=False):
+ """Set requies_grad for all the networks.
+
+ Args:
+ nets (nn.Module | list[nn.Module]): A list of networks or a single
+ network.
+ requires_grad (bool): Whether the networks require gradients or not
+ """
+ if not isinstance(nets, list):
+ nets = [nets]
+ for net in nets:
+ if net is not None:
+ for param in net.parameters():
+ param.requires_grad = requires_grad
+
+
+class HybrIK_trainer(BaseArchitecture, metaclass=ABCMeta):
+ """Hybrik_trainer Architecture.
+
+ Args:
+ backbone (dict | None, optional): Backbone config dict. Default: None.
+ neck (dict | None, optional): Neck config dict. Default: None
+ head (dict | None, optional): Regressor config dict. Default: None.
+ body_model (dict | None, optional): SMPL config dict. Default: None.
+ loss_beta (dict | None, optional): Losses config dict for
+ beta (shape parameters) estimation. Default: None
+ loss_theta (dict | None, optional): Losses config dict for
+ theta (pose parameters) estimation. Default: None
+ loss_twist (dict | None, optional): Losses config dict
+ for twist angles estimation. Default: None
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+ def __init__(self,
+ backbone=None,
+ neck=None,
+ head=None,
+ body_model=None,
+ loss_beta=None,
+ loss_theta=None,
+ loss_twist=None,
+ loss_uvd=None,
+ init_cfg=None):
+ super(HybrIK_trainer, self).__init__(init_cfg)
+
+ self.backbone = build_backbone(backbone)
+
+ self.neck = build_neck(neck)
+ self.head = build_head(head)
+ self.smpl = build_body_model(body_model)
+
+ self.loss_beta = build_loss(loss_beta)
+ self.loss_theta = build_loss(loss_theta)
+ self.loss_twist = build_loss(loss_twist)
+ self.loss_uvd = build_loss(loss_uvd)
+
+ self.head._initialize()
+
+ def forward_train(self, img, img_metas, **kwargs):
+ """Train step function.
+
+ In this function, train step is carried out
+ with following the pipeline:
+ 1. extract features with the backbone
+ 2. feed the extracted features into the head to
+ predicte beta, theta, twist angle, and heatmap (uvd map)
+ 3. compute regression losses of the predictions
+ and optimize backbone and head
+ Args:
+ img (torch.Tensor): Batch of data as input.
+ kwargs (dict): Dict with ground-truth
+ Returns:
+ output (dict): Dict with loss, information for logger,
+ the number of samples.
+ """
+ labels = {}
+ labels['trans_inv'] = kwargs['trans_inv']
+ labels['intrinsic_param'] = kwargs['intrinsic_param']
+ labels['joint_root'] = kwargs['joint_root']
+ labels['depth_factor'] = kwargs['depth_factor']
+ labels['target_uvd_29'] = kwargs['target_uvd_29']
+ labels['target_xyz_24'] = kwargs['target_xyz_24']
+ labels['target_weight_24'] = kwargs['target_weight_24']
+ labels['target_weight_29'] = kwargs['target_weight_29']
+ labels['target_xyz_17'] = kwargs['target_xyz_17']
+ labels['target_weight_17'] = kwargs['target_weight_17']
+ labels['target_theta'] = kwargs['target_theta']
+ labels['target_beta'] = kwargs['target_beta']
+ labels['target_smpl_weight'] = kwargs['target_smpl_weight']
+ labels['target_theta_weight'] = kwargs['target_theta_weight']
+ labels['target_twist'] = kwargs['target_twist']
+ labels['target_twist_weight'] = kwargs['target_twist_weight']
+ # flip_output = kwargs.pop('is_flipped', None)
+
+ for k, _ in labels.items():
+ labels[k] = labels[k].cuda()
+
+ trans_inv = labels.pop('trans_inv')
+ intrinsic_param = labels.pop('intrinsic_param')
+ joint_root = labels.pop('joint_root')
+ depth_factor = labels.pop('depth_factor')
+
+ if self.backbone is not None:
+ img = img.cuda().requires_grad_()
+ features = self.backbone(img)
+ features = features[0]
+ else:
+ features = img['features']
+
+ if self.neck is not None:
+ features = self.neck(features)
+
+ predictions = self.head(features, trans_inv, intrinsic_param,
+ joint_root, depth_factor, self.smpl)
+
+ losses = self.compute_losses(predictions, labels)
+
+ return losses
+
+ def compute_losses(self, predictions, targets):
+ """Compute regression losses for beta, theta, twist and uvd."""
+ smpl_weight = targets['target_smpl_weight']
+
+ losses = {}
+ if self.loss_beta is not None:
+ losses['loss_beta'] = self.loss_beta(
+ predictions['pred_shape'] * smpl_weight,
+ targets['target_beta'] * smpl_weight)
+ if self.loss_theta is not None:
+ pred_pose = rotmat_to_quat(predictions['pred_pose']).reshape(
+ -1, 96)
+ losses['loss_theta'] = self.loss_theta(
+ pred_pose * smpl_weight * targets['target_theta_weight'],
+ targets['target_theta'] * smpl_weight *
+ targets['target_theta_weight'])
+ if self.loss_twist is not None:
+ losses['loss_twist'] = self.loss_twist(
+ predictions['pred_phi'] * targets['target_twist_weight'],
+ targets['target_twist'] * targets['target_twist_weight'])
+ if self.loss_uvd is not None:
+ pred_uvd = predictions['pred_uvd_jts']
+ target_uvd = targets['target_uvd_29'][:, :pred_uvd.shape[1]]
+ target_uvd_weight = targets['target_weight_29'][:, :pred_uvd.
+ shape[1]]
+ losses['loss_uvd'] = self.loss_uvd(
+ 64 * predictions['pred_uvd_jts'],
+ 64 * target_uvd,
+ target_uvd_weight,
+ avg_factor=target_uvd_weight.sum())
+
+ return losses
+
+ def forward_test(self, img, img_metas, **kwargs):
+ """Test step function.
+
+ In this function, train step is carried out
+ with following the pipeline:
+ 1. extract features with the backbone
+ 2. feed the extracted features into the head to
+ predicte beta, theta, twist angle, and heatmap (uvd map)
+ 3. store predictions for evaluation
+ Args:
+ img (torch.Tensor): Batch of data as input.
+ img_metas (dict): Dict with image metas i.e. path
+ kwargs (dict): Dict with ground-truth
+ Returns:
+ all_preds (dict): Dict with image_path, vertices, xyz_17, uvd_jts,
+ xyz_24 for predictions.
+ """
+ labels = {}
+ labels['trans_inv'] = kwargs['trans_inv']
+ labels['intrinsic_param'] = kwargs['intrinsic_param']
+ labels['joint_root'] = kwargs['joint_root']
+ labels['depth_factor'] = kwargs['depth_factor']
+ labels['target_uvd_29'] = kwargs['target_uvd_29']
+ labels['target_xyz_24'] = kwargs['target_xyz_24']
+ labels['target_weight_24'] = kwargs['target_weight_24']
+ labels['target_weight_29'] = kwargs['target_weight_29']
+ labels['target_xyz_17'] = kwargs['target_xyz_17']
+ labels['target_weight_17'] = kwargs['target_weight_17']
+ labels['target_theta'] = kwargs['target_theta']
+ labels['target_beta'] = kwargs['target_beta']
+ labels['target_smpl_weight'] = kwargs['target_smpl_weight']
+ labels['target_theta_weight'] = kwargs['target_theta_weight']
+ labels['target_twist'] = kwargs['target_twist']
+ labels['target_twist_weight'] = kwargs['target_twist_weight']
+
+ bboxes = kwargs['bbox']
+
+ for k, _ in labels.items():
+ labels[k] = labels[k].cuda()
+
+ trans_inv = labels.pop('trans_inv')
+ intrinsic_param = labels.pop('intrinsic_param')
+ joint_root = labels.pop('joint_root')
+ depth_factor = labels.pop('depth_factor')
+ if len(depth_factor.shape) != 2:
+ depth_factor = torch.unsqueeze(depth_factor, dim=1)
+
+ if self.backbone is not None:
+ img = img.cuda().requires_grad_()
+ features = self.backbone(img)
+ features = features[0]
+ else:
+ features = img['features']
+
+ if self.neck is not None:
+ features = self.neck(features)
+
+ output = self.head(features, trans_inv, intrinsic_param, joint_root,
+ depth_factor, self.smpl)
+
+ pred_uvd_jts = output['pred_uvd_jts']
+ batch_num = pred_uvd_jts.shape[0]
+ pred_xyz_jts_24 = output['pred_xyz_jts_24'].reshape(batch_num, -1,
+ 3)[:, :24, :]
+ pred_xyz_jts_24_struct = output['pred_xyz_jts_24_struct'].reshape(
+ batch_num, 24, 3)
+ pred_xyz_jts_17 = output['pred_xyz_jts_17'].reshape(batch_num, 17, 3)
+ pred_mesh = output['pred_vertices'].reshape(batch_num, -1, 3)
+
+ pred_xyz_jts_24 = pred_xyz_jts_24.cpu().data.numpy()
+ pred_xyz_jts_24_struct = pred_xyz_jts_24_struct.cpu().data.numpy()
+ pred_xyz_jts_17 = pred_xyz_jts_17.cpu().data.numpy()
+ pred_uvd_jts = pred_uvd_jts.cpu().data
+ pred_mesh = pred_mesh.cpu().data.numpy()
+ pred_pose = output['pred_pose'].cpu().data.numpy()
+ pred_beta = output['pred_shape'].cpu().data.numpy()
+
+ assert pred_xyz_jts_17.ndim in [2, 3]
+ pred_xyz_jts_17 = pred_xyz_jts_17.reshape(pred_xyz_jts_17.shape[0], 17,
+ 3)
+ pred_uvd_jts = pred_uvd_jts.reshape(pred_uvd_jts.shape[0], -1, 3)
+ pred_xyz_jts_24 = pred_xyz_jts_24.reshape(pred_xyz_jts_24.shape[0], 24,
+ 3)
+ pred_scores = output['maxvals'].cpu().data[:, :29]
+
+ hm_shape = [64, 64]
+ pose_coords_list = []
+ for i in range(pred_xyz_jts_17.shape[0]):
+ bbox = bboxes[i].tolist()
+ pose_coords, _ = heatmap2coord(pred_uvd_jts[i],
+ pred_scores[i],
+ hm_shape,
+ bbox,
+ mean_bbox_scale=None)
+ pose_coords_list.append(pose_coords)
+
+ all_preds = {}
+ all_preds['vertices'] = pred_mesh
+ all_preds['smpl_pose'] = pred_pose
+ all_preds['smpl_beta'] = pred_beta
+ all_preds['xyz_17'] = pred_xyz_jts_17
+ all_preds['uvd_jts'] = pose_coords
+ all_preds['xyz_24'] = pred_xyz_jts_24_struct
+ image_path = []
+ for img_meta in img_metas:
+ image_path.append(img_meta['image_path'])
+ all_preds['image_path'] = image_path
+ all_preds['image_idx'] = kwargs['sample_idx']
+ return all_preds
diff --git a/detrsmpl/models/architectures/mesh_estimator.py b/detrsmpl/models/architectures/mesh_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f2b1b294fbda665407af7a4776c59f64f81e11c
--- /dev/null
+++ b/detrsmpl/models/architectures/mesh_estimator.py
@@ -0,0 +1,865 @@
+from abc import ABCMeta, abstractmethod
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+
+import detrsmpl.core.visualization.visualize_smpl as visualize_smpl
+from detrsmpl.core.conventions.keypoints_mapping import get_keypoint_idx
+from detrsmpl.models.utils import FitsDict
+from detrsmpl.utils.geometry import (
+ batch_rodrigues,
+ estimate_translation,
+ project_points,
+ rotation_matrix_to_angle_axis,
+)
+from ..backbones.builder import build_backbone
+from ..body_models.builder import build_body_model
+from ..discriminators.builder import build_discriminator
+from ..heads.builder import build_head
+from ..losses.builder import build_loss
+from ..necks.builder import build_neck
+from ..registrants.builder import build_registrant
+from .base_architecture import BaseArchitecture
+
+
+def set_requires_grad(nets, requires_grad=False):
+ """Set requies_grad for all the networks.
+
+ Args:
+ nets (nn.Module | list[nn.Module]): A list of networks or a single
+ network.
+ requires_grad (bool): Whether the networks require gradients or not
+ """
+ if not isinstance(nets, list):
+ nets = [nets]
+ for net in nets:
+ if net is not None:
+ for param in net.parameters():
+ param.requires_grad = requires_grad
+
+
+class BodyModelEstimator(BaseArchitecture, metaclass=ABCMeta):
+ """BodyModelEstimator Architecture.
+
+ Args:
+ backbone (dict | None, optional): Backbone config dict. Default: None.
+ neck (dict | None, optional): Neck config dict. Default: None
+ head (dict | None, optional): Regressor config dict. Default: None.
+ disc (dict | None, optional): Discriminator config dict.
+ Default: None.
+ registration (dict | None, optional): Registration config dict.
+ Default: None.
+ body_model_train (dict | None, optional): SMPL config dict during
+ training. Default: None.
+ body_model_test (dict | None, optional): SMPL config dict during
+ test. Default: None.
+ convention (str, optional): Keypoints convention. Default: "human_data"
+ loss_keypoints2d (dict | None, optional): Losses config dict for
+ 2D keypoints. Default: None.
+ loss_keypoints3d (dict | None, optional): Losses config dict for
+ 3D keypoints. Default: None.
+ loss_vertex (dict | None, optional): Losses config dict for mesh
+ vertices. Default: None
+ loss_smpl_pose (dict | None, optional): Losses config dict for smpl
+ pose. Default: None
+ loss_smpl_betas (dict | None, optional): Losses config dict for smpl
+ betas. Default: None
+ loss_camera (dict | None, optional): Losses config dict for predicted
+ camera parameters. Default: None
+ loss_adv (dict | None, optional): Losses config for adversial
+ training. Default: None.
+ loss_segm_mask (dict | None, optional): Losses config for predicted
+ part segmentation. Default: None.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+ def __init__(self,
+ backbone: Optional[Union[dict, None]] = None,
+ neck: Optional[Union[dict, None]] = None,
+ head: Optional[Union[dict, None]] = None,
+ disc: Optional[Union[dict, None]] = None,
+ registration: Optional[Union[dict, None]] = None,
+ body_model_train: Optional[Union[dict, None]] = None,
+ body_model_test: Optional[Union[dict, None]] = None,
+ convention: Optional[str] = 'human_data',
+ loss_keypoints2d: Optional[Union[dict, None]] = None,
+ loss_keypoints3d: Optional[Union[dict, None]] = None,
+ loss_vertex: Optional[Union[dict, None]] = None,
+ loss_smpl_pose: Optional[Union[dict, None]] = None,
+ loss_smpl_betas: Optional[Union[dict, None]] = None,
+ loss_camera: Optional[Union[dict, None]] = None,
+ loss_adv: Optional[Union[dict, None]] = None,
+ loss_segm_mask: Optional[Union[dict, None]] = None,
+ init_cfg: Optional[Union[list, dict, None]] = None):
+ super(BodyModelEstimator, self).__init__(init_cfg)
+ self.backbone = build_backbone(backbone)
+ self.neck = build_neck(neck)
+ self.head = build_head(head)
+ self.disc = build_discriminator(disc)
+
+ self.body_model_train = build_body_model(body_model_train)
+ self.body_model_test = build_body_model(body_model_test)
+ self.convention = convention
+
+ # TODO: support HMR+
+
+ self.registration = registration
+ if registration is not None:
+ self.fits_dict = FitsDict(fits='static')
+ self.registration_mode = self.registration['mode']
+ self.registrant = build_registrant(registration['registrant'])
+ else:
+ self.registrant = None
+
+ self.loss_keypoints2d = build_loss(loss_keypoints2d)
+ self.loss_keypoints3d = build_loss(loss_keypoints3d)
+
+ self.loss_vertex = build_loss(loss_vertex)
+ self.loss_smpl_pose = build_loss(loss_smpl_pose)
+ self.loss_smpl_betas = build_loss(loss_smpl_betas)
+ self.loss_adv = build_loss(loss_adv)
+ self.loss_camera = build_loss(loss_camera)
+ self.loss_segm_mask = build_loss(loss_segm_mask)
+ set_requires_grad(self.body_model_train, False)
+ set_requires_grad(self.body_model_test, False)
+
+ def train_step(self, data_batch, optimizer, **kwargs):
+ """Train step function.
+
+ In this function, the detector will finish the train step following
+ the pipeline:
+ 1. get fake and real SMPL parameters
+ 2. optimize discriminator (if have)
+ 3. optimize generator
+ If `self.train_cfg.disc_step > 1`, the train step will contain multiple
+ iterations for optimizing discriminator with different input data and
+ only one iteration for optimizing generator after `disc_step`
+ iterations for discriminator.
+ Args:
+ data_batch (torch.Tensor): Batch of data as input.
+ optimizer (dict[torch.optim.Optimizer]): Dict with optimizers for
+ generator and discriminator (if have).
+ Returns:
+ outputs (dict): Dict with loss, information for logger,
+ the number of samples.
+ """
+ if self.backbone is not None:
+ img = data_batch['img']
+ features = self.backbone(img)
+ else:
+ features = data_batch['features']
+
+ if self.neck is not None:
+ features = self.neck(features)
+
+ predictions = self.head(features)
+ targets = self.prepare_targets(data_batch)
+
+ # optimize discriminator (if have)
+ if self.disc is not None:
+ self.optimize_discrinimator(predictions, data_batch, optimizer)
+
+ if self.registration is not None:
+ targets = self.run_registration(predictions, targets)
+
+ losses = self.compute_losses(predictions, targets)
+ # optimizer generator part
+ if self.disc is not None:
+ adv_loss = self.optimize_generator(predictions)
+ losses.update(adv_loss)
+
+ loss, log_vars = self._parse_losses(losses)
+ for key in optimizer.keys():
+ optimizer[key].zero_grad()
+ loss.backward()
+ for key in optimizer.keys():
+ optimizer[key].step()
+
+ outputs = dict(loss=loss,
+ log_vars=log_vars,
+ num_samples=len(next(iter(data_batch.values()))))
+ return outputs
+
+ def run_registration(
+ self,
+ predictions: dict,
+ targets: dict,
+ threshold: Optional[float] = 10.0,
+ focal_length: Optional[float] = 5000.0,
+ img_res: Optional[Union[Tuple[int], int]] = 224) -> dict:
+ """Run registration on 2D keypoinst in predictions to obtain SMPL
+ parameters as pseudo ground truth.
+
+ Args:
+ predictions (dict): predicted SMPL parameters are used for
+ initialization.
+ targets (dict): existing ground truths with 2D keypoints
+ threshold (float, optional): the threshold to update fits
+ dictionary. Default: 10.0.
+ focal_length (tuple(int) | int, optional): camera focal_length
+ img_res (int, optional): image resolution
+
+ Returns:
+ targets: contains additional SMPL parameters
+ """
+
+ img_metas = targets['img_metas']
+ dataset_name = [meta['dataset_name'] for meta in img_metas
+ ] # name of the dataset the image comes from
+
+ indices = targets['sample_idx'].squeeze()
+ is_flipped = targets['is_flipped'].squeeze().bool(
+ ) # flag that indicates whether image was flipped
+ # during data augmentation
+ rot_angle = targets['rotation'].squeeze(
+ ) # rotation angle used for data augmentation Q
+ gt_betas = targets['smpl_betas'].float()
+ gt_global_orient = targets['smpl_global_orient'].float()
+ gt_pose = targets['smpl_body_pose'].float().view(-1, 69)
+
+ pred_rotmat = predictions['pred_pose'].detach().clone()
+ pred_betas = predictions['pred_shape'].detach().clone()
+ pred_cam = predictions['pred_cam'].detach().clone()
+ pred_cam_t = torch.stack([
+ pred_cam[:, 1], pred_cam[:, 2], 2 * focal_length /
+ (img_res * pred_cam[:, 0] + 1e-9)
+ ],
+ dim=-1)
+
+ gt_keypoints_2d = targets['keypoints2d'].float()
+ num_keypoints = gt_keypoints_2d.shape[1]
+
+ has_smpl = targets['has_smpl'].view(
+ -1).bool() # flag that indicates whether SMPL parameters are valid
+ batch_size = has_smpl.shape[0]
+ device = has_smpl.device
+
+ # Get GT vertices and model joints
+ # Note that gt_model_joints is different from gt_joints as
+ # it comes from SMPL
+ gt_out = self.body_model_train(betas=gt_betas,
+ body_pose=gt_pose,
+ global_orient=gt_global_orient)
+ # TODO: support more convention
+ assert num_keypoints == 49
+ gt_model_joints = gt_out['joints']
+ gt_vertices = gt_out['vertices']
+
+ # Get current best fits from the dictionary
+ opt_pose, opt_betas = self.fits_dict[(dataset_name, indices.cpu(),
+ rot_angle.cpu(),
+ is_flipped.cpu())]
+
+ opt_pose = opt_pose.to(device)
+ opt_betas = opt_betas.to(device)
+ opt_output = self.body_model_train(betas=opt_betas,
+ body_pose=opt_pose[:, 3:],
+ global_orient=opt_pose[:, :3])
+ opt_joints = opt_output['joints']
+ opt_vertices = opt_output['vertices']
+
+ gt_keypoints_2d_orig = gt_keypoints_2d.clone()
+ # Estimate camera translation given the model joints and 2D keypoints
+ # by minimizing a weighted least squares loss
+ gt_cam_t = estimate_translation(gt_model_joints,
+ gt_keypoints_2d_orig,
+ focal_length=focal_length,
+ img_size=img_res)
+
+ opt_cam_t = estimate_translation(opt_joints,
+ gt_keypoints_2d_orig,
+ focal_length=focal_length,
+ img_size=img_res)
+
+ with torch.no_grad():
+ loss_dict = self.registrant.evaluate(
+ global_orient=opt_pose[:, :3],
+ body_pose=opt_pose[:, 3:],
+ betas=opt_betas,
+ transl=opt_cam_t,
+ keypoints2d=gt_keypoints_2d_orig[:, :, :2],
+ keypoints2d_conf=gt_keypoints_2d_orig[:, :, 2],
+ reduction_override='none')
+ opt_joint_loss = loss_dict['keypoint2d_loss'].sum(dim=-1).sum(dim=-1)
+
+ if self.registration_mode == 'in_the_loop':
+ # Convert predicted rotation matrices to axis-angle
+ pred_rotmat_hom = torch.cat([
+ pred_rotmat.detach().view(-1, 3, 3).detach(),
+ torch.tensor([0, 0, 1], dtype=torch.float32,
+ device=device).view(1, 3, 1).expand(
+ batch_size * 24, -1, -1)
+ ],
+ dim=-1)
+ pred_pose = rotation_matrix_to_angle_axis(
+ pred_rotmat_hom).contiguous().view(batch_size, -1)
+ # tgm.rotation_matrix_to_angle_axis returns NaN for 0 rotation,
+ # so manually hack it
+ pred_pose[torch.isnan(pred_pose)] = 0.0
+
+ registrant_output = self.registrant(
+ keypoints2d=gt_keypoints_2d_orig[:, :, :2],
+ keypoints2d_conf=gt_keypoints_2d_orig[:, :, 2],
+ init_global_orient=pred_pose[:, :3],
+ init_transl=pred_cam_t,
+ init_body_pose=pred_pose[:, 3:],
+ init_betas=pred_betas,
+ return_joints=True,
+ return_verts=True,
+ return_losses=True)
+ new_opt_vertices = registrant_output[
+ 'vertices'] - pred_cam_t.unsqueeze(1)
+ new_opt_joints = registrant_output[
+ 'joints'] - pred_cam_t.unsqueeze(1)
+
+ new_opt_global_orient = registrant_output['global_orient']
+ new_opt_body_pose = registrant_output['body_pose']
+ new_opt_pose = torch.cat(
+ [new_opt_global_orient, new_opt_body_pose], dim=1)
+
+ new_opt_betas = registrant_output['betas']
+ new_opt_cam_t = registrant_output['transl']
+ new_opt_joint_loss = registrant_output['keypoint2d_loss'].sum(
+ dim=-1).sum(dim=-1)
+
+ # Will update the dictionary for the examples where the new loss
+ # is less than the current one
+ update = (new_opt_joint_loss < opt_joint_loss)
+
+ opt_joint_loss[update] = new_opt_joint_loss[update]
+ opt_vertices[update, :] = new_opt_vertices[update, :]
+ opt_joints[update, :] = new_opt_joints[update, :]
+ opt_pose[update, :] = new_opt_pose[update, :]
+ opt_betas[update, :] = new_opt_betas[update, :]
+ opt_cam_t[update, :] = new_opt_cam_t[update, :]
+
+ self.fits_dict[(dataset_name, indices.cpu(), rot_angle.cpu(),
+ is_flipped.cpu(),
+ update.cpu())] = (opt_pose.cpu(), opt_betas.cpu())
+
+ # Replace extreme betas with zero betas
+ opt_betas[(opt_betas.abs() > 3).any(dim=-1)] = 0.
+
+ # Replace the optimized parameters with the ground truth parameters,
+ # if available
+ opt_vertices[has_smpl, :, :] = gt_vertices[has_smpl, :, :]
+ opt_cam_t[has_smpl, :] = gt_cam_t[has_smpl, :]
+ opt_joints[has_smpl, :, :] = gt_model_joints[has_smpl, :, :]
+ opt_pose[has_smpl, 3:] = gt_pose[has_smpl, :]
+ opt_pose[has_smpl, :3] = gt_global_orient[has_smpl, :]
+ opt_betas[has_smpl, :] = gt_betas[has_smpl, :]
+
+ # Assert whether a fit is valid by comparing the joint loss with
+ # the threshold
+ valid_fit = (opt_joint_loss < threshold).to(device)
+ valid_fit = valid_fit | has_smpl
+ targets['valid_fit'] = valid_fit
+
+ targets['opt_vertices'] = opt_vertices
+ targets['opt_cam_t'] = opt_cam_t
+ targets['opt_joints'] = opt_joints
+ targets['opt_pose'] = opt_pose
+ targets['opt_betas'] = opt_betas
+
+ return targets
+
+ def optimize_discrinimator(self, predictions: dict, data_batch: dict,
+ optimizer: dict):
+ """Optimize discrinimator during adversarial training."""
+ set_requires_grad(self.disc, True)
+ fake_data = self.make_fake_data(predictions, requires_grad=False)
+ real_data = self.make_real_data(data_batch)
+ fake_score = self.disc(fake_data)
+ real_score = self.disc(real_data)
+
+ disc_losses = {}
+ disc_losses['real_loss'] = self.loss_adv(real_score,
+ target_is_real=True,
+ is_disc=True)
+ disc_losses['fake_loss'] = self.loss_adv(fake_score,
+ target_is_real=False,
+ is_disc=True)
+ loss_disc, log_vars_d = self._parse_losses(disc_losses)
+
+ optimizer['disc'].zero_grad()
+ loss_disc.backward()
+ optimizer['disc'].step()
+
+ def optimize_generator(self, predictions: dict):
+ """Optimize generator during adversarial training."""
+ set_requires_grad(self.disc, False)
+ fake_data = self.make_fake_data(predictions, requires_grad=True)
+ pred_score = self.disc(fake_data)
+ loss_adv = self.loss_adv(pred_score,
+ target_is_real=True,
+ is_disc=False)
+ loss = dict(adv_loss=loss_adv)
+ return loss
+
+ def compute_keypoints3d_loss(
+ self,
+ pred_keypoints3d: torch.Tensor,
+ gt_keypoints3d: torch.Tensor,
+ has_keypoints3d: Optional[torch.Tensor] = None):
+ """Compute loss for 3d keypoints."""
+ keypoints3d_conf = gt_keypoints3d[:, :, 3].float().unsqueeze(-1)
+ keypoints3d_conf = keypoints3d_conf.repeat(1, 1, 3)
+ pred_keypoints3d = pred_keypoints3d.float()
+ gt_keypoints3d = gt_keypoints3d[:, :, :3].float()
+
+ # currently, only mpi_inf_3dhp and h36m have 3d keypoints
+ # both datasets have right_hip_extra and left_hip_extra
+ right_hip_idx = get_keypoint_idx('right_hip_extra', self.convention)
+ left_hip_idx = get_keypoint_idx('left_hip_extra', self.convention)
+ gt_pelvis = (gt_keypoints3d[:, right_hip_idx, :] +
+ gt_keypoints3d[:, left_hip_idx, :]) / 2
+ pred_pelvis = (pred_keypoints3d[:, right_hip_idx, :] +
+ pred_keypoints3d[:, left_hip_idx, :]) / 2
+
+ gt_keypoints3d = gt_keypoints3d - gt_pelvis[:, None, :]
+ pred_keypoints3d = pred_keypoints3d - pred_pelvis[:, None, :]
+ loss = self.loss_keypoints3d(pred_keypoints3d,
+ gt_keypoints3d,
+ reduction_override='none')
+
+ # If has_keypoints3d is not None, then computes the losses on the
+ # instances that have ground-truth keypoints3d.
+ # But the zero confidence keypoints will be included in mean.
+ # Otherwise, only compute the keypoints3d
+ # which have positive confidence.
+
+ # has_keypoints3d is None when the key has_keypoints3d
+ # is not in the datasets
+ if has_keypoints3d is None:
+
+ valid_pos = keypoints3d_conf > 0
+ if keypoints3d_conf[valid_pos].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints3d)
+ loss = torch.sum(loss * keypoints3d_conf)
+ loss /= keypoints3d_conf[valid_pos].numel()
+ else:
+
+ keypoints3d_conf = keypoints3d_conf[has_keypoints3d == 1]
+ if keypoints3d_conf.shape[0] == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints3d)
+ loss = loss[has_keypoints3d == 1]
+ loss = (loss * keypoints3d_conf).mean()
+ return loss
+
+ def compute_keypoints2d_loss(
+ self,
+ pred_keypoints3d: torch.Tensor,
+ pred_cam: torch.Tensor,
+ gt_keypoints2d: torch.Tensor,
+ img_res: Optional[int] = 224,
+ focal_length: Optional[int] = 5000,
+ has_keypoints2d: Optional[torch.Tensor] = None):
+ """Compute loss for 2d keypoints."""
+ keypoints2d_conf = gt_keypoints2d[:, :, 2].float().unsqueeze(-1)
+ keypoints2d_conf = keypoints2d_conf.repeat(1, 1, 2)
+ gt_keypoints2d = gt_keypoints2d[:, :, :2].float()
+ pred_keypoints2d = project_points(pred_keypoints3d,
+ pred_cam,
+ focal_length=focal_length,
+ img_res=img_res)
+ # Normalize keypoints to [-1,1]
+ # The coordinate origin of pred_keypoints_2d is
+ # the center of the input image.
+ pred_keypoints2d = 2 * pred_keypoints2d / (img_res - 1)
+ # The coordinate origin of gt_keypoints_2d is
+ # the top left corner of the input image.
+ gt_keypoints2d = 2 * gt_keypoints2d / (img_res - 1) - 1
+ loss = self.loss_keypoints2d(pred_keypoints2d,
+ gt_keypoints2d,
+ reduction_override='none')
+
+ # If has_keypoints2d is not None, then computes the losses on the
+ # instances that have ground-truth keypoints2d.
+ # But the zero confidence keypoints will be included in mean.
+ # Otherwise, only compute the keypoints2d
+ # which have positive confidence.
+ # has_keypoints2d is None when the key has_keypoints2d
+ # is not in the datasets
+
+ if has_keypoints2d is None:
+ valid_pos = keypoints2d_conf > 0
+ if keypoints2d_conf[valid_pos].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints2d)
+ loss = torch.sum(loss * keypoints2d_conf)
+ loss /= keypoints2d_conf[valid_pos].numel()
+ else:
+ keypoints2d_conf = keypoints2d_conf[has_keypoints2d == 1]
+ if keypoints2d_conf.shape[0] == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints2d)
+ loss = loss[has_keypoints2d == 1]
+ loss = (loss * keypoints2d_conf).mean()
+
+ return loss
+
+ def compute_vertex_loss(self, pred_vertices: torch.Tensor,
+ gt_vertices: torch.Tensor, has_smpl: torch.Tensor):
+ """Compute loss for vertices."""
+ gt_vertices = gt_vertices.float()
+ conf = has_smpl.float().view(-1, 1, 1)
+ conf = conf.repeat(1, gt_vertices.shape[1], gt_vertices.shape[2])
+ loss = self.loss_vertex(pred_vertices,
+ gt_vertices,
+ reduction_override='none')
+ valid_pos = conf > 0
+ if conf[valid_pos].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_vertices)
+ loss = torch.sum(loss * conf) / conf[valid_pos].numel()
+ return loss
+
+ def compute_smpl_pose_loss(self, pred_rotmat: torch.Tensor,
+ gt_pose: torch.Tensor, has_smpl: torch.Tensor):
+ """Compute loss for smpl pose."""
+ conf = has_smpl.float().view(-1)
+ valid_pos = conf > 0
+ if conf[valid_pos].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_pose)
+ pred_rotmat = pred_rotmat[valid_pos]
+ gt_pose = gt_pose[valid_pos]
+ conf = conf[valid_pos]
+ gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24, 3, 3)
+ loss = self.loss_smpl_pose(pred_rotmat,
+ gt_rotmat,
+ reduction_override='none')
+ loss = loss.view(loss.shape[0], -1).mean(-1)
+ loss = torch.mean(loss * conf)
+ return loss
+
+ def compute_smpl_betas_loss(self, pred_betas: torch.Tensor,
+ gt_betas: torch.Tensor,
+ has_smpl: torch.Tensor):
+ """Compute loss for smpl betas."""
+ conf = has_smpl.float().view(-1)
+ valid_pos = conf > 0
+ if conf[valid_pos].numel() == 0:
+ return torch.Tensor([0]).type_as(gt_betas)
+ pred_betas = pred_betas[valid_pos]
+ gt_betas = gt_betas[valid_pos]
+ conf = conf[valid_pos]
+ loss = self.loss_smpl_betas(pred_betas,
+ gt_betas,
+ reduction_override='none')
+ loss = loss.view(loss.shape[0], -1).mean(-1)
+ loss = torch.mean(loss * conf)
+ return loss
+
+ def compute_camera_loss(self, cameras: torch.Tensor):
+ """Compute loss for predicted camera parameters."""
+ loss = self.loss_camera(cameras)
+ return loss
+
+ def compute_part_segmentation_loss(self,
+ pred_heatmap: torch.Tensor,
+ gt_vertices: torch.Tensor,
+ gt_keypoints2d: torch.Tensor,
+ gt_model_joints: torch.Tensor,
+ has_smpl: torch.Tensor,
+ img_res: Optional[int] = 224,
+ focal_length: Optional[int] = 500):
+ """Compute loss for part segmentations."""
+ device = gt_keypoints2d.device
+ gt_keypoints2d_valid = gt_keypoints2d[has_smpl == 1]
+ batch_size = gt_keypoints2d_valid.shape[0]
+
+ gt_vertices_valid = gt_vertices[has_smpl == 1]
+ gt_model_joints_valid = gt_model_joints[has_smpl == 1]
+
+ if batch_size == 0:
+ return torch.Tensor([0]).type_as(gt_keypoints2d)
+ gt_cam_t = estimate_translation(
+ gt_model_joints_valid,
+ gt_keypoints2d_valid,
+ focal_length=focal_length,
+ img_size=img_res,
+ )
+
+ K = torch.eye(3)
+ K[0, 0] = focal_length
+ K[1, 1] = focal_length
+ K[2, 2] = 1
+ K[0, 2] = img_res / 2.
+ K[1, 2] = img_res / 2.
+ K = K[None, :, :]
+
+ R = torch.eye(3)[None, :, :]
+ device = gt_keypoints2d.device
+ gt_sem_mask = visualize_smpl.render_smpl(
+ verts=gt_vertices_valid,
+ R=R,
+ K=K,
+ T=gt_cam_t,
+ render_choice='part_silhouette',
+ resolution=img_res,
+ return_tensor=True,
+ body_model=self.body_model_train,
+ device=device,
+ in_ndc=False,
+ convention='pytorch3d',
+ projection='perspective',
+ no_grad=True,
+ batch_size=batch_size,
+ verbose=False,
+ )
+ gt_sem_mask = torch.flip(gt_sem_mask, [1, 2]).squeeze(-1).detach()
+ pred_heatmap_valid = pred_heatmap[has_smpl == 1]
+ ph, pw = pred_heatmap_valid.size(2), pred_heatmap_valid.size(3)
+ h, w = gt_sem_mask.size(1), gt_sem_mask.size(2)
+ if ph != h or pw != w:
+ pred_heatmap_valid = F.interpolate(input=pred_heatmap_valid,
+ size=(h, w),
+ mode='bilinear')
+
+ loss = self.loss_segm_mask(pred_heatmap_valid, gt_sem_mask)
+ return loss
+
+ def compute_losses(self, predictions: dict, targets: dict):
+ """Compute losses."""
+ pred_betas = predictions['pred_shape'].view(-1, 10)
+ pred_pose = predictions['pred_pose'].view(-1, 24, 3, 3)
+ pred_cam = predictions['pred_cam'].view(-1, 3)
+
+ gt_keypoints3d = targets['keypoints3d']
+ gt_keypoints2d = targets['keypoints2d']
+ # pred_pose N, 24, 3, 3
+ if self.body_model_train is not None:
+ pred_output = self.body_model_train(
+ betas=pred_betas,
+ body_pose=pred_pose[:, 1:],
+ global_orient=pred_pose[:, 0].unsqueeze(1),
+ pose2rot=False,
+ num_joints=gt_keypoints2d.shape[1])
+ pred_keypoints3d = pred_output['joints']
+ pred_vertices = pred_output['vertices']
+
+ # # TODO: temp. Should we multiply confs here?
+ # pred_keypoints3d_mask = pred_output['joint_mask']
+ # keypoints3d_mask = keypoints3d_mask * pred_keypoints3d_mask
+
+ # TODO: temp solution
+ if 'valid_fit' in targets:
+ has_smpl = targets['valid_fit'].view(-1)
+ # global_orient = targets['opt_pose'][:, :3].view(-1, 1, 3)
+ gt_pose = targets['opt_pose']
+ gt_betas = targets['opt_betas']
+ gt_vertices = targets['opt_vertices']
+ else:
+ has_smpl = targets['has_smpl'].view(-1)
+ gt_pose = targets['smpl_body_pose']
+ global_orient = targets['smpl_global_orient'].view(-1, 1, 3)
+ gt_pose = torch.cat((global_orient, gt_pose), dim=1).float()
+ gt_betas = targets['smpl_betas'].float()
+
+ # gt_pose N, 72
+ if self.body_model_train is not None:
+ gt_output = self.body_model_train(
+ betas=gt_betas,
+ body_pose=gt_pose[:, 3:],
+ global_orient=gt_pose[:, :3],
+ num_joints=gt_keypoints2d.shape[1])
+ gt_vertices = gt_output['vertices']
+ gt_model_joints = gt_output['joints']
+ if 'has_keypoints3d' in targets:
+ has_keypoints3d = targets['has_keypoints3d'].squeeze(-1)
+ else:
+ has_keypoints3d = None
+ if 'has_keypoints2d' in targets:
+ has_keypoints2d = targets['has_keypoints2d'].squeeze(-1)
+ else:
+ has_keypoints2d = None
+ if 'pred_segm_mask' in predictions:
+ pred_segm_mask = predictions['pred_segm_mask']
+ losses = {}
+ if self.loss_keypoints3d is not None:
+ losses['keypoints3d_loss'] = self.compute_keypoints3d_loss(
+ pred_keypoints3d,
+ gt_keypoints3d,
+ has_keypoints3d=has_keypoints3d)
+ if self.loss_keypoints2d is not None:
+ losses['keypoints2d_loss'] = self.compute_keypoints2d_loss(
+ pred_keypoints3d,
+ pred_cam,
+ gt_keypoints2d,
+ has_keypoints2d=has_keypoints2d)
+ if self.loss_vertex is not None:
+ losses['vertex_loss'] = self.compute_vertex_loss(
+ pred_vertices, gt_vertices, has_smpl)
+ if self.loss_smpl_pose is not None:
+ losses['smpl_pose_loss'] = self.compute_smpl_pose_loss(
+ pred_pose, gt_pose, has_smpl)
+ if self.loss_smpl_betas is not None:
+ losses['smpl_betas_loss'] = self.compute_smpl_betas_loss(
+ pred_betas, gt_betas, has_smpl)
+ if self.loss_camera is not None:
+ losses['camera_loss'] = self.compute_camera_loss(pred_cam)
+ if self.loss_segm_mask is not None:
+ losses['loss_segm_mask'] = self.compute_part_segmentation_loss(
+ pred_segm_mask, gt_vertices, gt_keypoints2d, gt_model_joints,
+ has_smpl)
+
+ return losses
+
+ @abstractmethod
+ def make_fake_data(self, predictions, requires_grad):
+ pass
+
+ @abstractmethod
+ def make_real_data(self, data_batch):
+ pass
+
+ @abstractmethod
+ def prepare_targets(self, data_batch):
+ pass
+
+ def forward_train(self, **kwargs):
+ """Forward function for general training.
+
+ For mesh estimation, we do not use this interface.
+ """
+ raise NotImplementedError('This interface should not be used in '
+ 'current training schedule. Please use '
+ '`train_step` for training.')
+
+ @abstractmethod
+ def forward_test(self, img, img_metas, **kwargs):
+ """Defines the computation performed at every call when testing."""
+ pass
+
+
+class ImageBodyModelEstimator(BodyModelEstimator):
+ def make_fake_data(self, predictions: dict, requires_grad: bool):
+ pred_cam = predictions['pred_cam']
+ pred_pose = predictions['pred_pose']
+ pred_betas = predictions['pred_shape']
+ if requires_grad:
+ fake_data = (pred_cam, pred_pose, pred_betas)
+ else:
+ fake_data = (pred_cam.detach(), pred_pose.detach(),
+ pred_betas.detach())
+ return fake_data
+
+ def make_real_data(self, data_batch: dict):
+ transl = data_batch['adv_smpl_transl'].float()
+ global_orient = data_batch['adv_smpl_global_orient']
+ body_pose = data_batch['adv_smpl_body_pose']
+ betas = data_batch['adv_smpl_betas'].float()
+ pose = torch.cat((global_orient, body_pose), dim=-1).float()
+ real_data = (transl, pose, betas)
+ return real_data
+
+ def prepare_targets(self, data_batch: dict):
+ # Image Mesh Estimator does not need extra process for ground truth
+ return data_batch
+
+ def forward_test(self, img: torch.Tensor, img_metas: dict, **kwargs):
+ """Defines the computation performed at every call when testing."""
+ if self.backbone is not None:
+ features = self.backbone(img)
+ else:
+ features = kwargs['features']
+
+ if self.neck is not None:
+ features = self.neck(features)
+ predictions = self.head(features)
+ pred_pose = predictions['pred_pose']
+ pred_betas = predictions['pred_shape']
+ pred_cam = predictions['pred_cam']
+ pred_output = self.body_model_test(
+ betas=pred_betas,
+ body_pose=pred_pose[:, 1:],
+ global_orient=pred_pose[:, 0].unsqueeze(1),
+ pose2rot=False)
+
+ pred_vertices = pred_output['vertices']
+ pred_keypoints_3d = pred_output['joints']
+ all_preds = {}
+ all_preds['keypoints_3d'] = pred_keypoints_3d.detach().cpu().numpy()
+ all_preds['smpl_pose'] = pred_pose.detach().cpu().numpy()
+ all_preds['smpl_beta'] = pred_betas.detach().cpu().numpy()
+ all_preds['camera'] = pred_cam.detach().cpu().numpy()
+ all_preds['vertices'] = pred_vertices.detach().cpu().numpy()
+ image_path = []
+ for img_meta in img_metas:
+ image_path.append(img_meta['image_path'])
+ all_preds['image_path'] = image_path
+ all_preds['image_idx'] = kwargs['sample_idx']
+ return all_preds
+
+
+class VideoBodyModelEstimator(BodyModelEstimator):
+ def make_fake_data(self, predictions: dict, requires_grad: bool):
+ B, T = predictions['pred_cam'].shape[:2]
+ pred_cam_vec = predictions['pred_cam']
+ pred_betas_vec = predictions['pred_shape']
+ pred_pose = predictions['pred_pose']
+ pred_pose_vec = rotation_matrix_to_angle_axis(pred_pose.view(-1, 3, 3))
+ pred_pose_vec = pred_pose_vec.contiguous().view(B, T, -1)
+ pred_theta_vec = (pred_cam_vec, pred_pose_vec, pred_betas_vec)
+ pred_theta_vec = torch.cat(pred_theta_vec, dim=-1)
+
+ if not requires_grad:
+ pred_theta_vec = pred_theta_vec.detach()
+ return pred_theta_vec[:, :, 6:75]
+
+ def make_real_data(self, data_batch: dict):
+ B, T = data_batch['adv_smpl_transl'].shape[:2]
+ transl = data_batch['adv_smpl_transl'].view(B, T, -1)
+ global_orient = \
+ data_batch['adv_smpl_global_orient'].view(B, T, -1)
+ body_pose = data_batch['adv_smpl_body_pose'].view(B, T, -1)
+ betas = data_batch['adv_smpl_betas'].view(B, T, -1)
+ real_data = (transl, global_orient, body_pose, betas)
+ real_data = torch.cat(real_data, dim=-1).float()
+ return real_data[:, :, 6:75]
+
+ def prepare_targets(self, data_batch: dict):
+ # Video Mesh Estimator needs squeeze first two dimensions
+ B, T = data_batch['smpl_body_pose'].shape[:2]
+
+ output = {
+ 'smpl_body_pose': data_batch['smpl_body_pose'].view(-1, 23, 3),
+ 'smpl_global_orient': data_batch['smpl_global_orient'].view(-1, 3),
+ 'smpl_betas': data_batch['smpl_betas'].view(-1, 10),
+ 'has_smpl': data_batch['has_smpl'].view(-1),
+ 'keypoints3d': data_batch['keypoints3d'].view(B * T, -1, 4),
+ 'keypoints2d': data_batch['keypoints2d'].view(B * T, -1, 3)
+ }
+ return output
+
+ def forward_test(self, img_metas: dict, **kwargs):
+ """Defines the computation performed at every call when testing."""
+ if self.backbone is not None:
+ features = self.backbone(kwargs['img'])
+ else:
+ features = kwargs['features']
+
+ if self.neck is not None:
+ features = self.neck(features)
+
+ B, T = features.shape[:2]
+ predictions = self.head(features)
+ pred_pose = predictions['pred_pose'].view(-1, 24, 3, 3)
+ pred_betas = predictions['pred_shape'].view(-1, 10)
+ pred_cam = predictions['pred_cam'].view(-1, 3)
+
+ pred_output = self.body_model_test(
+ betas=pred_betas,
+ body_pose=pred_pose[:, 1:],
+ global_orient=pred_pose[:, 0].unsqueeze(1),
+ pose2rot=False)
+
+ pred_vertices = pred_output['vertices']
+ pred_keypoints_3d = pred_output['joints']
+ all_preds = {}
+ all_preds['keypoints_3d'] = pred_keypoints_3d.detach().cpu().numpy()
+ all_preds['smpl_pose'] = pred_pose.detach().cpu().numpy()
+ all_preds['smpl_beta'] = pred_betas.detach().cpu().numpy()
+ all_preds['camera'] = pred_cam.detach().cpu().numpy()
+ all_preds['vertices'] = pred_vertices.detach().cpu().numpy()
+ all_preds['image_idx'] = \
+ kwargs['sample_idx'].detach().cpu().numpy().reshape((-1))
+ return all_preds
diff --git a/detrsmpl/models/backbones/__init__.py b/detrsmpl/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/models/backbones/builder.py b/detrsmpl/models/backbones/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..aced32c469d8b18001877178ef9d11e8db5d21c0
--- /dev/null
+++ b/detrsmpl/models/backbones/builder.py
@@ -0,0 +1,22 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from mmcv.utils import Registry
+
+from .hrnet import PoseHighResolutionNet, PoseHighResolutionNetExpose
+from .resnet import ResNet, ResNetV1d
+
+BACKBONES = Registry('backbones')
+
+BACKBONES.register_module(name='ResNet', module=ResNet)
+BACKBONES.register_module(name='ResNetV1d', module=ResNetV1d)
+BACKBONES.register_module(name='PoseHighResolutionNet',
+ module=PoseHighResolutionNet)
+BACKBONES.register_module(name='PoseHighResolutionNetExpose',
+ module=PoseHighResolutionNetExpose)
+
+
+def build_backbone(cfg):
+ """Build backbone."""
+ if cfg is None:
+ return None
+ return BACKBONES.build(cfg)
diff --git a/detrsmpl/models/backbones/hrnet.py b/detrsmpl/models/backbones/hrnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5f950fbc1be7399ac70453df58559cf938519c5
--- /dev/null
+++ b/detrsmpl/models/backbones/hrnet.py
@@ -0,0 +1,754 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmcv.runner import BaseModule, ModuleList, Sequential
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from .resnet import BasicBlock, Bottleneck
+
+
+class HRModule(BaseModule):
+ """High-Resolution Module for HRNet.
+
+ In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
+ is in this module.
+ """
+ def __init__(self,
+ num_branches,
+ blocks,
+ num_blocks,
+ in_channels,
+ num_channels,
+ multiscale_output=True,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ block_init_cfg=None,
+ init_cfg=None):
+ super(HRModule, self).__init__(init_cfg)
+ self.block_init_cfg = block_init_cfg
+ self._check_branches(num_branches, num_blocks, in_channels,
+ num_channels)
+
+ self.in_channels = in_channels
+ self.num_branches = num_branches
+
+ self.multiscale_output = multiscale_output
+ self.norm_cfg = norm_cfg
+ self.conv_cfg = conv_cfg
+ self.with_cp = with_cp
+ self.branches = self._make_branches(num_branches, blocks, num_blocks,
+ num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(inplace=False)
+
+ def _check_branches(self, num_branches, num_blocks, in_channels,
+ num_channels):
+ if num_branches != len(num_blocks):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_BLOCKS({len(num_blocks)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_CHANNELS({len(num_channels)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(in_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_INCHANNELS({len(in_channels)})'
+ raise ValueError(error_msg)
+
+ def _make_one_branch(self,
+ branch_index,
+ block,
+ num_blocks,
+ num_channels,
+ stride=1):
+ downsample = None
+ if stride != 1 or \
+ self.in_channels[branch_index] != \
+ num_channels[branch_index] * block.expansion:
+ downsample = nn.Sequential(
+ build_conv_layer(self.conv_cfg,
+ self.in_channels[branch_index],
+ num_channels[branch_index] * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, num_channels[branch_index] *
+ block.expansion)[1])
+
+ layers = []
+ layers.append(
+ block(self.in_channels[branch_index],
+ num_channels[branch_index],
+ stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg,
+ init_cfg=self.block_init_cfg))
+ self.in_channels[branch_index] = \
+ num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(
+ block(self.in_channels[branch_index],
+ num_channels[branch_index],
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg,
+ init_cfg=self.block_init_cfg))
+
+ return Sequential(*layers)
+
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(
+ self._make_one_branch(i, block, num_blocks, num_channels))
+
+ return ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ in_channels = self.in_channels
+ fuse_layers = []
+ num_out_branches = num_branches if self.multiscale_output else 1
+ for i in range(num_out_branches):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(
+ nn.Sequential(
+ build_conv_layer(self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ build_norm_layer(self.norm_cfg, in_channels[i])[1],
+ nn.Upsample(scale_factor=2**(j - i),
+ mode='nearest')))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv_downsamples = []
+ for k in range(i - j):
+ if k == i - j - 1:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[i])[1]))
+ else:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(self.conv_cfg,
+ in_channels[j],
+ in_channels[j],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[j])[1],
+ nn.ReLU(inplace=False)))
+ fuse_layer.append(nn.Sequential(*conv_downsamples))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def forward(self, x):
+ """Forward function."""
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ y = 0
+ for j in range(self.num_branches):
+ if i == j:
+ y += x[j]
+ else:
+ y += self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+ return x_fuse
+
+
+class PoseHighResolutionNet(BaseModule):
+ """HRNet backbone.
+ `High-Resolution Representations for Labeling Pixels and Regions
+ arXiv: `_.
+ Args:
+ extra (dict): Detailed configuration for each stage of HRNet.
+ There must be 4 stages, the configuration for each stage must have
+ 5 keys:
+ - num_modules(int): The number of HRModule in this stage.
+ - num_branches(int): The number of branches in the HRModule.
+ - block(str): The type of convolution block.
+ - num_blocks(tuple): The number of blocks in each branch.
+ The length must be equal to num_branches.
+ - num_channels(tuple): The number of channels in each branch.
+ The length must be equal to num_branches.
+ in_channels (int): Number of input image channels. Default: 3.
+ conv_cfg (dict): Dictionary to construct and config conv layer.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: True.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Default: False.
+ multiscale_output (bool): Whether to output multi-level features
+ produced by multiple branches. If False, only the first level
+ feature will be output. Default: True.
+ num_joints(int): the number of output for the final layer. Default: 24.
+ pretrained (str, optional): Model pretrained path. Default: None.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
+
+ def __init__(self,
+ extra,
+ in_channels=3,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ norm_eval=True,
+ with_cp=False,
+ num_joints=24,
+ zero_init_residual=False,
+ multiscale_output=True,
+ pretrained=None,
+ init_cfg=None):
+ super(PoseHighResolutionNet, self).__init__(init_cfg)
+
+ self.pretrained = pretrained
+ assert not (init_cfg and pretrained), \
+ 'init_cfg and pretrained cannot be specified at the same time'
+ if isinstance(pretrained, str):
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ elif pretrained is None:
+ if init_cfg is None:
+ self.init_cfg = [
+ dict(type='Kaiming', layer='Conv2d'),
+ dict(type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ # Assert configurations of 4 stages are in extra
+ assert 'stage1' in extra and 'stage2' in extra \
+ and 'stage3' in extra and 'stage4' in extra
+ # Assert whether the length of `num_blocks` and `num_channels` are
+ # equal to `num_branches`
+ for i in range(4):
+ cfg = extra[f'stage{i + 1}']
+ assert len(cfg['num_blocks']) == cfg['num_branches'] and \
+ len(cfg['num_channels']) == cfg['num_branches']
+
+ self.extra = extra
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ self.zero_init_residual = zero_init_residual
+
+ # stem net
+ self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
+
+ self.conv1 = build_conv_layer(self.conv_cfg,
+ in_channels,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(self.conv_cfg,
+ 64,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # stage 1
+ self.stage1_cfg = self.extra['stage1']
+ num_channels = self.stage1_cfg['num_channels'][0]
+ block_type = self.stage1_cfg['block']
+ num_blocks = self.stage1_cfg['num_blocks'][0]
+
+ block = self.blocks_dict[block_type]
+ stage1_out_channels = num_channels * block.expansion
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
+
+ # stage 2
+ self.stage2_cfg = self.extra['stage2']
+ num_channels = self.stage2_cfg['num_channels']
+ block_type = self.stage2_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition1 = self._make_transition_layer([stage1_out_channels],
+ num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(
+ self.stage2_cfg, num_channels)
+
+ # stage 3
+ self.stage3_cfg = self.extra['stage3']
+ num_channels = self.stage3_cfg['num_channels']
+ block_type = self.stage3_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition2 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+ self.stage3, pre_stage_channels = self._make_stage(
+ self.stage3_cfg, num_channels)
+
+ # stage 4
+ self.stage4_cfg = self.extra['stage4']
+ num_channels = self.stage4_cfg['num_channels']
+ block_type = self.stage4_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition3 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+ self.stage4, pre_stage_channels = self._make_stage(
+ self.stage4_cfg, num_channels, multiscale_output=multiscale_output)
+ # self.pretrained_layers = extra['pretrained_layers']
+ self.final_layer = build_conv_layer(
+ cfg=self.conv_cfg,
+ in_channels=pre_stage_channels[0],
+ out_channels=num_joints,
+ kernel_size=extra['final_conv_kernel'],
+ stride=1,
+ padding=1 if extra['final_conv_kernel'] == 3 else 0)
+ if extra['downsample'] and extra['use_conv']:
+ self.downsample_stage_1 = self._make_downsample_layer(
+ 3, num_channel=self.stage2_cfg['num_channels'][0])
+ self.downsample_stage_2 = self._make_downsample_layer(
+ 2, num_channel=self.stage2_cfg['num_channels'][-1])
+ self.downsample_stage_3 = self._make_downsample_layer(
+ 1, num_channel=self.stage3_cfg['num_channels'][-1])
+ elif not extra['downsample'] and extra['use_conv']:
+ self.upsample_stage_2 = self._make_upsample_layer(
+ 1, num_channel=self.stage2_cfg['num_channels'][-1])
+ self.upsample_stage_3 = self._make_upsample_layer(
+ 2, num_channel=self.stage3_cfg['num_channels'][-1])
+ self.upsample_stage_4 = self._make_upsample_layer(
+ 3, num_channel=self.stage4_cfg['num_channels'][-1])
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: the normalization layer named "norm2" """
+ return getattr(self, self.norm2_name)
+
+ def _make_transition_layer(self, num_channels_pre_layer,
+ num_channels_cur_layer):
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(
+ nn.Sequential(
+ build_conv_layer(self.conv_cfg,
+ num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ num_channels_cur_layer[i])[1],
+ nn.ReLU(inplace=True)))
+ else:
+ transition_layers.append(None)
+ else:
+ conv_downsamples = []
+ for j in range(i + 1 - num_branches_pre):
+ in_channels = num_channels_pre_layer[-1]
+ out_channels = num_channels_cur_layer[i] \
+ if j == i - num_branches_pre else in_channels
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(self.conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, out_channels)[1],
+ nn.ReLU(inplace=True)))
+ transition_layers.append(nn.Sequential(*conv_downsamples))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ build_conv_layer(self.conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
+
+ layers = []
+ block_init_cfg = None
+ if self.pretrained is None and not hasattr(
+ self, 'init_cfg') and self.zero_init_residual:
+ if block is BasicBlock:
+ block_init_cfg = dict(type='Constant',
+ val=0,
+ override=dict(name='norm2'))
+ elif block is Bottleneck:
+ block_init_cfg = dict(type='Constant',
+ val=0,
+ override=dict(name='norm3'))
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg,
+ init_cfg=block_init_cfg,
+ ))
+ inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(
+ block(inplanes,
+ planes,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg,
+ init_cfg=block_init_cfg))
+
+ return Sequential(*layers)
+
+ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
+ num_modules = layer_config['num_modules']
+ num_branches = layer_config['num_branches']
+ num_blocks = layer_config['num_blocks']
+ num_channels = layer_config['num_channels']
+ block = self.blocks_dict[layer_config['block']]
+
+ hr_modules = []
+ block_init_cfg = None
+ if self.pretrained is None and not hasattr(
+ self, 'init_cfg') and self.zero_init_residual:
+ if block is BasicBlock:
+ block_init_cfg = dict(type='Constant',
+ val=0,
+ override=dict(name='norm2'))
+ elif block is Bottleneck:
+ block_init_cfg = dict(type='Constant',
+ val=0,
+ override=dict(name='norm3'))
+
+ for i in range(num_modules):
+ # multi_scale_output is only used for the last module
+ if not multiscale_output and i == num_modules - 1:
+ reset_multiscale_output = False
+ else:
+ reset_multiscale_output = True
+
+ hr_modules.append(
+ HRModule(num_branches,
+ block,
+ num_blocks,
+ in_channels,
+ num_channels,
+ reset_multiscale_output,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg,
+ block_init_cfg=block_init_cfg))
+
+ return Sequential(*hr_modules), in_channels
+
+ def _make_upsample_layer(self, num_layers, num_channel, kernel_size=3):
+ layers = []
+ for i in range(num_layers):
+ layers.append(
+ nn.Upsample(scale_factor=2,
+ mode='bilinear',
+ align_corners=True))
+ layers.append(
+ build_conv_layer(
+ cfg=self.conv_cfg,
+ in_channels=num_channel,
+ out_channels=num_channel,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=1,
+ bias=False,
+ ))
+ layers.append(build_norm_layer(self.norm_cfg, num_channel)[1])
+ layers.append(nn.ReLU(inplace=True))
+
+ return nn.Sequential(*layers)
+
+ def _make_downsample_layer(self, num_layers, num_channel, kernel_size=3):
+ layers = []
+ for i in range(num_layers):
+ layers.append(
+ build_conv_layer(
+ cfg=self.conv_cfg,
+ in_channels=num_channel,
+ out_channels=num_channel,
+ kernel_size=kernel_size,
+ stride=2,
+ padding=1,
+ bias=False,
+ ))
+ layers.append(build_norm_layer(self.norm_cfg, num_channel)[1])
+ layers.append(nn.ReLU(inplace=True))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.norm2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = []
+ for i in range(self.stage2_cfg['num_branches']):
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+
+ x_list = []
+ for i in range(self.stage3_cfg['num_branches']):
+ if self.transition2[i] is not None:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+
+ x_list = []
+ for i in range(self.stage4_cfg['num_branches']):
+ if self.transition3[i] is not None:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage4(x_list)
+ if self.extra['return_list']:
+ return y_list
+ elif self.extra['downsample']:
+ if self.extra['use_conv']:
+ # Downsampling with strided convolutions
+ x1 = self.downsample_stage_1(y_list[0])
+ x2 = self.downsample_stage_2(y_list[1])
+ x3 = self.downsample_stage_3(y_list[2])
+ x = torch.cat([x1, x2, x3, y_list[3]], 1)
+ else:
+ # Downsampling with interpolation
+ x0_h, x0_w = y_list[3].size(2), y_list[3].size(3)
+ x1 = F.interpolate(y_list[0],
+ size=(x0_h, x0_w),
+ mode='bilinear',
+ align_corners=True)
+ x2 = F.interpolate(y_list[1],
+ size=(x0_h, x0_w),
+ mode='bilinear',
+ align_corners=True)
+ x3 = F.interpolate(y_list[2],
+ size=(x0_h, x0_w),
+ mode='bilinear',
+ align_corners=True)
+ x = torch.cat([x1, x2, x3, y_list[3]], 1)
+ else:
+ if self.extra['use_conv']:
+ # Upsampling with interpolations + convolutions
+ x1 = self.upsample_stage_2(y_list[1])
+ x2 = self.upsample_stage_3(y_list[2])
+ x3 = self.upsample_stage_4(y_list[3])
+ x = torch.cat([y_list[0], x1, x2, x3], 1)
+ else:
+ # Upsampling with interpolation
+ x0_h, x0_w = y_list[0].size(2), y_list[0].size(3)
+ x1 = F.interpolate(y_list[1],
+ size=(x0_h, x0_w),
+ mode='bilinear',
+ align_corners=True)
+ x2 = F.interpolate(y_list[2],
+ size=(x0_h, x0_w),
+ mode='bilinear',
+ align_corners=True)
+ x3 = F.interpolate(y_list[3],
+ size=(x0_h, x0_w),
+ mode='bilinear',
+ align_corners=True)
+ x = torch.cat([y_list[0], x1, x2, x3], 1)
+ return x
+
+ def train(self, mode=True):
+ """Convert the model into training mode will keeping the normalization
+ layer freezed."""
+ super(PoseHighResolutionNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+
+class PoseHighResolutionNetExpose(PoseHighResolutionNet):
+ """HRNet backbone for expose."""
+ def __init__(self,
+ extra,
+ in_channels=3,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ norm_eval=True,
+ with_cp=False,
+ num_joints=24,
+ zero_init_residual=False,
+ multiscale_output=True,
+ pretrained=None,
+ init_cfg=None):
+ super().__init__(extra, in_channels, conv_cfg, norm_cfg, norm_eval,
+ with_cp, num_joints, zero_init_residual,
+ multiscale_output, pretrained, init_cfg)
+ in_dims = (2**2 * self.stage2_cfg['num_channels'][-1] +
+ 2**1 * self.stage3_cfg['num_channels'][-1] +
+ self.stage4_cfg['num_channels'][-1])
+ self.conv_layers = self._make_conv_layer(in_channels=in_dims,
+ num_layers=5)
+ self.subsample_3 = self._make_subsample_layer(
+ in_channels=self.stage2_cfg['num_channels'][-1], num_layers=2)
+ self.subsample_2 = self._make_subsample_layer(
+ in_channels=self.stage3_cfg['num_channels'][-1], num_layers=1)
+
+ def _make_conv_layer(self,
+ in_channels=2048,
+ num_layers=3,
+ num_filters=2048,
+ stride=1):
+
+ layers = []
+ for i in range(num_layers):
+
+ downsample = nn.Conv2d(in_channels,
+ num_filters,
+ stride=1,
+ kernel_size=1,
+ bias=False)
+ layers.append(
+ Bottleneck(in_channels,
+ num_filters // 4,
+ downsample=downsample))
+ in_channels = num_filters
+
+ return nn.Sequential(*layers)
+
+ def _make_subsample_layer(self, in_channels=96, num_layers=3, stride=2):
+
+ layers = []
+ for i in range(num_layers):
+
+ layers.append(
+ nn.Conv2d(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1))
+ in_channels = 2 * in_channels
+ layers.append(nn.BatchNorm2d(in_channels, momentum=0.1))
+ layers.append(nn.ReLU(inplace=True))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.norm2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = []
+ for i in range(self.stage2_cfg['num_branches']):
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+
+ x_list = []
+ for i in range(self.stage3_cfg['num_branches']):
+ if self.transition2[i] is not None:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+
+ x_list = []
+ for i in range(self.stage4_cfg['num_branches']):
+ if self.transition3[i] is not None:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ x3 = self.subsample_3(x_list[1])
+ x2 = self.subsample_2(x_list[2])
+ x1 = x_list[3]
+ xf = self.conv_layers(torch.cat([x3, x2, x1], dim=1))
+ xf = xf.mean(dim=(2, 3))
+ xf = xf.view(xf.size(0), -1)
+ return xf
diff --git a/detrsmpl/models/backbones/resnet.py b/detrsmpl/models/backbones/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7333646625b104a35eccf4e1405caa6fb81aaa88
--- /dev/null
+++ b/detrsmpl/models/backbones/resnet.py
@@ -0,0 +1,662 @@
+import warnings
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
+from mmcv.runner import BaseModule
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ..utils import ResLayer
+
+
+class BasicBlock(BaseModule):
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None,
+ init_cfg=None):
+ super(BasicBlock, self).__init__(init_cfg)
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+
+ self.conv1 = build_conv_layer(conv_cfg,
+ inplanes,
+ planes,
+ 3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(conv_cfg,
+ planes,
+ planes,
+ 3,
+ padding=1,
+ bias=False)
+ self.add_module(self.norm2_name, norm2)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+
+ def forward(self, x):
+ """Forward function."""
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(BaseModule):
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None,
+ init_cfg=None):
+ """Bottleneck block for ResNet.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super(Bottleneck, self).__init__(init_cfg)
+ assert style in ['pytorch', 'caffe']
+ assert dcn is None or isinstance(dcn, dict)
+ assert plugins is None or isinstance(plugins, list)
+ if plugins is not None:
+ allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
+ assert all(p['position'] in allowed_position for p in plugins)
+
+ self.inplanes = inplanes
+ self.planes = planes
+ self.stride = stride
+ self.dilation = dilation
+ self.style = style
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.dcn = dcn
+ self.with_dcn = dcn is not None
+ self.plugins = plugins
+ self.with_plugins = plugins is not None
+
+ if self.with_plugins:
+ # collect plugins for conv1/conv2/conv3
+ self.after_conv1_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv1'
+ ]
+ self.after_conv2_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv2'
+ ]
+ self.after_conv3_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv3'
+ ]
+
+ if self.style == 'pytorch':
+ self.conv1_stride = 1
+ self.conv2_stride = stride
+ else:
+ self.conv1_stride = stride
+ self.conv2_stride = 1
+
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(norm_cfg,
+ planes * self.expansion,
+ postfix=3)
+
+ self.conv1 = build_conv_layer(conv_cfg,
+ inplanes,
+ planes,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = dcn.pop('fallback_on_stride', False)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(conv_cfg,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ self.conv2 = build_conv_layer(dcn,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(conv_cfg,
+ planes,
+ planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+
+ if self.with_plugins:
+ self.after_conv1_plugin_names = self.make_block_plugins(
+ planes, self.after_conv1_plugins)
+ self.after_conv2_plugin_names = self.make_block_plugins(
+ planes, self.after_conv2_plugins)
+ self.after_conv3_plugin_names = self.make_block_plugins(
+ planes * self.expansion, self.after_conv3_plugins)
+
+ def make_block_plugins(self, in_channels, plugins):
+ """make plugins for block.
+
+ Args:
+ in_channels (int): Input channels of plugin.
+ plugins (list[dict]): List of plugins cfg to build.
+ Returns:
+ list[str]: List of the names of plugin.
+ """
+ assert isinstance(plugins, list)
+ plugin_names = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ name, layer = build_plugin_layer(plugin,
+ in_channels=in_channels,
+ postfix=plugin.pop('postfix', ''))
+ assert not hasattr(self, name), f'duplicate plugin {name}'
+ self.add_module(name, layer)
+ plugin_names.append(name)
+ return plugin_names
+
+ def forward_plugin(self, x, plugin_names):
+ out = x
+ for name in plugin_names:
+ out = getattr(self, name)(x)
+ return out
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+
+ @property
+ def norm3(self):
+ """nn.Module: normalization layer after the third convolution layer"""
+ return getattr(self, self.norm3_name)
+
+ def forward(self, x):
+ """Forward function."""
+ def _inner_forward(x):
+ identity = x
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(BaseModule):
+ """ResNet backbone.
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ stem_channels (int | None): Number of stem channels. If not specified,
+ it will be the same as `base_channels`. Default: None.
+ base_channels (int): Number of base channels of res layer. Default: 64.
+ in_channels (int): Number of input image channels. Default: 3.
+ num_stages (int): Resnet stages. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ plugins (list[dict]): List of plugins for stages, each dict contains:
+ - cfg (dict, required): Cfg dict to build plugin.
+ - position (str, required): Position inside block to insert
+ plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'.
+ - stages (tuple[bool], optional): Stages to apply plugin, length
+ should be same as 'num_stages'.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+ pretrained (str, optional): model pretrained path. Default: None
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ Example:
+ >>> from detrsmpl.models.backbones.resnet import ResNet
+ >>> import torch
+ >>> self = ResNet(depth=18)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 64, 8, 8)
+ (1, 128, 4, 4)
+ (1, 256, 2, 2)
+ (1, 512, 1, 1)
+ """
+
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ depth,
+ in_channels=3,
+ stem_channels=None,
+ base_channels=64,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3),
+ style='pytorch',
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ dcn=None,
+ stage_with_dcn=(False, False, False, False),
+ plugins=None,
+ with_cp=False,
+ zero_init_residual=True,
+ pretrained=None,
+ init_cfg=None):
+ super(ResNet, self).__init__(init_cfg)
+ self.zero_init_residual = zero_init_residual
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+
+ block_init_cfg = None
+ assert not (init_cfg and pretrained), \
+ 'init_cfg and pretrained cannot be setting at the same time'
+ if isinstance(pretrained, str):
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ elif pretrained is None:
+ if init_cfg is None:
+ self.init_cfg = [
+ dict(type='Kaiming', layer='Conv2d'),
+ dict(type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]
+ block = self.arch_settings[depth][0]
+ if self.zero_init_residual:
+ if block is BasicBlock:
+ block_init_cfg = dict(type='Constant',
+ val=0,
+ override=dict(name='norm2'))
+ elif block is Bottleneck:
+ block_init_cfg = dict(type='Constant',
+ val=0,
+ override=dict(name='norm3'))
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ self.depth = depth
+ if stem_channels is None:
+ stem_channels = base_channels
+ self.stem_channels = stem_channels
+ self.base_channels = base_channels
+ self.num_stages = num_stages
+ assert num_stages >= 1 and num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.dcn = dcn
+ self.stage_with_dcn = stage_with_dcn
+ if dcn is not None:
+ assert len(stage_with_dcn) == num_stages
+ self.plugins = plugins
+ self.block, stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ self.inplanes = stem_channels
+
+ self._make_stem_layer(in_channels, stem_channels)
+
+ self.res_layers = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ dcn = self.dcn if self.stage_with_dcn[i] else None
+ if plugins is not None:
+ stage_plugins = self.make_stage_plugins(plugins, i)
+ else:
+ stage_plugins = None
+ planes = base_channels * 2**i
+ res_layer = self.make_res_layer(block=self.block,
+ inplanes=self.inplanes,
+ planes=planes,
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ dcn=dcn,
+ plugins=stage_plugins,
+ init_cfg=block_init_cfg)
+ self.inplanes = planes * self.block.expansion
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ self.feat_dim = self.block.expansion * base_channels * 2**(
+ len(self.stage_blocks) - 1)
+
+ def make_stage_plugins(self, plugins, stage_idx):
+ """Make plugins for ResNet ``stage_idx`` th stage.
+ Currently we support to insert ``context_block``,
+ ``empirical_attention_block``, ``nonlocal_block`` into the backbone
+ like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
+ Bottleneck.
+ An example of plugins format could be:
+ Examples:
+ >>> plugins=[
+ ... dict(cfg=dict(type='xxx', arg1='xxx'),
+ ... stages=(False, True, True, True),
+ ... position='after_conv2'),
+ ... dict(cfg=dict(type='yyy'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='1'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='2'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3')
+ ... ]
+ >>> self = ResNet(depth=18)
+ >>> stage_plugins = self.make_stage_plugins(plugins, 0)
+ >>> assert len(stage_plugins) == 3
+ Suppose ``stage_idx=0``, the structure of blocks in the stage would be:
+ .. code-block:: none
+ conv1-> conv2->conv3->yyy->zzz1->zzz2
+ Suppose 'stage_idx=1', the structure of blocks in the stage would be:
+ .. code-block:: none
+ conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
+ If stages is missing, the plugin would be applied to all stages.
+ Args:
+ plugins (list[dict]): List of plugins cfg to build. The postfix is
+ required if multiple same type plugins are inserted.
+ stage_idx (int): Index of stage to build
+ Returns:
+ list[dict]: Plugins for current stage
+ """
+ stage_plugins = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ stages = plugin.pop('stages', None)
+ assert stages is None or len(stages) == self.num_stages
+ # whether to insert plugin into current stage
+ if stages is None or stages[stage_idx]:
+ stage_plugins.append(plugin)
+
+ return stage_plugins
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``."""
+ return ResLayer(**kwargs)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ def _make_stem_layer(self, in_channels, stem_channels):
+ if self.deep_stem:
+ self.stem = nn.Sequential(
+ build_conv_layer(self.conv_cfg,
+ in_channels,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(self.conv_cfg,
+ stem_channels // 2,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(self.conv_cfg,
+ stem_channels // 2,
+ stem_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels)[1],
+ nn.ReLU(inplace=True))
+ else:
+ self.conv1 = build_conv_layer(self.conv_cfg,
+ in_channels,
+ stem_channels,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(self.norm_cfg,
+ stem_channels,
+ postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ if self.deep_stem:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ else:
+ self.norm1.eval()
+ for m in [self.conv1, self.norm1]:
+ for param in m.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = getattr(self, f'layer{i}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ """Forward function."""
+ if self.deep_stem:
+ x = self.stem(x)
+ else:
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep normalization layer
+ freezed."""
+ super(ResNet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+
+class ResNetV1d(ResNet):
+ r"""ResNetV1d variant described in `Bag of Tricks
+ `_.
+ Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
+ the input stem with three 3x3 convs. And in the downsampling block, a 2x2
+ avg_pool with stride 2 is added before conv, whose stride is changed to 1.
+ """
+ def __init__(self, **kwargs):
+ super(ResNetV1d, self).__init__(deep_stem=True,
+ avg_down=True,
+ **kwargs)
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding."""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution."""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False)
diff --git a/detrsmpl/models/body_models/__init__.py b/detrsmpl/models/body_models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/models/body_models/builder.py b/detrsmpl/models/body_models/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4afc217da15ca64b7c4948c804a6c0f8f5dbbad3
--- /dev/null
+++ b/detrsmpl/models/body_models/builder.py
@@ -0,0 +1,33 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from mmcv.utils import Registry
+
+from .flame import FLAME, FLAMELayer
+from .mano import MANO, MANOLayer
+from .smpl import SMPL, GenderedSMPL, HybrIKSMPL
+from .smplx import SMPLX, SMPLXLayer
+from .star import STAR
+
+BODY_MODELS = Registry('body_models')
+
+BODY_MODELS.register_module(name=['SMPL', 'smpl'], module=SMPL)
+BODY_MODELS.register_module(name='GenderedSMPL', module=GenderedSMPL)
+BODY_MODELS.register_module(name=['STAR', 'star'], module=STAR)
+BODY_MODELS.register_module(
+ name=['HybrIKSMPL', 'HybrIKsmpl', 'hybriksmpl', 'hybrik', 'hybrIK'],
+ module=HybrIKSMPL)
+BODY_MODELS.register_module(name=['SMPLX', 'smplx'], module=SMPLX)
+BODY_MODELS.register_module(name=['flame', 'FLAME'], module=FLAME)
+BODY_MODELS.register_module(name=['MANO', 'mano'], module=MANO)
+BODY_MODELS.register_module(name=['SMPLXLayer', 'smplxlayer'],
+ module=SMPLXLayer)
+BODY_MODELS.register_module(name=['MANOLayer', 'manolayer'], module=MANOLayer)
+BODY_MODELS.register_module(name=['FLAMELayer', 'flamelayer'],
+ module=FLAMELayer)
+
+
+def build_body_model(cfg):
+ """Build body_models."""
+ if cfg is None:
+ return None
+ return BODY_MODELS.build(cfg)
diff --git a/detrsmpl/models/body_models/flame.py b/detrsmpl/models/body_models/flame.py
new file mode 100644
index 0000000000000000000000000000000000000000..323e1a030985db109916490dad9400f911fa03fa
--- /dev/null
+++ b/detrsmpl/models/body_models/flame.py
@@ -0,0 +1,187 @@
+import numpy as np
+import torch
+from smplx import FLAME as _FLAME
+from smplx import FLAMELayer as _FLAMELayer
+
+from detrsmpl.core.conventions.keypoints_mapping import (
+ convert_kps,
+ get_keypoint_num,
+)
+
+
+class FLAME(_FLAME):
+ """Extension of the official FLAME implementation."""
+ head_pose_keys = {'global_orient', 'jaw_pose'}
+ full_pose_keys = {
+ 'global_orient', 'neck_pose', 'jaw_pose', 'leye_pose', 'reye_pose'
+ }
+
+ NUM_VERTS = 5023
+ NUM_FACES = 9976
+
+ def __init__(self,
+ *args,
+ keypoint_src: str = 'flame',
+ keypoint_dst: str = 'human_data',
+ keypoint_approximate: bool = False,
+ **kwargs):
+ """
+ Args:
+ *args: extra arguments for FLAME initialization.
+ keypoint_src: source convention of keypoints. This convention
+ is used for keypoints obtained from joint regressors.
+ Keypoints then undergo conversion into keypoint_dst
+ convention.
+ keypoint_dst: destination convention of keypoints. This convention
+ is used for keypoints in the output.
+ keypoint_approximate: whether to use approximate matching in
+ convention conversion for keypoints.
+ **kwargs: extra keyword arguments for FLAME initialization.
+
+ Returns:
+ None
+ """
+ super(FLAME, self).__init__(*args, **kwargs)
+ self.keypoint_src = keypoint_src
+ self.keypoint_dst = keypoint_dst
+ self.keypoint_approximate = keypoint_approximate
+
+ self.num_verts = self.get_num_verts()
+ self.num_faces = self.get_num_faces()
+ self.num_joints = get_keypoint_num(convention=self.keypoint_dst)
+
+ def forward(self,
+ *args,
+ return_verts: bool = True,
+ return_full_pose: bool = False,
+ **kwargs) -> dict:
+ """Forward function.
+
+ Args:
+ *args: extra arguments for FLAME
+ return_verts: whether to return vertices
+ return_full_pose: whether to return full pose parameters
+ **kwargs: extra arguments for FLAME
+
+ Returns:
+ output: contains output parameters and attributes
+ """
+ flame_output = super(FLAME, self).forward(*args, **kwargs)
+ joints = flame_output.joints
+ joints, joint_mask = convert_kps(joints,
+ src=self.keypoint_src,
+ dst=self.keypoint_dst,
+ approximate=self.keypoint_approximate)
+ if isinstance(joint_mask, np.ndarray):
+ joint_mask = torch.tensor(joint_mask,
+ dtype=torch.uint8,
+ device=joints.device)
+
+ batch_size = joints.shape[0]
+ joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1)
+
+ output = dict(global_orient=flame_output.global_orient,
+ neck_pose=flame_output.neck_pose,
+ jaw_pose=flame_output.jaw_pose,
+ joints=joints,
+ joint_mask=joint_mask,
+ keypoints=torch.cat([joints, joint_mask[:, :, None]],
+ dim=-1),
+ betas=flame_output.betas,
+ expression=flame_output.expression)
+
+ if return_verts:
+ output['vertices'] = flame_output.vertices
+ if return_full_pose:
+ output['full_pose'] = flame_output.full_pose
+
+ return output
+
+
+class FLAMELayer(_FLAMELayer):
+ """Extension of the official FLAME implementation."""
+ head_pose_keys = {'global_orient', 'jaw_pose'}
+ full_pose_keys = {
+ 'global_orient', 'neck_pose', 'jaw_pose', 'leye_pose', 'reye_pose'
+ }
+
+ NUM_VERTS = 5023
+ NUM_FACES = 9976
+
+ def __init__(self,
+ *args,
+ keypoint_src: str = 'flame',
+ keypoint_dst: str = 'human_data',
+ keypoint_approximate: bool = False,
+ **kwargs):
+ """
+ Args:
+ *args: extra arguments for FLAME initialization.
+ keypoint_src: source convention of keypoints. This convention
+ is used for keypoints obtained from joint regressors.
+ Keypoints then undergo conversion into keypoint_dst
+ convention.
+ keypoint_dst: destination convention of keypoints. This convention
+ is used for keypoints in the output.
+ keypoint_approximate: whether to use approximate matching in
+ convention conversion for keypoints.
+ **kwargs: extra keyword arguments for FLAME initialization.
+
+ Returns:
+ None
+ """
+ super(FLAMELayer, self).__init__(*args, **kwargs)
+ self.keypoint_src = keypoint_src
+ self.keypoint_dst = keypoint_dst
+ self.keypoint_approximate = keypoint_approximate
+
+ self.num_verts = self.get_num_verts()
+ self.num_faces = self.get_num_faces()
+ self.num_joints = get_keypoint_num(convention=self.keypoint_dst)
+
+ def forward(self,
+ *args,
+ return_verts: bool = True,
+ return_full_pose: bool = False,
+ **kwargs) -> dict:
+ """Forward function.
+
+ Args:
+ *args: extra arguments for FLAME
+ return_verts: whether to return vertices
+ return_full_pose: whether to return full pose parameters
+ **kwargs: extra arguments for FLAME
+
+ Returns:
+ output: contains output parameters and attributes
+ """
+ flame_output = super(FLAMELayer, self).forward(*args, **kwargs)
+ joints = flame_output.joints
+ joints, joint_mask = convert_kps(joints,
+ src=self.keypoint_src,
+ dst=self.keypoint_dst,
+ approximate=self.keypoint_approximate)
+ if isinstance(joint_mask, np.ndarray):
+ joint_mask = torch.tensor(joint_mask,
+ dtype=torch.uint8,
+ device=joints.device)
+
+ batch_size = joints.shape[0]
+ joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1)
+
+ output = dict(global_orient=flame_output.global_orient,
+ neck_pose=flame_output.neck_pose,
+ jaw_pose=flame_output.jaw_pose,
+ joints=joints,
+ joint_mask=joint_mask,
+ keypoints=torch.cat([joints, joint_mask[:, :, None]],
+ dim=-1),
+ betas=flame_output.betas,
+ expression=flame_output.expression)
+
+ if return_verts:
+ output['vertices'] = flame_output.vertices
+ if return_full_pose:
+ output['full_pose'] = flame_output.full_pose
+
+ return output
diff --git a/detrsmpl/models/body_models/mano.py b/detrsmpl/models/body_models/mano.py
new file mode 100644
index 0000000000000000000000000000000000000000..124d95d051dcaffa51fe1e34fd737a0c22d658bf
--- /dev/null
+++ b/detrsmpl/models/body_models/mano.py
@@ -0,0 +1,271 @@
+import numpy as np
+import torch
+from smplx import MANO as _MANO
+from smplx import MANOLayer as _MANOLayer
+
+from detrsmpl.core.conventions.keypoints_mapping import (
+ convert_kps,
+ get_keypoint_num,
+)
+
+
+class MANO(_MANO):
+ """Extension of the official MANO implementation."""
+ full_pose_keys = {'global_orient', 'hand_pose'}
+
+ NUM_VERTS = 776
+ NUM_FACES = 9976
+
+ KpId2manokps = {
+ 0: 0, # Wrist
+ 1: 5,
+ 2: 6,
+ 3: 7, # Index
+ 4: 9,
+ 5: 10,
+ 6: 11, # Middle
+ 7: 17,
+ 8: 18,
+ 9: 19, # Pinky
+ 10: 13,
+ 11: 14,
+ 12: 15, # Ring
+ 13: 1,
+ 14: 2,
+ 15: 3
+ } # Thumb
+ kpId2vertices = {
+ 4: 744, # Thumb
+ 8: 320, # Index
+ 12: 443, # Middle
+ 16: 555, # Ring
+ 20: 672 # Pink
+ }
+
+ def __init__(self,
+ *args,
+ keypoint_src: str = 'mano',
+ keypoint_dst: str = 'human_data',
+ keypoint_approximate: bool = False,
+ **kwargs):
+ """
+ Args:
+ *args: extra arguments for MANO initialization.
+ keypoint_src: source convention of keypoints. This convention
+ is used for keypoints obtained from joint regressors.
+ Keypoints then undergo conversion into keypoint_dst
+ convention.
+ keypoint_dst: destination convention of keypoints. This convention
+ is used for keypoints in the output.
+ keypoint_approximate: whether to use approximate matching in
+ convention conversion for keypoints.
+ **kwargs: extra keyword arguments for MANO initialization.
+
+ Returns:
+ None
+ """
+ super(MANO, self).__init__(*args, **kwargs)
+ self.keypoint_src = keypoint_src
+ self.keypoint_dst = keypoint_dst
+ self.keypoint_approximate = keypoint_approximate
+
+ self.num_verts = self.get_num_verts()
+ self.num_faces = self.get_num_faces()
+ self.num_joints = get_keypoint_num(convention=self.keypoint_dst)
+
+ def forward(self,
+ *args,
+ return_verts: bool = True,
+ return_full_pose: bool = False,
+ **kwargs) -> dict:
+ """Forward function.
+
+ Args:
+ *args: extra arguments for MANO
+ return_verts: whether to return vertices
+ return_full_pose: whether to return full pose parameters
+ **kwargs: extra arguments for MANO
+
+ Returns:
+ output: contains output parameters and attributes
+ """
+ if 'right_hand_pose' in kwargs:
+ kwargs['hand_pose'] = kwargs['right_hand_pose']
+ mano_output = super(MANO, self).forward(*args, **kwargs)
+ joints = mano_output.joints
+
+ joints = self.get_keypoints_from_mesh(mano_output.vertices, joints)
+
+ joints, joint_mask = convert_kps(joints,
+ src=self.keypoint_src,
+ dst=self.keypoint_dst,
+ approximate=self.keypoint_approximate)
+ if isinstance(joint_mask, np.ndarray):
+ joint_mask = torch.tensor(joint_mask,
+ dtype=torch.uint8,
+ device=joints.device)
+
+ batch_size = joints.shape[0]
+ joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1)
+
+ output = dict(
+ global_orient=mano_output.global_orient,
+ hand_pose=mano_output.hand_pose,
+ joints=joints,
+ joint_mask=joint_mask,
+ keypoints=torch.cat([joints, joint_mask[:, :, None]], dim=-1),
+ betas=mano_output.betas,
+ )
+
+ if return_verts:
+ output['vertices'] = mano_output.vertices
+ if return_full_pose:
+ output['full_pose'] = mano_output.full_pose
+
+ return output
+
+ def get_keypoints_from_mesh(self, mesh_vertices, keypoints_regressed):
+ """Assembles the full 21 keypoint set from the 16 Mano Keypoints and 5
+ mesh vertices for the fingers."""
+ batch_size = keypoints_regressed.shape[0]
+ keypoints = torch.zeros((batch_size, 21, 3)).cuda()
+
+ # fill keypoints which are regressed
+ for manoId, myId in self.KpId2manokps.items():
+ keypoints[:, myId, :] = keypoints_regressed[:, manoId, :]
+ # get other keypoints from mesh
+ for myId, meshId in self.kpId2vertices.items():
+ keypoints[:, myId, :] = mesh_vertices[:, meshId, :]
+
+ return keypoints
+
+
+class MANOLayer(_MANOLayer):
+ """Extension of the official MANO implementation."""
+ full_pose_keys = {'global_orient', 'hand_pose'}
+
+ NUM_VERTS = 776
+ NUM_FACES = 9976
+
+ KpId2manokps = {
+ 0: 0, # Wrist
+ 1: 5,
+ 2: 6,
+ 3: 7, # Index
+ 4: 9,
+ 5: 10,
+ 6: 11, # Middle
+ 7: 17,
+ 8: 18,
+ 9: 19, # Pinky
+ 10: 13,
+ 11: 14,
+ 12: 15, # Ring
+ 13: 1,
+ 14: 2,
+ 15: 3
+ } # Thumb
+ kpId2vertices = {
+ 4: 744, # Thumb
+ 8: 320, # Index
+ 12: 443, # Middle
+ 16: 555, # Ring
+ 20: 672 # Pink
+ }
+
+ def __init__(self,
+ *args,
+ keypoint_src: str = 'mano',
+ keypoint_dst: str = 'human_data',
+ keypoint_approximate: bool = False,
+ **kwargs):
+ """
+ Args:
+ *args: extra arguments for MANO initialization.
+ keypoint_src: source convention of keypoints. This convention
+ is used for keypoints obtained from joint regressors.
+ Keypoints then undergo conversion into keypoint_dst
+ convention.
+ keypoint_dst: destination convention of keypoints. This convention
+ is used for keypoints in the output.
+ keypoint_approximate: whether to use approximate matching in
+ convention conversion for keypoints.
+ **kwargs: extra keyword arguments for MANO initialization.
+
+ Returns:
+ None
+ """
+ super(MANOLayer, self).__init__(*args, **kwargs)
+ self.keypoint_src = keypoint_src
+ self.keypoint_dst = keypoint_dst
+ self.keypoint_approximate = keypoint_approximate
+
+ self.num_verts = self.get_num_verts()
+ self.num_faces = self.get_num_faces()
+ self.num_joints = get_keypoint_num(convention=self.keypoint_dst)
+
+ def forward(self,
+ *args,
+ return_verts: bool = True,
+ return_full_pose: bool = False,
+ **kwargs) -> dict:
+ """Forward function.
+
+ Args:
+ *args: extra arguments for MANO
+ return_verts: whether to return vertices
+ return_full_pose: whether to return full pose parameters
+ **kwargs: extra arguments for MANO
+
+ Returns:
+ output: contains output parameters and attributes
+ """
+ if 'right_hand_pose' in kwargs:
+ kwargs['hand_pose'] = kwargs['right_hand_pose']
+ mano_output = super(MANOLayer, self).forward(*args, **kwargs)
+ joints = mano_output.joints
+
+ joints = self.get_keypoints_from_mesh(mano_output.vertices, joints)
+
+ joints, joint_mask = convert_kps(joints,
+ src=self.keypoint_src,
+ dst=self.keypoint_dst,
+ approximate=self.keypoint_approximate)
+ if isinstance(joint_mask, np.ndarray):
+ joint_mask = torch.tensor(joint_mask,
+ dtype=torch.uint8,
+ device=joints.device)
+
+ batch_size = joints.shape[0]
+ joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1)
+
+ output = dict(
+ global_orient=mano_output.global_orient,
+ hand_pose=mano_output.hand_pose,
+ joints=joints,
+ joint_mask=joint_mask,
+ keypoints=torch.cat([joints, joint_mask[:, :, None]], dim=-1),
+ betas=mano_output.betas,
+ )
+
+ if return_verts:
+ output['vertices'] = mano_output.vertices
+ if return_full_pose:
+ output['full_pose'] = mano_output.full_pose
+
+ return output
+
+ def get_keypoints_from_mesh(self, mesh_vertices, keypoints_regressed):
+ """Assembles the full 21 keypoint set from the 16 Mano Keypoints and 5
+ mesh vertices for the fingers."""
+ batch_size = keypoints_regressed.shape[0]
+ keypoints = torch.zeros((batch_size, 21, 3)).cuda()
+
+ # fill keypoints which are regressed
+ for manoId, myId in self.KpId2manokps.items():
+ keypoints[:, myId, :] = keypoints_regressed[:, manoId, :]
+ # get other keypoints from mesh
+ for myId, meshId in self.kpId2vertices.items():
+ keypoints[:, myId, :] = mesh_vertices[:, meshId, :]
+
+ return keypoints
diff --git a/detrsmpl/models/body_models/smpl.py b/detrsmpl/models/body_models/smpl.py
new file mode 100644
index 0000000000000000000000000000000000000000..215e4e2f9ef0134eb69422e73f01206c8ee5741f
--- /dev/null
+++ b/detrsmpl/models/body_models/smpl.py
@@ -0,0 +1,610 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from typing import Optional
+
+import numpy as np
+import torch
+from smplx import SMPL as _SMPL
+from smplx.lbs import batch_rigid_transform, blend_shapes, vertices2joints
+
+from detrsmpl.core.conventions.keypoints_mapping import (
+ convert_kps,
+ get_keypoint_num,
+)
+from detrsmpl.core.conventions.segmentation import body_segmentation
+from detrsmpl.models.utils import batch_inverse_kinematics_transform
+from detrsmpl.utils.transforms import quat_to_rotmat
+
+
+class SMPL(_SMPL):
+ """Extension of the official SMPL implementation."""
+
+ body_pose_keys = {
+ 'global_orient',
+ 'body_pose',
+ }
+ full_pose_keys = {
+ 'global_orient',
+ 'body_pose',
+ }
+ NUM_VERTS = 6890
+ NUM_FACES = 13776
+
+ def __init__(self,
+ *args,
+ keypoint_src: str = 'smpl_45',
+ keypoint_dst: str = 'human_data',
+ keypoint_approximate: bool = False,
+ joints_regressor: str = None,
+ extra_joints_regressor: str = None,
+ **kwargs) -> None:
+ """
+ Args:
+ *args: extra arguments for SMPL initialization.
+ keypoint_src: source convention of keypoints. This convention
+ is used for keypoints obtained from joint regressors.
+ Keypoints then undergo conversion into keypoint_dst
+ convention.
+ keypoint_dst: destination convention of keypoints. This convention
+ is used for keypoints in the output.
+ keypoint_approximate: whether to use approximate matching in
+ convention conversion for keypoints.
+ joints_regressor: path to joint regressor. Should be a .npy
+ file. If provided, replaces the official J_regressor of SMPL.
+ extra_joints_regressor: path to extra joint regressor. Should be
+ a .npy file. If provided, extra joints are regressed and
+ concatenated after the joints regressed with the official
+ J_regressor or joints_regressor.
+ **kwargs: extra keyword arguments for SMPL initialization.
+
+ Returns:
+ None
+ """
+ super(SMPL, self).__init__(*args, **kwargs)
+ # joints = [JOINT_MAP[i] for i in JOINT_NAMES]
+ self.keypoint_src = keypoint_src
+ self.keypoint_dst = keypoint_dst
+ self.keypoint_approximate = keypoint_approximate
+ # override the default SMPL joint regressor if available
+ if joints_regressor is not None:
+ joints_regressor = torch.tensor(np.load(joints_regressor),
+ dtype=torch.float)
+ self.register_buffer('joints_regressor', joints_regressor)
+
+ # allow for extra joints to be regressed if available
+ if extra_joints_regressor is not None:
+ joints_regressor_extra = torch.tensor(
+ np.load(extra_joints_regressor), dtype=torch.float)
+ self.register_buffer('joints_regressor_extra',
+ joints_regressor_extra)
+
+ self.num_verts = self.get_num_verts()
+ self.num_joints = get_keypoint_num(convention=self.keypoint_dst)
+ self.body_part_segmentation = body_segmentation('smpl')
+
+ def forward(self,
+ *args,
+ return_verts: bool = True,
+ return_full_pose: bool = False,
+ **kwargs) -> dict:
+ """Forward function.
+
+ Args:
+ *args: extra arguments for SMPL
+ return_verts: whether to return vertices
+ return_full_pose: whether to return full pose parameters
+ **kwargs: extra arguments for SMPL
+
+ Returns:
+ output: contains output parameters and attributes
+ """
+
+ kwargs['get_skin'] = True
+ smpl_output = super(SMPL, self).forward(*args, **kwargs)
+
+ if not hasattr(self, 'joints_regressor'):
+ joints = smpl_output.joints
+ else:
+ joints = vertices2joints(self.joints_regressor,
+ smpl_output.vertices)
+
+ if hasattr(self, 'joints_regressor_extra'):
+ extra_joints = vertices2joints(self.joints_regressor_extra,
+ smpl_output.vertices)
+ joints = torch.cat([joints, extra_joints], dim=1)
+
+ joints, joint_mask = convert_kps(joints,
+ src=self.keypoint_src,
+ dst=self.keypoint_dst,
+ approximate=self.keypoint_approximate)
+ if isinstance(joint_mask, np.ndarray):
+ joint_mask = torch.tensor(joint_mask,
+ dtype=torch.uint8,
+ device=joints.device)
+
+ batch_size = joints.shape[0]
+ joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1)
+
+ output = dict(global_orient=smpl_output.global_orient,
+ body_pose=smpl_output.body_pose,
+ joints=joints,
+ joint_mask=joint_mask,
+ keypoints=torch.cat([joints, joint_mask[:, :, None]],
+ dim=-1),
+ betas=smpl_output.betas)
+
+ if return_verts:
+ output['vertices'] = smpl_output.vertices
+ if return_full_pose:
+ output['full_pose'] = smpl_output.full_pose
+
+ return output
+
+ @classmethod
+ def tensor2dict(cls,
+ full_pose: torch.Tensor,
+ betas: Optional[torch.Tensor] = None,
+ transl: Optional[torch.Tensor] = None):
+ """Convert full pose tensor to pose dict.
+
+ Args:
+ full_pose (torch.Tensor): shape should be (..., 165) or
+ (..., 55, 3). All zeros for T-pose.
+ betas (Optional[torch.Tensor], optional): shape should be
+ (..., 10). The batch num should be 1 or corresponds with
+ full_pose.
+ Defaults to None.
+ transl (Optional[torch.Tensor], optional): shape should be
+ (..., 3). The batch num should be 1 or corresponds with
+ full_pose.
+ Defaults to None.
+ Returns:
+ dict: dict of smpl pose containing transl & betas.
+ """
+ full_pose = full_pose.view(-1, (cls.NUM_BODY_JOINTS + 1) * 3)
+ body_pose = full_pose[:, 3:]
+ global_orient = full_pose[:, :3]
+ batch_size = full_pose.shape[0]
+ if betas is not None:
+ # squeeze or unsqueeze betas to 2 dims
+ betas = betas.view(-1, betas.shape[-1])
+ if betas.shape[0] == 1:
+ betas = betas.repeat(batch_size, 1)
+ else:
+ betas = betas
+ transl = transl.view(batch_size, -1) if transl is not None else transl
+ return {
+ 'betas': betas,
+ 'body_pose': body_pose,
+ 'global_orient': global_orient,
+ 'transl': transl,
+ }
+
+ @classmethod
+ def dict2tensor(cls, smpl_dict: dict) -> torch.Tensor:
+ """Convert smpl pose dict to full pose tensor.
+
+ Args:
+ smpl_dict (dict): smpl pose dict.
+
+ Returns:
+ torch: full pose tensor.
+ """
+ assert cls.body_pose_keys.issubset(smpl_dict)
+ for k in smpl_dict:
+ if isinstance(smpl_dict[k], np.ndarray):
+ smpl_dict[k] = torch.Tensor(smpl_dict[k])
+ global_orient = smpl_dict['global_orient'].view(-1, 3)
+ body_pose = smpl_dict['body_pose'].view(-1, 3 * cls.NUM_BODY_JOINTS)
+ full_pose = torch.cat([global_orient, body_pose], dim=1)
+ return full_pose
+
+
+class GenderedSMPL(torch.nn.Module):
+ """A wrapper of SMPL to handle gendered inputs."""
+ def __init__(self,
+ *args,
+ keypoint_src: str = 'smpl_45',
+ keypoint_dst: str = 'human_data',
+ keypoint_approximate: bool = False,
+ joints_regressor: str = None,
+ extra_joints_regressor: str = None,
+ **kwargs) -> None:
+ """
+ Args:
+ *args: extra arguments for SMPL initialization.
+ keypoint_src: source convention of keypoints. This convention
+ is used for keypoints obtained from joint regressors.
+ Keypoints then undergo conversion into keypoint_dst
+ convention.
+ keypoint_dst: destination convention of keypoints. This convention
+ is used for keypoints in the output.
+ keypoint_approximate: whether to use approximate matching in
+ convention conversion for keypoints.
+ joints_regressor: path to joint regressor. Should be a .npy
+ file. If provided, replaces the official J_regressor of SMPL.
+ extra_joints_regressor: path to extra joint regressor. Should be
+ a .npy file. If provided, extra joints are regressed and
+ concatenated after the joints regressed with the official
+ J_regressor or joints_regressor.
+ **kwargs: extra keyword arguments for SMPL initialization.
+
+ Returns:
+ None
+ """
+ super(GenderedSMPL, self).__init__()
+
+ assert 'gender' not in kwargs, \
+ self.__class__.__name__ + \
+ 'does not need \'gender\' for initialization.'
+
+ self.smpl_neutral = SMPL(*args,
+ gender='neutral',
+ keypoint_src=keypoint_src,
+ keypoint_dst=keypoint_dst,
+ keypoint_approximate=keypoint_approximate,
+ joints_regressor=joints_regressor,
+ extra_joints_regressor=extra_joints_regressor,
+ **kwargs)
+
+ self.smpl_male = SMPL(*args,
+ gender='male',
+ keypoint_src=keypoint_src,
+ keypoint_dst=keypoint_dst,
+ keypoint_approximate=keypoint_approximate,
+ joints_regressor=joints_regressor,
+ extra_joints_regressor=extra_joints_regressor,
+ **kwargs)
+
+ self.smpl_female = SMPL(*args,
+ gender='female',
+ keypoint_src=keypoint_src,
+ keypoint_dst=keypoint_dst,
+ keypoint_approximate=keypoint_approximate,
+ joints_regressor=joints_regressor,
+ extra_joints_regressor=extra_joints_regressor,
+ **kwargs)
+
+ self.num_verts = self.smpl_neutral.num_verts
+ self.num_joints = self.smpl_neutral.num_joints
+ self.faces = self.smpl_neutral.faces
+
+ def forward(self,
+ *args,
+ betas: torch.Tensor = None,
+ body_pose: torch.Tensor = None,
+ global_orient: torch.Tensor = None,
+ transl: torch.Tensor = None,
+ return_verts: bool = True,
+ return_full_pose: bool = False,
+ gender: torch.Tensor = None,
+ device=None,
+ **kwargs):
+ """Forward function.
+ Note:
+ B: batch size
+ J: number of joints of model, J = 23 (SMPL)
+ K: number of keypoints
+ Args:
+ *args: extra arguments
+ betas: Tensor([B, 10]), human body shape parameters of SMPL model.
+ body_pose: Tensor([B, J*3] or [B, J, 3, 3]), human body pose
+ parameters of SMPL model. It should be axis-angle vector
+ ([B, J*3]) or rotation matrix ([B, J, 3, 3)].
+ global_orient: Tensor([B, 3] or [B, 1, 3, 3]), global orientation
+ of human body. It should be axis-angle vector ([B, 3]) or
+ rotation matrix ([B, 1, 3, 3)].
+ transl: Tensor([B, 3]), global translation of human body.
+ gender: Tensor([B]), gender parameters of human body. -1 for
+ neutral, 0 for male , 1 for female.
+ device: the device of the output
+ **kwargs: extra keyword arguments
+ Returns:
+ outputs (dict): Dict with mesh vertices and joints.
+ - vertices: Tensor([B, V, 3]), mesh vertices
+ - joints: Tensor([B, K, 3]), 3d keypoints regressed from
+ mesh vertices.
+ """
+
+ batch_size = None
+ for attr in [betas, body_pose, global_orient, transl]:
+ if attr is not None:
+ if device is None:
+ device = attr.device
+ if batch_size is None:
+ batch_size = attr.shape[0]
+ else:
+ assert batch_size == attr.shape[0]
+
+ if gender is not None:
+ output = {
+ 'vertices':
+ torch.zeros([batch_size, self.num_verts, 3], device=device),
+ 'joints':
+ torch.zeros([batch_size, self.num_joints, 3], device=device),
+ 'joint_mask':
+ torch.zeros([batch_size, self.num_joints],
+ dtype=torch.uint8,
+ device=device)
+ }
+
+ for body_model, gender_label in \
+ [(self.smpl_neutral, -1),
+ (self.smpl_male, 0),
+ (self.smpl_female, 1)]:
+ gender_idxs = gender == gender_label
+
+ # skip if no such gender is present
+ if gender_idxs.sum() == 0:
+ continue
+
+ output_model = body_model(
+ betas=betas[gender_idxs] if betas is not None else None,
+ body_pose=body_pose[gender_idxs]
+ if body_pose is not None else None,
+ global_orient=global_orient[gender_idxs]
+ if global_orient is not None else None,
+ transl=transl[gender_idxs] if transl is not None else None,
+ **kwargs)
+
+ output['joints'][gender_idxs] = output_model['joints']
+
+ # TODO: quick fix
+ if 'joint_mask' in output_model:
+ output['joint_mask'][gender_idxs] = output_model[
+ 'joint_mask']
+
+ if return_verts:
+ output['vertices'][gender_idxs] = output_model['vertices']
+ if return_full_pose:
+ output['full_pose'][gender_idxs] = output_model[
+ 'full_pose']
+ else:
+ output = self.smpl_neutral(betas=betas,
+ body_pose=body_pose,
+ global_orient=global_orient,
+ transl=transl,
+ **kwargs)
+
+ return output
+
+
+def to_tensor(array, dtype=torch.float32):
+ if 'torch.tensor' not in str(type(array)):
+ return torch.tensor(array, dtype=dtype)
+
+
+def to_np(array, dtype=np.float32):
+ if 'scipy.sparse' in str(type(array)):
+ array = array.todense()
+ return np.array(array, dtype=dtype)
+
+
+class HybrIKSMPL(SMPL):
+ """Extension of the SMPL for HybrIK."""
+
+ NUM_JOINTS = 23
+ NUM_BODY_JOINTS = 23
+ NUM_BETAS = 10
+ JOINT_NAMES = [
+ 'pelvis',
+ 'left_hip',
+ 'right_hip', # 2
+ 'spine1',
+ 'left_knee',
+ 'right_knee', # 5
+ 'spine2',
+ 'left_ankle',
+ 'right_ankle', # 8
+ 'spine3',
+ 'left_foot',
+ 'right_foot', # 11
+ 'neck',
+ 'left_collar',
+ 'right_collar', # 14
+ 'jaw', # 15
+ 'left_shoulder',
+ 'right_shoulder', # 17
+ 'left_elbow',
+ 'right_elbow', # 19
+ 'left_wrist',
+ 'right_wrist', # 21
+ 'left_thumb',
+ 'right_thumb', # 23
+ 'head',
+ 'left_middle',
+ 'right_middle', # 26
+ 'left_bigtoe',
+ 'right_bigtoe' # 28
+ ]
+ LEAF_NAMES = [
+ 'head', 'left_middle', 'right_middle', 'left_bigtoe', 'right_bigtoe'
+ ]
+ root_idx_17 = 0
+ root_idx_smpl = 0
+
+ def __init__(self, *args, extra_joints_regressor=None, **kwargs):
+ """
+ Args:
+ *args: extra arguments for SMPL initialization.
+ extra_joints_regressor: path to extra joint regressor. Should be
+ a .npy file. If provided, extra joints are regressed and
+ concatenated after the joints regressed with the official
+ J_regressor or joints_regressor.
+ **kwargs: extra keyword arguments for SMPL initialization.
+
+ Returns:
+ None
+ """
+ super(HybrIKSMPL,
+ self).__init__(*args,
+ extra_joints_regressor=extra_joints_regressor,
+ create_betas=False,
+ create_global_orient=False,
+ create_body_pose=False,
+ create_transl=False,
+ **kwargs)
+
+ self.dtype = torch.float32
+ self.num_joints = 29
+
+ self.ROOT_IDX = self.JOINT_NAMES.index('pelvis')
+ self.LEAF_IDX = [
+ self.JOINT_NAMES.index(name) for name in self.LEAF_NAMES
+ ]
+ self.SPINE3_IDX = 9
+ # # indices of parents for each joints
+ parents = torch.zeros(len(self.JOINT_NAMES), dtype=torch.long)
+ # extend kinematic tree
+ parents[:24] = self.parents
+ parents[24] = 15
+ parents[25] = 22
+ parents[26] = 23
+ parents[27] = 10
+ parents[28] = 11
+ if parents.shape[0] > self.num_joints:
+ parents = parents[:24]
+ self.register_buffer('children_map',
+ self._parents_to_children(parents))
+ self.parents = parents
+
+ def _parents_to_children(self, parents):
+ children = torch.ones_like(parents) * -1
+ for i in range(self.num_joints):
+ if children[parents[i]] < 0:
+ children[parents[i]] = i
+ for i in self.LEAF_IDX:
+ if i < children.shape[0]:
+ children[i] = -1
+
+ children[self.SPINE3_IDX] = -3
+ children[0] = 3
+ children[self.SPINE3_IDX] = self.JOINT_NAMES.index('neck')
+
+ return children
+
+ def forward(self,
+ pose_skeleton,
+ betas,
+ phis,
+ global_orient,
+ transl=None,
+ return_verts=True,
+ leaf_thetas=None):
+ """Inverse pass for the SMPL model.
+
+ Args:
+ pose_skeleton: torch.tensor, optional, shape Bx(J*3)
+ It should be a tensor that contains joint locations in
+ (img, Y, Z) format. (default=None)
+ betas: torch.tensor, optional, shape Bx10
+ It can used if shape parameters
+ `betas` are predicted from some external model.
+ (default=None)
+ phis: torch.tensor, shape Bx23x2
+ Rotation on bone axis parameters
+ global_orient: torch.tensor, optional, shape Bx3
+ Global Orientations.
+ transl: torch.tensor, optional, shape Bx3
+ Global Translations.
+ return_verts: bool, optional
+ Return the vertices. (default=True)
+ leaf_thetas: torch.tensor, optional, shape Bx5x4
+ Quaternions of 5 leaf joints. (default=None)
+
+ Returns
+ outputs: output dictionary.
+ """
+ batch_size = pose_skeleton.shape[0]
+
+ if leaf_thetas is not None:
+ leaf_thetas = leaf_thetas.reshape(batch_size * 5, 4)
+ leaf_thetas = quat_to_rotmat(leaf_thetas)
+
+ batch_size = max(betas.shape[0], pose_skeleton.shape[0])
+ device = betas.device
+
+ # 1. Add shape contribution
+ v_shaped = self.v_template + blend_shapes(betas, self.shapedirs)
+
+ # 2. Get the rest joints
+ # NxJx3 array
+ if leaf_thetas is not None:
+ rest_J = vertices2joints(self.J_regressor, v_shaped)
+ else:
+ rest_J = torch.zeros((v_shaped.shape[0], 29, 3),
+ dtype=self.dtype,
+ device=device)
+ rest_J[:, :24] = vertices2joints(self.J_regressor, v_shaped)
+
+ leaf_number = [411, 2445, 5905, 3216, 6617]
+ leaf_vertices = v_shaped[:, leaf_number].clone()
+ rest_J[:, 24:] = leaf_vertices
+
+ # 3. Get the rotation matrics
+ rot_mats, rotate_rest_pose = batch_inverse_kinematics_transform(
+ pose_skeleton,
+ global_orient,
+ phis,
+ rest_J.clone(),
+ self.children_map,
+ self.parents,
+ dtype=self.dtype,
+ train=self.training,
+ leaf_thetas=leaf_thetas)
+
+ test_joints = True
+ if test_joints:
+ new_joints, A = batch_rigid_transform(rot_mats,
+ rest_J[:, :24].clone(),
+ self.parents[:24],
+ dtype=self.dtype)
+ else:
+ new_joints = None
+
+ # assert torch.mean(torch.abs(rotate_rest_pose - new_joints)) < 1e-5
+ # 4. Add pose blend shapes
+ # rot_mats: N x (J + 1) x 3 x 3
+ ident = torch.eye(3, dtype=self.dtype, device=device)
+ pose_feature = (rot_mats[:, 1:] - ident).view([batch_size, -1])
+ pose_offsets = torch.matmul(pose_feature, self.posedirs) \
+ .view(batch_size, -1, 3)
+
+ v_posed = pose_offsets + v_shaped
+
+ # 5. Do skinning:
+ # W is N x V x (J + 1)
+ W = self.lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
+ # (N x V x (J + 1)) x (N x (J + 1) x 16)
+ num_joints = self.J_regressor.shape[0]
+ T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \
+ .view(batch_size, -1, 4, 4)
+
+ homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],
+ dtype=self.dtype,
+ device=device)
+ v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
+ v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
+
+ vertices = v_homo[:, :, :3, 0]
+ joints_from_verts = vertices2joints(self.joints_regressor_extra,
+ vertices)
+
+ # rot_mats = rot_mats.reshape(batch_size * 24, 3, 3)
+ if transl is not None:
+ new_joints += transl.unsqueeze(dim=1)
+ vertices += transl.unsqueeze(dim=1)
+ joints_from_verts += transl.unsqueeze(dim=1)
+ else:
+ new_joints = new_joints - \
+ new_joints[:, self.root_idx_smpl, :].unsqueeze(1).detach()
+ joints_from_verts = joints_from_verts - \
+ joints_from_verts[:, self.root_idx_17, :].unsqueeze(1).detach()
+
+ output = {
+ 'vertices': vertices,
+ 'joints': new_joints,
+ 'poses': rot_mats,
+ 'joints_from_verts': joints_from_verts,
+ }
+ return output
diff --git a/detrsmpl/models/body_models/smplx.py b/detrsmpl/models/body_models/smplx.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfe1fb723bf46aeca17cc5f1f0a8dae5ec0df5f2
--- /dev/null
+++ b/detrsmpl/models/body_models/smplx.py
@@ -0,0 +1,375 @@
+from typing import Optional
+
+import numpy as np
+import torch
+from smplx import SMPLX as _SMPLX
+from smplx import SMPLXLayer as _SMPLXLayer
+from smplx.lbs import vertices2joints
+
+from detrsmpl.core.conventions.keypoints_mapping import (
+ convert_kps,
+ get_keypoint_num,
+)
+from detrsmpl.core.conventions.segmentation import body_segmentation
+
+
+class SMPLX(_SMPLX):
+ """Extension of the official SMPL-X implementation."""
+
+ body_pose_keys = {'global_orient', 'body_pose'}
+ full_pose_keys = {
+ 'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose',
+ 'jaw_pose', 'leye_pose', 'reye_pose'
+ }
+ NUM_VERTS = 10475
+ NUM_FACES = 20908
+
+ def __init__(self,
+ *args,
+ keypoint_src: str = 'smplx',
+ keypoint_dst: str = 'human_data',
+ keypoint_approximate: bool = False,
+ joints_regressor: str = None,
+ extra_joints_regressor: str = None,
+ **kwargs):
+ """
+ Args:
+ *args: extra arguments for SMPL initialization.
+ keypoint_src: source convention of keypoints. This convention
+ is used for keypoints obtained from joint regressors.
+ Keypoints then undergo conversion into keypoint_dst
+ convention.
+ keypoint_dst: destination convention of keypoints. This convention
+ is used for keypoints in the output.
+ keypoint_approximate: whether to use approximate matching in
+ convention conversion for keypoints.
+ joints_regressor: path to joint regressor. Should be a .npy
+ file. If provided, replaces the official J_regressor of SMPL.
+ extra_joints_regressor: path to extra joint regressor. Should be
+ a .npy file. If provided, extra joints are regressed and
+ concatenated after the joints regressed with the official
+ J_regressor or joints_regressor.
+ **kwargs: extra keyword arguments for SMPL initialization.
+
+ Returns:
+ None
+ """
+ super(SMPLX, self).__init__(*args, **kwargs)
+ # joints = [JOINT_MAP[i] for i in JOINT_NAMES]
+ self.keypoint_src = keypoint_src
+ self.keypoint_dst = keypoint_dst
+ self.keypoint_approximate = keypoint_approximate
+
+ # override the default SMPL joint regressor if available
+ if joints_regressor is not None:
+ joints_regressor = torch.tensor(np.load(joints_regressor),
+ dtype=torch.float)
+ self.register_buffer('joints_regressor', joints_regressor)
+
+ # allow for extra joints to be regressed if available
+ if extra_joints_regressor is not None:
+ joints_regressor_extra = torch.tensor(
+ np.load(extra_joints_regressor), dtype=torch.float)
+ self.register_buffer('joints_regressor_extra',
+ joints_regressor_extra)
+
+ self.num_verts = self.get_num_verts()
+ self.num_joints = get_keypoint_num(convention=self.keypoint_dst)
+ self.body_part_segmentation = body_segmentation('smplx')
+
+ def forward(self,
+ *args,
+ return_verts: bool = True,
+ return_full_pose: bool = False,
+ **kwargs) -> dict:
+ """Forward function.
+
+ Args:
+ *args: extra arguments for SMPL
+ return_verts: whether to return vertices
+ return_full_pose: whether to return full pose parameters
+ **kwargs: extra arguments for SMPL
+
+ Returns:
+ output: contains output parameters and attributes
+ """
+
+ kwargs['get_skin'] = True
+ smplx_output = super(SMPLX, self).forward(*args, **kwargs)
+
+ if not hasattr(self, 'joints_regressor'):
+ joints = smplx_output.joints
+ else:
+ joints = vertices2joints(self.joints_regressor,
+ smplx_output.vertices)
+
+ if hasattr(self, 'joints_regressor_extra'):
+ extra_joints = vertices2joints(self.joints_regressor_extra,
+ smplx_output.vertices)
+ joints = torch.cat([joints, extra_joints], dim=1)
+
+ joints, joint_mask = convert_kps(joints,
+ src=self.keypoint_src,
+ dst=self.keypoint_dst,
+ approximate=self.keypoint_approximate)
+ if isinstance(joint_mask, np.ndarray):
+ joint_mask = torch.tensor(joint_mask,
+ dtype=torch.uint8,
+ device=joints.device)
+
+ batch_size = joints.shape[0]
+ joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1)
+
+ output = dict(global_orient=smplx_output.global_orient,
+ body_pose=smplx_output.body_pose,
+ joints=joints,
+ joint_mask=joint_mask,
+ keypoints=torch.cat([joints, joint_mask[:, :, None]],
+ dim=-1),
+ betas=smplx_output.betas)
+
+ if return_verts:
+ output['vertices'] = smplx_output.vertices
+ if return_full_pose:
+ output['full_pose'] = smplx_output.full_pose
+
+ return output
+
+ @classmethod
+ def tensor2dict(cls,
+ full_pose: torch.Tensor,
+ betas: Optional[torch.Tensor] = None,
+ transl: Optional[torch.Tensor] = None,
+ expression: Optional[torch.Tensor] = None) -> dict:
+ """Convert full pose tensor to pose dict.
+
+ Args:
+ full_pose (torch.Tensor): shape should be (..., 165) or
+ (..., 55, 3). All zeros for T-pose.
+ betas (Optional[torch.Tensor], optional): shape should be
+ (..., 10). The batch num should be 1 or corresponds with
+ full_pose.
+ Defaults to None.
+ transl (Optional[torch.Tensor], optional): shape should be
+ (..., 3). The batch num should be 1 or corresponds with
+ full_pose.
+ Defaults to None.
+ expression (Optional[torch.Tensor], optional): shape should
+ be (..., 10). The batch num should be 1 or corresponds with
+ full_pose.
+ Defaults to None.
+
+ Returns:
+ dict: dict of smplx pose containing transl & betas.
+ """
+ NUM_BODY_JOINTS = cls.NUM_BODY_JOINTS
+ NUM_HAND_JOINTS = cls.NUM_HAND_JOINTS
+ NUM_FACE_JOINTS = cls.NUM_FACE_JOINTS
+ NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS
+ full_pose = full_pose.view(-1, (NUM_JOINTS + 1), 3)
+ global_orient = full_pose[:, :1]
+ body_pose = full_pose[:, 1:NUM_BODY_JOINTS + 1]
+ jaw_pose = full_pose[:, NUM_BODY_JOINTS + 1:NUM_BODY_JOINTS + 2]
+ leye_pose = full_pose[:, NUM_BODY_JOINTS + 2:NUM_BODY_JOINTS + 3]
+ reye_pose = full_pose[:, NUM_BODY_JOINTS + 3:NUM_BODY_JOINTS + 4]
+ left_hand_pose = full_pose[:, NUM_BODY_JOINTS + 4:NUM_BODY_JOINTS + 19]
+ right_hand_pose = full_pose[:,
+ NUM_BODY_JOINTS + 19:NUM_BODY_JOINTS + 34]
+ batch_size = body_pose.shape[0]
+ if betas is not None:
+ # squeeze or unsqueeze betas to 2 dims
+ betas = betas.view(-1, betas.shape[-1])
+ if betas.shape[0] == 1:
+ betas = betas.repeat(batch_size, 1)
+ else:
+ betas = betas
+ transl = transl.view(batch_size, -1) if transl is not None else transl
+ expression = expression.view(
+ batch_size, -1) if expression is not None else torch.zeros(
+ batch_size, 10).to(body_pose.device)
+ return {
+ 'betas':
+ betas,
+ 'global_orient':
+ global_orient.view(batch_size, 3),
+ 'body_pose':
+ body_pose.view(batch_size, NUM_BODY_JOINTS * 3),
+ 'left_hand_pose':
+ left_hand_pose.view(batch_size, NUM_HAND_JOINTS * 3),
+ 'right_hand_pose':
+ right_hand_pose.view(batch_size, NUM_HAND_JOINTS * 3),
+ 'transl':
+ transl,
+ 'expression':
+ expression,
+ 'jaw_pose':
+ jaw_pose.view(batch_size, 3),
+ 'leye_pose':
+ leye_pose.view(batch_size, 3),
+ 'reye_pose':
+ reye_pose.view(batch_size, 3),
+ }
+
+ @classmethod
+ def dict2tensor(cls, smplx_dict: dict) -> torch.Tensor:
+ """Convert smplx pose dict to full pose tensor.
+
+ Args:
+ smplx_dict (dict): smplx pose dict.
+
+ Returns:
+ torch: full pose tensor.
+ """
+ assert cls.body_pose_keys.issubset(smplx_dict)
+ for k in smplx_dict:
+ if isinstance(smplx_dict[k], np.ndarray):
+ smplx_dict[k] = torch.Tensor(smplx_dict[k])
+ NUM_BODY_JOINTS = cls.NUM_BODY_JOINTS
+ NUM_HAND_JOINTS = cls.NUM_HAND_JOINTS
+ NUM_FACE_JOINTS = cls.NUM_FACE_JOINTS
+ NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS
+ global_orient = smplx_dict['global_orient'].reshape(-1, 1, 3)
+ body_pose = smplx_dict['body_pose'].reshape(-1, NUM_BODY_JOINTS, 3)
+ batch_size = global_orient.shape[0]
+ jaw_pose = smplx_dict.get('jaw_pose', torch.zeros((batch_size, 1, 3)))
+ leye_pose = smplx_dict.get('leye_pose', torch.zeros(
+ (batch_size, 1, 3)))
+ reye_pose = smplx_dict.get('reye_pose', torch.zeros(
+ (batch_size, 1, 3)))
+ left_hand_pose = smplx_dict.get(
+ 'left_hand_pose', torch.zeros((batch_size, NUM_HAND_JOINTS, 3)))
+ right_hand_pose = smplx_dict.get(
+ 'right_hand_pose', torch.zeros((batch_size, NUM_HAND_JOINTS, 3)))
+ full_pose = torch.cat([
+ global_orient, body_pose,
+ jaw_pose.reshape(-1, 1, 3),
+ leye_pose.reshape(-1, 1, 3),
+ reye_pose.reshape(-1, 1, 3),
+ left_hand_pose.reshape(-1, 15, 3),
+ right_hand_pose.reshape(-1, 15, 3)
+ ],
+ dim=1).reshape(-1, (NUM_JOINTS + 1) * 3)
+ return full_pose
+
+
+class SMPLXLayer(_SMPLXLayer):
+ """Extension of the official SMPL-X implementation."""
+
+ body_pose_keys = {'global_orient', 'body_pose'}
+ full_pose_keys = {
+ 'global_orient', 'body_pose', 'left_hand_pose', 'right_hand_pose',
+ 'jaw_pose', 'leye_pose', 'reye_pose'
+ }
+ NUM_VERTS = 10475
+ NUM_FACES = 20908
+
+ def __init__(self,
+ *args,
+ keypoint_src: str = 'smplx',
+ keypoint_dst: str = 'human_data',
+ keypoint_approximate: bool = False,
+ joints_regressor: str = None,
+ extra_joints_regressor: str = None,
+ **kwargs):
+ """
+ Args:
+ *args: extra arguments for SMPL initialization.
+ keypoint_src: source convention of keypoints. This convention
+ is used for keypoints obtained from joint regressors.
+ Keypoints then undergo conversion into keypoint_dst
+ convention.
+ keypoint_dst: destination convention of keypoints. This convention
+ is used for keypoints in the output.
+ keypoint_approximate: whether to use approximate matching in
+ convention conversion for keypoints.
+ joints_regressor: path to joint regressor. Should be a .npy
+ file. If provided, replaces the official J_regressor of SMPL.
+ extra_joints_regressor: path to extra joint regressor. Should be
+ a .npy file. If provided, extra joints are regressed and
+ concatenated after the joints regressed with the official
+ J_regressor or joints_regressor.
+ **kwargs: extra keyword arguments for SMPL initialization.
+
+ Returns:
+ None
+ """
+ super(SMPLXLayer, self).__init__(*args, **kwargs)
+ # joints = [JOINT_MAP[i] for i in JOINT_NAMES]
+ self.keypoint_src = keypoint_src
+ self.keypoint_dst = keypoint_dst
+ self.keypoint_approximate = keypoint_approximate
+
+ # override the default SMPL joint regressor if available
+ if joints_regressor is not None:
+ joints_regressor = torch.tensor(np.load(joints_regressor),
+ dtype=torch.float)
+ self.register_buffer('joints_regressor', joints_regressor)
+
+ # allow for extra joints to be regressed if available
+ if extra_joints_regressor is not None:
+ joints_regressor_extra = torch.tensor(
+ np.load(extra_joints_regressor), dtype=torch.float)
+ self.register_buffer('joints_regressor_extra',
+ joints_regressor_extra)
+
+ self.num_verts = self.get_num_verts()
+ self.num_joints = get_keypoint_num(convention=self.keypoint_dst)
+ self.body_part_segmentation = body_segmentation('smplx')
+
+ def forward(self,
+ *args,
+ return_verts: bool = True,
+ return_full_pose: bool = False,
+ **kwargs) -> dict:
+ """Forward function.
+
+ Args:
+ *args: extra arguments for SMPL
+ return_verts: whether to return vertices
+ return_full_pose: whether to return full pose parameters
+ **kwargs: extra arguments for SMPL
+
+ Returns:
+ output: contains output parameters and attributes
+ """
+
+ kwargs['get_skin'] = True
+ smplx_output = super(SMPLXLayer, self).forward(*args, **kwargs)
+
+ if not hasattr(self, 'joints_regressor'):
+ joints = smplx_output.joints
+ else:
+ joints = vertices2joints(self.joints_regressor,
+ smplx_output.vertices)
+
+ if hasattr(self, 'joints_regressor_extra'):
+ extra_joints = vertices2joints(self.joints_regressor_extra,
+ smplx_output.vertices)
+ joints = torch.cat([joints, extra_joints], dim=1)
+
+ joints, joint_mask = convert_kps(joints,
+ src=self.keypoint_src,
+ dst=self.keypoint_dst,
+ approximate=self.keypoint_approximate)
+ if isinstance(joint_mask, np.ndarray):
+ joint_mask = torch.tensor(joint_mask,
+ dtype=torch.uint8,
+ device=joints.device)
+
+ batch_size = joints.shape[0]
+ joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1)
+
+ output = dict(global_orient=smplx_output.global_orient,
+ body_pose=smplx_output.body_pose,
+ joints=joints,
+ joint_mask=joint_mask,
+ keypoints=torch.cat([joints, joint_mask[:, :, None]],
+ dim=-1),
+ betas=smplx_output.betas)
+
+ if return_verts:
+ output['vertices'] = smplx_output.vertices
+ if return_full_pose:
+ output['full_pose'] = smplx_output.full_pose
+
+ return output
diff --git a/detrsmpl/models/body_models/star.py b/detrsmpl/models/body_models/star.py
new file mode 100644
index 0000000000000000000000000000000000000000..b40531c77091d81c1586eb53d34d38a68bcc0d91
--- /dev/null
+++ b/detrsmpl/models/body_models/star.py
@@ -0,0 +1,333 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import os
+from typing import Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from detrsmpl.core.conventions.keypoints_mapping import convert_kps
+from detrsmpl.utils.transforms import (
+ aa_to_rotmat,
+ make_homegeneous_rotmat_batch,
+)
+
+
+class STAR(nn.Module):
+
+ NUM_BODY_JOINTS = 24
+
+ def __init__(self,
+ model_path: str,
+ gender: str = 'neutral',
+ keypoint_src: str = 'star',
+ keypoint_dst: str = 'human_data',
+ keypoint_approximate: bool = False,
+ create_global_orient: bool = True,
+ global_orient: Optional[torch.Tensor] = None,
+ create_body_pose: bool = True,
+ body_pose: torch.Tensor = None,
+ num_betas: int = 10,
+ create_betas: bool = True,
+ betas: torch.Tensor = None,
+ create_transl: bool = True,
+ transl: torch.Tensor = None,
+ batch_size: int = 1,
+ dtype: torch.dtype = torch.float32) -> None:
+ """STAR model constructor.
+
+ Args:
+ model_path: str
+ The path to the folder or to the file where the model
+ parameters are stored.
+ gender: str, optional
+ Which gender to load.
+ keypoint_src: str
+ Source convention of keypoints. This convention is used for
+ keypoints obtained from joint regressors. Keypoints then
+ undergo conversion into keypoint_dst convention.
+ keypoint_dst: destination convention of keypoints. This convention
+ is used for keypoints in the output.
+ keypoint_approximate: whether to use approximate matching in
+ convention conversion for keypoints.
+ create_global_orient: bool, optional
+ Flag for creating a member variable for the global orientation
+ of the body. (default = True)
+ global_orient: torch.tensor, optional, Bx3
+ The default value for the global orientation variable.
+ (default = None)
+ create_body_pose: bool, optional
+ Flag for creating a member variable for the pose of the body.
+ (default = True)
+ body_pose: torch.tensor, optional, Bx(3*24)
+ The default value for the body pose variable.
+ (default = None)
+ num_betas: int, optional
+ Number of shape components to use
+ (default = 10).
+ create_betas: bool, optional
+ Flag for creating a member variable for the shape space
+ (default = True).
+ betas: torch.tensor, optional, Bx10
+ The default value for the shape member variable.
+ (default = None)
+ create_transl: bool, optional
+ Flag for creating a member variable for the translation
+ of the body. (default = True)
+ transl: torch.tensor, optional, Bx3
+ The default value for the transl variable.
+ (default = None)
+ batch_size: int, optional
+ The batch size used for creating the member variables.
+ dtype: torch.dtype, optional
+ The data type for the created variables.
+ """
+ if gender not in ['male', 'female', 'neutral']:
+ raise RuntimeError('Invalid gender! Should be one of '
+ '[\'male\', \'female\', or \'neutral\']!')
+ self.gender = gender
+
+ model_fname = 'STAR_{}.npz'.format(gender.upper())
+ if not os.path.exists(model_path):
+ raise RuntimeError('Path {} does not exist!'.format(model_path))
+ elif os.path.isdir(model_path):
+ star_path = os.path.join(model_path, model_fname)
+ else:
+ if os.path.split(model_path)[-1] != model_fname:
+ raise RuntimeError(
+ f'Model filename ({model_fname}) and gender '
+ f'({gender}) are incompatible!')
+ star_path = model_path
+
+ super(STAR, self).__init__()
+
+ self.keypoint_src = keypoint_src
+ self.keypoint_dst = keypoint_dst
+ self.keypoint_approximate = keypoint_approximate
+
+ star_model = np.load(star_path, allow_pickle=True)
+ J_regressor = star_model['J_regressor']
+ self.num_betas = num_betas
+
+ # Model sparse joints regressor, regresses joints location from a mesh
+ self.register_buffer('J_regressor',
+ torch.tensor(J_regressor, dtype=torch.float))
+
+ # Model skinning weights
+ self.register_buffer(
+ 'weights', torch.tensor(star_model['weights'], dtype=torch.float))
+
+ # Model pose corrective blend shapes
+ self.register_buffer(
+ 'posedirs',
+ torch.tensor(star_model['posedirs'].reshape((-1, 93)),
+ dtype=torch.float))
+
+ # Mean Shape
+ self.register_buffer(
+ 'v_template',
+ torch.tensor(star_model['v_template'], dtype=torch.float))
+
+ # Shape corrective blend shapes
+ self.register_buffer(
+ 'shapedirs',
+ torch.tensor(star_model['shapedirs'][:, :, :num_betas],
+ dtype=torch.float))
+
+ # Mesh traingles
+ self.register_buffer(
+ 'faces', torch.from_numpy(star_model['f'].astype(np.int64)))
+ self.f = star_model['f']
+
+ # Kinematic tree of the model
+ self.register_buffer(
+ 'kintree_table',
+ torch.from_numpy(star_model['kintree_table'].astype(np.int64)))
+
+ id_to_col = {
+ self.kintree_table[1, i].item(): i
+ for i in range(self.kintree_table.shape[1])
+ }
+ self.register_buffer(
+ 'parent',
+ torch.tensor([
+ id_to_col[self.kintree_table[0, it].item()]
+ for it in range(1, self.kintree_table.shape[1])
+ ],
+ dtype=torch.int64))
+
+ if create_global_orient:
+ if global_orient is None:
+ default_global_orient = torch.zeros([batch_size, 3],
+ dtype=dtype)
+ else:
+ if torch.is_tensor(global_orient):
+ default_global_orient = global_orient.clone().detach()
+ else:
+ default_global_orient = torch.tensor(global_orient,
+ dtype=dtype)
+
+ global_orient = nn.Parameter(default_global_orient,
+ requires_grad=True)
+ self.register_parameter('global_orient', global_orient)
+
+ if create_body_pose:
+ if body_pose is None:
+ default_body_pose = torch.zeros(
+ [batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype)
+ else:
+ if torch.is_tensor(body_pose):
+ default_body_pose = body_pose.clone().detach()
+ else:
+ default_body_pose = torch.tensor(body_pose, dtype=dtype)
+ self.register_parameter(
+ 'body_pose', nn.Parameter(default_body_pose,
+ requires_grad=True))
+
+ if create_betas:
+ if betas is None:
+ default_betas = torch.zeros([batch_size, self.num_betas],
+ dtype=dtype)
+ else:
+ if torch.is_tensor(betas):
+ default_betas = betas.clone().detach()
+ else:
+ default_betas = torch.tensor(betas, dtype=dtype)
+
+ self.register_parameter(
+ 'betas', nn.Parameter(default_betas, requires_grad=True))
+
+ if create_transl:
+ if transl is None:
+ default_transl = torch.zeros([batch_size, 3],
+ dtype=dtype,
+ requires_grad=True)
+ else:
+ default_transl = torch.tensor(transl, dtype=dtype)
+ self.register_parameter(
+ 'transl', nn.Parameter(default_transl, requires_grad=True))
+
+ self.verts = None
+ self.J = None
+ self.R = None
+
+ def forward(self,
+ global_orient: Optional[torch.Tensor] = None,
+ body_pose: Optional[torch.Tensor] = None,
+ betas: Optional[torch.Tensor] = None,
+ transl: Optional[torch.Tensor] = None,
+ return_verts: bool = True,
+ return_full_pose: bool = True) -> torch.Tensor:
+ """Forward pass for the STAR model.
+
+ Args:
+ global_orient: torch.tensor, optional, shape Bx3
+ Global orientation (rotation) of the body. If given, ignore the
+ member variable and use it as the global rotation of the body.
+ Useful if someone wishes to predicts this with an external
+ model. (default=None)
+ body_pose: torch.Tensor, shape Bx(J*3)
+ Pose parameters for the STAR model. It should be a tensor that
+ contains joint rotations in axis-angle format. If given, ignore
+ the member variable and use it as the body parameters.
+ (default=None)
+ betas: torch.Tensor, shape Bx10
+ Shape parameters for the STAR model. If given, ignore the
+ member variable and use it as shape parameters. (default=None)
+ transl: torch.Tensor, shape Bx3
+ Translation vector for the STAR model. If given, ignore the
+ member variable and use it as the translation of the body.
+ (default=None)
+ Returns:
+ output: Contains output parameters and attributes corresponding
+ to other body models.
+ """
+ global_orient = (global_orient
+ if global_orient is not None else self.global_orient)
+ body_pose = body_pose if body_pose is not None else self.body_pose
+ betas = betas if betas is not None else self.betas
+ apply_transl = transl is not None or hasattr(self, 'transl')
+ if transl is None and hasattr(self, 'transl'):
+ transl = self.transl
+
+ batch_size = body_pose.shape[0]
+ v_template = self.v_template[None, :]
+ shapedirs = self.shapedirs.view(-1, self.num_betas)[None, :].expand(
+ batch_size, -1, -1)
+ beta = betas[:, :, None]
+ v_shaped = torch.matmul(shapedirs, beta).view(-1, 6890, 3) + v_template
+ J = torch.einsum('bik,ji->bjk', [v_shaped, self.J_regressor])
+
+ pose_quat = self.normalize_quaternion(body_pose.view(-1, 3)).view(
+ batch_size, -1)
+ pose_feat = torch.cat((pose_quat[:, 4:], beta[:, 1]), 1)
+
+ R = aa_to_rotmat(body_pose.view(-1, 3)).view(batch_size, 24, 3, 3)
+ R = R.view(batch_size, 24, 3, 3)
+
+ posedirs = self.posedirs[None, :].expand(batch_size, -1, -1)
+ v_posed = v_shaped + torch.matmul(
+ posedirs, pose_feat[:, :, None]).view(-1, 6890, 3)
+
+ root_transform = make_homegeneous_rotmat_batch(
+ torch.cat((R[:, 0], J[:, 0][:, :, None]), 2))
+ results = [root_transform]
+ for i in range(0, self.parent.shape[0]):
+ transform_i = make_homegeneous_rotmat_batch(
+ torch.cat((R[:, i + 1], J[:, i + 1][:, :, None] -
+ J[:, self.parent[i]][:, :, None]), 2))
+ curr_res = torch.matmul(results[self.parent[i]], transform_i)
+ results.append(curr_res)
+ results = torch.stack(results, dim=1)
+ posed_joints = results[:, :, :3, 3]
+
+ if apply_transl:
+ posed_joints += transl[:, None, :]
+ v_posed += transl[:, None, :]
+
+ joints, joint_mask = convert_kps(posed_joints,
+ src=self.keypoint_src,
+ dst=self.keypoint_dst,
+ approximate=self.keypoint_approximate)
+
+ joint_mask = torch.tensor(joint_mask,
+ dtype=torch.uint8,
+ device=joints.device)
+ joint_mask = joint_mask.reshape(1, -1).expand(batch_size, -1)
+
+ output = dict(global_orient=global_orient,
+ body_pose=body_pose,
+ joints=posed_joints,
+ joint_mask=joint_mask,
+ keypoints=torch.cat([joints, joint_mask[:, :, None]],
+ dim=-1),
+ betas=beta)
+
+ if return_verts:
+ output['vertices'] = v_posed
+ if return_full_pose:
+ output['full_pose'] = torch.cat([global_orient, body_pose], dim=1)
+
+ return output
+
+ @classmethod
+ def normalize_quaternion(self, theta: torch.Tensor) -> torch.Tensor:
+ """Computes a normalized quaternion ([0,0,0,0] when the body is in rest
+ pose) given joint angles.
+
+ Args:
+ theta (torch.Tensor): A tensor of joints axis angles,
+ batch size x number of joints x 3
+
+ Returns:
+ quat (torch.Tensor)
+ """
+ l1norm = torch.norm(theta + 1e-8, p=2, dim=1)
+ angle = torch.unsqueeze(l1norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_sin * normalized, v_cos - 1], dim=1)
+ return quat
diff --git a/detrsmpl/models/body_models/utils.py b/detrsmpl/models/body_models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..da3a4f8dce68a0f1284012453b0136b425ae96d0
--- /dev/null
+++ b/detrsmpl/models/body_models/utils.py
@@ -0,0 +1,116 @@
+import numpy as np
+
+from detrsmpl.utils.transforms import aa_to_rotmat, rotmat_to_aa
+
+
+def transform_to_camera_frame(global_orient, transl, pelvis, extrinsic):
+ """Transform body model parameters to camera frame.
+
+ Args:
+ global_orient (numpy.ndarray): shape (3, ). Only global_orient and
+ transl needs to be updated in the rigid transformation
+ transl (numpy.ndarray): shape (3, ).
+ pelvis (numpy.ndarray): shape (3, ). 3D joint location of pelvis
+ This is necessary to eliminate the offset from SMPL
+ canonical space origin to pelvis, because the global orient
+ is conducted around the pelvis, not the canonical space origin
+ extrinsic (numpy.ndarray): shape (4, 4). Transformation matrix
+ from world frame to camera frame
+ Returns:
+ (new_gloabl_orient, new_transl)
+ new_gloabl_orient: transformed global orient
+ new_transl: transformed transl
+ """
+
+ # take out the small offset from smpl origin to pelvis
+ transl_offset = pelvis - transl
+ T_p2w = np.eye(4)
+ T_p2w[:3, 3] = transl_offset
+
+ # camera extrinsic: transformation from world frame to camera frame
+ T_w2c = extrinsic
+
+ # smpl transformation: from vertex frame to world frame
+ T_v2p = np.eye(4)
+ global_orient_mat = aa_to_rotmat(global_orient)
+ T_v2p[:3, :3] = global_orient_mat
+ T_v2p[:3, 3] = transl
+
+ # compute combined transformation from vertex to world
+ T_v2w = T_p2w @ T_v2p
+
+ # compute transformation from vertex to camera
+ T_v2c = T_w2c @ T_v2w
+
+ # decompose vertex to camera transformation
+ # np: new pelvis frame
+ # T_v2c = T_np2c x T_v2np
+ T_np2c = T_p2w
+ T_v2np = np.linalg.inv(T_np2c) @ T_v2c
+
+ # decompose into new global orient and new transl
+ new_global_orient_mat = T_v2np[:3, :3]
+ new_global_orient = rotmat_to_aa(new_global_orient_mat)
+ new_transl = T_v2np[:3, 3]
+
+ return new_global_orient, new_transl
+
+
+def batch_transform_to_camera_frame(global_orient, transl, pelvis, extrinsic):
+ """Transform body model parameters to camera frame by batch.
+
+ Args:
+ global_orient (np.ndarray): shape (N, 3). Only global_orient and
+ transl needs to be updated in the rigid transformation
+ transl (np.ndarray): shape (N, 3).
+ pelvis (np.ndarray): shape (N, 3). 3D joint location of pelvis
+ This is necessary to eliminate the offset from SMPL
+ canonical space origin to pelvis, because the global orient
+ is conducted around the pelvis, not the canonical space origin
+ extrinsic (np.ndarray): shape (4, 4). Transformation matrix
+ from world frame to camera frame
+ Returns:
+ (new_gloabl_orient, new_transl)
+ new_gloabl_orient: transformed global orient
+ new_transl: transformed transl
+ """
+ N = len(global_orient)
+ assert global_orient.shape == (N, 3)
+ assert transl.shape == (N, 3)
+ assert pelvis.shape == (N, 3)
+
+ # take out the small offset from smpl origin to pelvis
+ transl_offset = pelvis - transl
+ T_p2w = np.eye(4).reshape(1, 4, 4).repeat(N, axis=0)
+ T_p2w[:, :3, 3] = transl_offset
+
+ # camera extrinsic: transformation from world frame to camera frame
+ T_w2c = extrinsic
+
+ # smpl transformation: from vertex frame to world frame
+ T_v2p = np.eye(4).reshape(1, 4, 4).repeat(N, axis=0)
+ global_orient_mat = aa_to_rotmat(global_orient)
+ T_v2p[:, :3, :3] = global_orient_mat
+ T_v2p[:, :3, 3] = transl
+
+ # compute combined transformation from vertex to world
+ T_v2w = T_p2w @ T_v2p
+
+ # compute transformation from vertex to camera
+ T_v2c = T_w2c @ T_v2w
+
+ # decompose vertex to camera transformation
+ # np: new pelvis frame
+ # T_v2c = T_np2c x T_v2np
+ T_np2c = T_p2w
+ T_v2np = np.linalg.inv(T_np2c) @ T_v2c
+
+ # decompose into new global orient and new transl
+ new_global_orient_mat = T_v2np[:, :3, :3]
+ new_global_orient = rotmat_to_aa(new_global_orient_mat)
+ new_transl = T_v2np[:, :3, 3]
+
+ assert new_global_orient.shape == (N, 3)
+ assert new_transl.shape == (N, 3)
+
+ return new_global_orient, new_transl
diff --git a/detrsmpl/models/discriminators/__init__.py b/detrsmpl/models/discriminators/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/models/discriminators/builder.py b/detrsmpl/models/discriminators/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c981eef44bf68d097a6f6b15a157b3d1ad714de
--- /dev/null
+++ b/detrsmpl/models/discriminators/builder.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from mmcv.utils import Registry
+
+from .pose_discriminator import (
+ FullPoseDiscriminator,
+ PoseDiscriminator,
+ ShapeDiscriminator,
+ SMPLDiscriminator,
+)
+
+DISCRIMINATORS = Registry('discriminators')
+
+DISCRIMINATORS.register_module(name='ShapeDiscriminator',
+ module=ShapeDiscriminator)
+DISCRIMINATORS.register_module(name='PoseDiscriminator',
+ module=PoseDiscriminator)
+DISCRIMINATORS.register_module(name='FullPoseDiscriminator',
+ module=FullPoseDiscriminator)
+DISCRIMINATORS.register_module(name='SMPLDiscriminator',
+ module=SMPLDiscriminator)
+
+
+def build_discriminator(cfg):
+ """Build discriminator."""
+ if cfg is None:
+ return None
+ return DISCRIMINATORS.build(cfg)
diff --git a/detrsmpl/models/discriminators/pose_discriminator.py b/detrsmpl/models/discriminators/pose_discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..1233302688167b426e88a434d0595783566f8e8b
--- /dev/null
+++ b/detrsmpl/models/discriminators/pose_discriminator.py
@@ -0,0 +1,302 @@
+# ------------------------------------------------------------------------------
+# Adapted from https://github.com/akanazawa/hmr
+# Original licence: Copyright (c) 2018 akanazawa, under the MIT License.
+# ------------------------------------------------------------------------------
+
+from abc import abstractmethod
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import normal_init, xavier_init
+
+from detrsmpl.utils.geometry import batch_rodrigues
+
+
+class BaseDiscriminator(nn.Module):
+ """Base linear module for SMPL parameter discriminator.
+
+ Args:
+ fc_layers (Tuple): Tuple of neuron count,
+ such as (9, 32, 32, 1)
+ use_dropout (Tuple): Tuple of bool define use dropout or not
+ for each layer, such as (True, True, False)
+ drop_prob (Tuple): Tuple of float defined the drop prob,
+ such as (0.5, 0.5, 0)
+ use_activation(Tuple): Tuple of bool define use active function
+ or not, such as (True, True, False)
+ """
+ def __init__(self, fc_layers, use_dropout, drop_prob, use_activation):
+ super().__init__()
+ self.fc_layers = fc_layers
+ self.use_dropout = use_dropout
+ self.drop_prob = drop_prob
+ self.use_activation = use_activation
+ self._check()
+ self.create_layers()
+
+ def _check(self):
+ """Check input to avoid ValueError."""
+ if not isinstance(self.fc_layers, tuple):
+ raise TypeError(f'fc_layers require tuple, '
+ f'get {type(self.fc_layers)}')
+
+ if not isinstance(self.use_dropout, tuple):
+ raise TypeError(f'use_dropout require tuple, '
+ f'get {type(self.use_dropout)}')
+
+ if not isinstance(self.drop_prob, tuple):
+ raise TypeError(f'drop_prob require tuple, '
+ f'get {type(self.drop_prob)}')
+
+ if not isinstance(self.use_activation, tuple):
+ raise TypeError(f'use_activation require tuple, '
+ f'get {type(self.use_activation)}')
+
+ l_fc_layer = len(self.fc_layers)
+ l_use_drop = len(self.use_dropout)
+ l_drop_prob = len(self.drop_prob)
+ l_use_activation = len(self.use_activation)
+
+ pass_check = (l_fc_layer >= 2 and l_use_drop < l_fc_layer
+ and l_drop_prob < l_fc_layer
+ and l_use_activation < l_fc_layer
+ and l_drop_prob == l_use_drop)
+
+ if not pass_check:
+ msg = 'Wrong BaseDiscriminator parameters!'
+ raise ValueError(msg)
+
+ def create_layers(self):
+ """Create layers."""
+ l_fc_layer = len(self.fc_layers)
+ l_use_drop = len(self.use_dropout)
+ l_use_activation = len(self.use_activation)
+
+ self.fc_blocks = nn.Sequential()
+
+ for i in range(l_fc_layer - 1):
+ self.fc_blocks.add_module(name=f'regressor_fc_{i}',
+ module=nn.Linear(
+ in_features=self.fc_layers[i],
+ out_features=self.fc_layers[i + 1]))
+
+ if i < l_use_activation and self.use_activation[i]:
+ self.fc_blocks.add_module(name=f'regressor_af_{i}',
+ module=nn.ReLU())
+
+ if i < l_use_drop and self.use_dropout[i]:
+ self.fc_blocks.add_module(
+ name=f'regressor_fc_dropout_{i}',
+ module=nn.Dropout(p=self.drop_prob[i]))
+
+ @abstractmethod
+ def forward(self, inputs):
+ """Forward function."""
+ msg = 'the base class [BaseDiscriminator] is not callable!'
+ raise NotImplementedError(msg)
+
+ def init_weights(self):
+ """Initialize model weights."""
+ for m in self.fc_blocks.named_modules():
+ if isinstance(m, nn.Linear):
+ xavier_init(m, gain=0.01)
+
+
+class ShapeDiscriminator(BaseDiscriminator):
+ """Discriminator for SMPL shape parameters, the inputs is (batch_size x 10)
+ Args:
+ fc_layers (Tuple): Tuple of neuron count,
+ such as (10, 5, 1)
+ use_dropout (Tuple): Tuple of bool define use dropout or
+ not for each layer, such as (True, True, False)
+ drop_prob (Tuple): Tuple of float defined the drop prob,
+ such as (0.5, 0)
+ use_activation(Tuple): Tuple of bool define use active
+ function or not, such as (True, False)
+ """
+ def __init__(self, fc_layers, use_dropout, drop_prob, use_activation):
+ if fc_layers[-1] != 1:
+ msg = f'the neuron count of the last layer ' \
+ f'must be 1, but got {fc_layers[-1]}'
+ raise ValueError(msg)
+
+ super().__init__(fc_layers, use_dropout, drop_prob, use_activation)
+
+ def forward(self, inputs):
+ """Forward function."""
+ return self.fc_blocks(inputs)
+
+
+class PoseDiscriminator(nn.Module):
+ """Discriminator for SMPL pose parameters of each joint.
+
+ It is composed of
+ discriminators for each joints. The inputs is (batch_size x joint_count x
+ 9)
+ Args:
+ channels (Tuple): Tuple of channel number,
+ such as (9, 32, 32, 1)
+ joint_count (int): Joint number, such as 23
+ """
+ def __init__(self, channels, joint_count):
+ super().__init__()
+ if channels[-1] != 1:
+ msg = f'the neuron count of the last layer ' \
+ f'must be 1, but got {channels[-1]}'
+ raise ValueError(msg)
+ self.joint_count = joint_count
+
+ self.conv_blocks = nn.Sequential()
+ len_channels = len(channels)
+ for idx in range(len_channels - 2):
+ self.conv_blocks.add_module(name=f'conv_{idx}',
+ module=nn.Conv2d(
+ in_channels=channels[idx],
+ out_channels=channels[idx + 1],
+ kernel_size=1,
+ stride=1))
+
+ self.fc_layer = nn.ModuleList()
+ for idx in range(joint_count):
+ self.fc_layer.append(
+ nn.Linear(in_features=channels[len_channels - 2],
+ out_features=1))
+
+ def forward(self, inputs):
+ """Forward function.
+
+ The input is (batch_size x joint_count x 9)
+ """
+ # shape: batch_size x 9 x 1 x joint_count
+ inputs = inputs.transpose(1, 2).unsqueeze(2).contiguous()
+ # shape: batch_size x c x 1 x joint_count
+ internal_outputs = self.conv_blocks(inputs)
+ outputs = []
+ for idx in range(self.joint_count):
+ outputs.append(self.fc_layer[idx](internal_outputs[:, :, 0, idx]))
+
+ return torch.cat(outputs, 1), internal_outputs
+
+ def init_weights(self):
+ """Initialize model weights."""
+ for m in self.conv_blocks:
+ if isinstance(m, nn.Conv2d):
+ normal_init(m, std=0.001, bias=0)
+ for m in self.fc_layer.named_modules():
+ if isinstance(m, nn.Linear):
+ xavier_init(m, gain=0.01)
+
+
+class FullPoseDiscriminator(BaseDiscriminator):
+ """Discriminator for SMPL pose parameters of all joints.
+
+ Args:
+ fc_layers (Tuple): Tuple of neuron count,
+ such as (736, 1024, 1024, 1)
+ use_dropout (Tuple): Tuple of bool define use dropout or not
+ for each layer, such as (True, True, False)
+ drop_prob (Tuple): Tuple of float defined the drop prob,
+ such as (0.5, 0.5, 0)
+ use_activation(Tuple): Tuple of bool define use active
+ function or not, such as (True, True, False)
+ """
+ def __init__(self, fc_layers, use_dropout, drop_prob, use_activation):
+ if fc_layers[-1] != 1:
+ msg = f'the neuron count of the last layer must be 1,' \
+ f' but got {fc_layers[-1]}'
+ raise ValueError(msg)
+
+ super().__init__(fc_layers, use_dropout, drop_prob, use_activation)
+
+ def forward(self, inputs):
+ """Forward function."""
+ return self.fc_blocks(inputs)
+
+
+class SMPLDiscriminator(nn.Module):
+ """Discriminator for SMPL pose and shape parameters.
+
+ It is composed of a
+ discriminator for SMPL shape parameters, a discriminator for SMPL pose
+ parameters of all joints and a discriminator for SMPL pose parameters of
+ each joint.
+ Args:
+ beta_channel (tuple of int): Tuple of neuron count of the
+ discriminator of shape parameters. Defaults to (10, 5, 1)
+ per_joint_channel (tuple of int): Tuple of neuron count of the
+ discriminator of each joint. Defaults to (9, 32, 32, 1)
+ full_pose_channel (tuple of int): Tuple of neuron count of the
+ discriminator of full pose. Defaults to (23*32, 1024, 1024, 1)
+ """
+ def __init__(self,
+ beta_channel=(10, 5, 1),
+ per_joint_channel=(9, 32, 32, 1),
+ full_pose_channel=(23 * 32, 1024, 1024, 1)):
+ super().__init__()
+ self.joint_count = 23
+ # The count of SMPL shape parameter is 10.
+ assert beta_channel[0] == 10
+ # Use 3 x 3 rotation matrix as the pose parameters
+ # of each joint, so the input channel is 9.
+ assert per_joint_channel[0] == 9
+ assert self.joint_count * per_joint_channel[-2] \
+ == full_pose_channel[0]
+
+ self.beta_channel = beta_channel
+ self.per_joint_channel = per_joint_channel
+ self.full_pose_channel = full_pose_channel
+ self._create_sub_modules()
+
+ def _create_sub_modules(self):
+ """Create sub discriminators."""
+
+ # create theta discriminator for each joint
+ self.pose_discriminator = PoseDiscriminator(self.per_joint_channel,
+ self.joint_count)
+
+ # create full pose discriminator for total joints
+ fc_layers = self.full_pose_channel
+ use_dropout = tuple([False] * (len(fc_layers) - 1))
+ drop_prob = tuple([0.5] * (len(fc_layers) - 1))
+ use_activation = tuple([True] * (len(fc_layers) - 2) + [False])
+
+ self.full_pose_discriminator = FullPoseDiscriminator(
+ fc_layers, use_dropout, drop_prob, use_activation)
+
+ # create shape discriminator for betas
+ fc_layers = self.beta_channel
+ use_dropout = tuple([False] * (len(fc_layers) - 1))
+ drop_prob = tuple([0.5] * (len(fc_layers) - 1))
+ use_activation = tuple([True] * (len(fc_layers) - 2) + [False])
+ self.shape_discriminator = ShapeDiscriminator(fc_layers, use_dropout,
+ drop_prob,
+ use_activation)
+
+ def forward(self, thetas):
+ """Forward function."""
+ _, poses, shapes = thetas
+
+ batch_size = poses.shape[0]
+ shape_disc_value = self.shape_discriminator(shapes)
+
+ # The first rotation matrix is global rotation
+ # and is NOT used in discriminator.
+ if poses.dim() == 2:
+ rotate_matrixs = \
+ batch_rodrigues(poses.contiguous().view(-1, 3)
+ ).view(batch_size, 24, 9)[:, 1:, :]
+ else:
+ rotate_matrixs = poses.contiguous().view(batch_size, 24,
+ 9)[:, 1:, :].contiguous()
+ pose_disc_value, pose_inter_disc_value \
+ = self.pose_discriminator(rotate_matrixs)
+ full_pose_disc_value = self.full_pose_discriminator(
+ pose_inter_disc_value.contiguous().view(batch_size, -1))
+ return torch.cat(
+ (pose_disc_value, full_pose_disc_value, shape_disc_value), 1)
+
+ def init_weights(self):
+ """Initialize model weights."""
+ self.full_pose_discriminator.init_weights()
+ self.pose_discriminator.init_weights()
+ self.shape_discriminator.init_weights()
diff --git a/detrsmpl/models/heads/__init__.py b/detrsmpl/models/heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/models/heads/builder.py b/detrsmpl/models/heads/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8df6333e9ed58bb7608a27be3ae77fa2b327389
--- /dev/null
+++ b/detrsmpl/models/heads/builder.py
@@ -0,0 +1,27 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from mmcv.utils import Registry
+
+from .detr_head import DeformableDETRHead, DETRHead
+from .expose_head import ExPoseBodyHead, ExPoseFaceHead, ExPoseHandHead
+from .hmr_head import HMRHead
+from .hybrik_head import HybrIKHead
+from .pare_head import PareHead
+
+HEADS = Registry('heads')
+
+HEADS.register_module(name='HybrIKHead', module=HybrIKHead)
+HEADS.register_module(name='HMRHead', module=HMRHead)
+HEADS.register_module(name='PareHead', module=PareHead)
+HEADS.register_module(name='ExPoseBodyHead', module=ExPoseBodyHead)
+HEADS.register_module(name='ExPoseHandHead', module=ExPoseHandHead)
+HEADS.register_module(name='ExPoseFaceHead', module=ExPoseFaceHead)
+HEADS.register_module(name='DETRHead', module=DETRHead)
+HEADS.register_module(name='DeformableDETRHead', module=DeformableDETRHead)
+
+
+def build_head(cfg):
+ """Build head."""
+ if cfg is None:
+ return None
+ return HEADS.build(cfg)
diff --git a/detrsmpl/models/heads/detr_head.py b/detrsmpl/models/heads/detr_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ef62b2b7651b98ba9b2359888ff033674fc1b0e
--- /dev/null
+++ b/detrsmpl/models/heads/detr_head.py
@@ -0,0 +1,1504 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+from abc import ABCMeta
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import (
+ Conv2d,
+ ConvModule,
+ Linear,
+ bias_init_with_prob,
+ build_activation_layer,
+ constant_init,
+)
+from mmcv.cnn.bricks.transformer import FFN
+from mmcv.ops import batched_nms
+from mmcv.runner import BaseModule, force_fp32
+
+from detrsmpl.core.post_processing.bbox.assigners import build_assigner
+# from detrsmpl.core.post_processing.bbox.coder import build_bbox_coder
+from detrsmpl.core.post_processing.bbox.samplers import build_sampler
+from detrsmpl.core.post_processing.bbox.transforms import (
+ bbox_cxcywh_to_xyxy,
+ bbox_xyxy_to_cxcywh,
+)
+# from mmdet.core.anchor.point_generator import MlvlPointGenerator
+# from mmdet.core.utils import filter_scores_and_topk, select_single_mlvl
+from detrsmpl.models.utils import (
+ build_positional_encoding,
+ build_transformer,
+ inverse_sigmoid,
+)
+from detrsmpl.utils.dist_utils import reduce_mean
+from detrsmpl.utils.geometry import rot6d_to_rotmat
+# from utils.misc import multi_apply
+from detrsmpl.utils.misc import multi_apply
+from ..losses.builder import build_loss
+
+
+class DETRHead(BaseModule, metaclass=ABCMeta):
+ """Implements the DETR transformer head.
+
+ See `paper: End-to-End Object Detection with Transformers
+ `_ for details.
+
+ Args:
+ num_classes (int): Number of categories excluding the background.
+ in_channels (int): Number of channels in the input feature map.
+ num_query (int): Number of query in Transformer.
+ num_reg_fcs (int, optional): Number of fully-connected layers used in
+ `FFN`, which is then used for the regression head. Default 2.
+ transformer (obj:`mmcv.ConfigDict`|dict): Config for transformer.
+ Default: None.
+ sync_cls_avg_factor (bool): Whether to sync the avg_factor of
+ all ranks. Default to False.
+ positional_encoding (obj:`mmcv.ConfigDict`|dict):
+ Config for position encoding.
+ loss_cls (obj:`mmcv.ConfigDict`|dict): Config of the
+ classification loss. Default `CrossEntropyLoss`.
+ loss_bbox (obj:`mmcv.ConfigDict`|dict): Config of the
+ regression loss. Default `L1Loss`.
+ loss_iou (obj:`mmcv.ConfigDict`|dict): Config of the
+ regression iou loss. Default `GIoULoss`.
+ tran_cfg (obj:`mmcv.ConfigDict`|dict): Training config of
+ transformer head.
+ test_cfg (obj:`mmcv.ConfigDict`|dict): Testing config of
+ transformer head.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ _version = 2
+
+ def __init__(
+ self,
+ num_classes,
+ in_channels,
+ # anchor free
+ feat_channels=256,
+ stacked_convs=4,
+ strides=(4, 8, 16, 32, 64),
+ dcn_on_last_conv=False,
+ conv_bias='auto',
+ num_query=100,
+ num_reg_fcs=2,
+ transformer=None,
+ sync_cls_avg_factor=False,
+ positional_encoding=dict(type='SinePositionalEncoding',
+ num_feats=128,
+ normalize=True),
+ loss_cls=dict(type='CrossEntropyLoss',
+ bg_cls_weight=0.1,
+ use_sigmoid=False,
+ loss_weight=1.0,
+ class_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=5.0),
+ loss_iou=dict(type='GIoULoss', loss_weight=2.0),
+ # anchor free
+ bbox_coder=dict(type='DistancePointBBoxCoder'),
+ conv_cfg=None,
+ norm_cfg=None,
+ train_cfg=dict(assigner=dict(
+ type='HungarianAssigner',
+ # cls_cost=dict(type='ClassificationCost', weight=1.),
+ # reg_cost=dict(type='BBoxL1Cost', weight=5.0),
+ # iou_cost=dict(type='IoUCost', iou_mode='giou',
+ # weight=2.0)
+ kp3d_cost=dict(
+ type='Keypoints3DCost', convention='smpl_54', weight=5.0),
+ kp2d_cost=dict(
+ type='Keypoints2DCost', convention='smpl_54', weight=5.0),
+ )),
+ test_cfg=dict(max_per_img=100),
+ init_cfg=dict(type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=dict(type='Normal',
+ name='conv_cls',
+ std=0.01,
+ bias_prob=0.01)),
+ **kwargs):
+ # NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
+ # since it brings inconvenience when the initialization of
+ # `AnchorFreeHead` is called.
+ super(DETRHead, self).__init__(init_cfg)
+ self.bg_cls_weight = 0
+ self.sync_cls_avg_factor = sync_cls_avg_factor
+ class_weight = loss_cls.get('class_weight', None)
+ if class_weight is not None and (self.__class__ is DETRHead):
+ assert isinstance(class_weight, float), 'Expected ' \
+ 'class_weight to have type float. Found ' \
+ f'{type(class_weight)}.'
+ # NOTE following the official DETR rep0, bg_cls_weight means
+ # relative classification weight of the no-object class.
+ bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight)
+ assert isinstance(bg_cls_weight, float), 'Expected ' \
+ 'bg_cls_weight to have type float. Found ' \
+ f'{type(bg_cls_weight)}.'
+ class_weight = torch.ones(num_classes + 1) * class_weight
+ # set background class as the last indice
+ class_weight[num_classes] = bg_cls_weight
+ loss_cls.update({'class_weight': class_weight})
+ if 'bg_cls_weight' in loss_cls:
+ loss_cls.pop('bg_cls_weight')
+ self.bg_cls_weight = bg_cls_weight
+
+ if train_cfg:
+ assert 'assigner' in train_cfg, 'assigner should be provided '\
+ 'when train_cfg is set.'
+ assigner = train_cfg['assigner']
+ # TODO: update these
+ # assert loss_cls['loss_weight'] == assigner['kp3d_cost']['weight'], \
+ # 'The classification weight for loss and matcher should be' \
+ # 'exactly the same.'
+ # assert loss_bbox['loss_weight'] == assigner['kp3d_cost'][
+ # 'weight'], 'The regression L1 weight for loss and matcher ' \
+ # 'should be exactly the same.'
+ # assert loss_iou['loss_weight'] == assigner['kp3d_cost']['weight'], \
+ # 'The regression iou weight for loss and matcher should be' \
+ # 'exactly the same.'
+ self.assigner = build_assigner(assigner)
+ # DETR sampling=False, so use PseudoSampler
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+
+ self.num_query = num_query
+ self.num_classes = num_classes
+ self.in_channels = in_channels
+ self.num_reg_fcs = num_reg_fcs
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.fp16_enabled = False
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox = build_loss(loss_bbox)
+ self.loss_iou = build_loss(loss_iou)
+
+ if self.loss_cls.use_sigmoid:
+ self.cls_out_channels = num_classes
+ else:
+ self.cls_out_channels = num_classes + 1
+ self.act_cfg = transformer.get('act_cfg',
+ dict(type='ReLU', inplace=True))
+ self.activate = build_activation_layer(self.act_cfg)
+ self.positional_encoding = build_positional_encoding(
+ positional_encoding)
+ self.transformer = build_transformer(transformer)
+ self.embed_dims = self.transformer.embed_dims
+ assert 'num_feats' in positional_encoding
+ num_feats = positional_encoding['num_feats']
+ assert num_feats * 2 == self.embed_dims, 'embed_dims should' \
+ f' be exactly 2 times of num_feats. Found {self.embed_dims}' \
+ f' and {num_feats}.'
+ self._init_layers()
+
+ def _init_layers(self):
+ """Initialize layers of the transformer head."""
+ self.input_proj = Conv2d(self.in_channels,
+ self.embed_dims,
+ kernel_size=1)
+ self.fc_cls = Linear(self.embed_dims, self.cls_out_channels)
+ self.reg_ffn = FFN(self.embed_dims,
+ self.embed_dims,
+ self.num_reg_fcs,
+ self.act_cfg,
+ dropout=0.0,
+ add_residual=False)
+ self.fc_reg = Linear(self.embed_dims, 4)
+ self.query_embedding = nn.Embedding(self.num_query, self.embed_dims)
+
+ def init_weights(self):
+ """Initialize weights of the transformer head."""
+ # The initialization for transformer is important
+ self.transformer.init_weights()
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ """load checkpoints."""
+ # NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
+ # since `AnchorFreeHead._load_from_state_dict` should not be
+ # called here. Invoking the default `Module._load_from_state_dict`
+ # is enough.
+
+ # Names of some parameters in has been changed.
+ version = local_metadata.get('version', None)
+ if (version is None or version < 2) and self.__class__ is DETRHead:
+ convert_dict = {
+ '.self_attn.': '.attentions.0.',
+ '.ffn.': '.ffns.0.',
+ '.multihead_attn.': '.attentions.1.',
+ '.decoder.norm.': '.decoder.post_norm.'
+ }
+ state_dict_keys = list(state_dict.keys())
+ for k in state_dict_keys:
+ for ori_key, convert_key in convert_dict.items():
+ if ori_key in k:
+ convert_key = k.replace(ori_key, convert_key)
+ state_dict[convert_key] = state_dict[k]
+ del state_dict[k]
+
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys,
+ error_msgs)
+
+ def forward(self, feats, img_metas):
+ """Forward function.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
+
+ - all_cls_scores_list (list[Tensor]): Classification scores \
+ for each scale level. Each is a 4D-tensor with shape \
+ [nb_dec, bs, num_query, cls_out_channels]. Note \
+ `cls_out_channels` should includes background.
+ - all_bbox_preds_list (list[Tensor]): Sigmoid regression \
+ outputs for each scale level. Each is a 4D-tensor with \
+ normalized coordinate format (cx, cy, w, h) and shape \
+ [nb_dec, bs, num_query, 4].
+ """
+ num_levels = len(feats)
+ img_metas_list = [img_metas for _ in range(num_levels)]
+ return multi_apply(self.forward_single, feats, img_metas_list)
+
+ def forward_single(self, x, img_metas):
+ """"Forward function for a single feature level.
+
+ Args:
+ x (Tensor): Input feature from backbone's single stage, shape
+ [bs, c, h, w].
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ all_cls_scores (Tensor): Outputs from the classification head,
+ shape [nb_dec, bs, num_query, cls_out_channels]. Note
+ cls_out_channels should includes background.
+ all_bbox_preds (Tensor): Sigmoid outputs from the regression
+ head with normalized coordinate format (cx, cy, w, h).
+ Shape [nb_dec, bs, num_query, 4].
+ """
+ # construct binary masks which used for the transformer.
+ # NOTE following the official DETR repo, non-zero values representing
+ # ignored positions, while zero values means valid positions.
+ batch_size = x.size(0)
+ input_img_h, input_img_w = img_metas[0]['batch_input_shape']
+ masks = x.new_ones((batch_size, input_img_h, input_img_w))
+ for img_id in range(batch_size):
+ img_h, img_w, _ = img_metas[img_id]['img_shape']
+ masks[img_id, :img_h, :img_w] = 0
+
+ x = self.input_proj(x)
+ # interpolate masks to have the same spatial shape with x
+ masks = F.interpolate(masks.unsqueeze(1),
+ size=x.shape[-2:]).to(torch.bool).squeeze(1)
+ # position encoding
+ pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w]
+ # outs_dec: [nb_dec, bs, num_query, embed_dim]
+ outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
+ pos_embed)
+
+ all_cls_scores = self.fc_cls(outs_dec)
+ all_bbox_preds = self.fc_reg(self.activate(
+ self.reg_ffn(outs_dec))).sigmoid()
+ return all_cls_scores, all_bbox_preds
+
+ @force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
+ def loss(self,
+ all_cls_scores_list,
+ all_bbox_preds_list,
+ gt_bboxes_list,
+ gt_labels_list,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """"Loss function.
+
+ Only outputs from the last feature level are used for computing
+ losses by default.
+
+ Args:
+ all_cls_scores_list (list[Tensor]): Classification outputs
+ for each feature level. Each is a 4D-tensor with shape
+ [nb_dec, bs, num_query, cls_out_channels].
+ all_bbox_preds_list (list[Tensor]): Sigmoid regression
+ outputs for each feature level. Each is a 4D-tensor with
+ normalized coordinate format (cx, cy, w, h) and shape
+ [nb_dec, bs, num_query, 4].
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+ with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image with shape (num_gts, ).
+ img_metas (list[dict]): List of image meta information.
+ gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
+ which can be ignored for each image. Default None.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ # NOTE defaultly only the outputs from the last feature scale is used.
+ all_cls_scores = all_cls_scores_list[-1]
+ all_bbox_preds = all_bbox_preds_list[-1]
+ assert gt_bboxes_ignore is None, \
+ 'Only supports for gt_bboxes_ignore setting to None.'
+
+ num_dec_layers = len(all_cls_scores)
+ all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
+ all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
+ all_gt_bboxes_ignore_list = [
+ gt_bboxes_ignore for _ in range(num_dec_layers)
+ ]
+ img_metas_list = [img_metas for _ in range(num_dec_layers)]
+
+ losses_cls, losses_bbox, losses_iou = multi_apply(
+ self.loss_single, all_cls_scores, all_bbox_preds,
+ all_gt_bboxes_list, all_gt_labels_list, img_metas_list,
+ all_gt_bboxes_ignore_list)
+
+ loss_dict = dict()
+ # loss from the last decoder layer
+ loss_dict['loss_cls'] = losses_cls[-1]
+ loss_dict['loss_bbox'] = losses_bbox[-1]
+ loss_dict['loss_iou'] = losses_iou[-1]
+ # loss from other decoder layers
+ num_dec_layer = 0
+ for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
+ losses_bbox[:-1],
+ losses_iou[:-1]):
+ loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
+ loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
+ loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
+ num_dec_layer += 1
+ return loss_dict
+
+ def loss_single(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes_list,
+ gt_labels_list,
+ img_metas,
+ gt_bboxes_ignore_list=None):
+ """"Loss function for outputs from a single decoder layer of a single
+ feature level.
+
+ Args:
+ cls_scores (Tensor): Box score logits from a single decoder layer
+ for all images. Shape [bs, num_query, cls_out_channels].
+ bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
+ for all images, with normalized coordinate (cx, cy, w, h) and
+ shape [bs, num_query, 4].
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+ with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image with shape (num_gts, ).
+ img_metas (list[dict]): List of image meta information.
+ gt_bboxes_ignore_list (list[Tensor], optional): Bounding
+ boxes which can be ignored for each image. Default None.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components for outputs from
+ a single decoder layer.
+ """
+ num_imgs = cls_scores.size(0)
+ cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
+ bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
+ cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
+ gt_bboxes_list, gt_labels_list,
+ img_metas, gt_bboxes_ignore_list)
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg) = cls_reg_targets
+ labels = torch.cat(labels_list, 0)
+ label_weights = torch.cat(label_weights_list, 0)
+ bbox_targets = torch.cat(bbox_targets_list, 0)
+ bbox_weights = torch.cat(bbox_weights_list, 0)
+
+ # classification loss
+ cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
+ # construct weighted avg_factor to match with the official DETR repo
+ cls_avg_factor = num_total_pos * 1.0 + \
+ num_total_neg * self.bg_cls_weight
+ if self.sync_cls_avg_factor:
+ cls_avg_factor = reduce_mean(
+ cls_scores.new_tensor([cls_avg_factor]))
+ cls_avg_factor = max(cls_avg_factor, 1)
+
+ loss_cls = self.loss_cls(cls_scores,
+ labels,
+ label_weights,
+ avg_factor=cls_avg_factor)
+
+ # Compute the average number of gt boxes across all gpus, for
+ # normalization purposes
+ num_total_pos = loss_cls.new_tensor([num_total_pos])
+ num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
+
+ # construct factors used for rescale bboxes
+ factors = []
+ for img_meta, bbox_pred in zip(img_metas, bbox_preds):
+ img_h, img_w, _ = img_meta['img_shape']
+ factor = bbox_pred.new_tensor([img_w, img_h, img_w,
+ img_h]).unsqueeze(0).repeat(
+ bbox_pred.size(0), 1)
+ factors.append(factor)
+ factors = torch.cat(factors, 0)
+
+ # DETR regress the relative position of boxes (cxcywh) in the image,
+ # thus the learning target is normalized by the image size. So here
+ # we need to re-scale them for calculating IoU loss
+ bbox_preds = bbox_preds.reshape(-1, 4)
+ bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
+ bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
+
+ # regression IoU loss, defaultly GIoU loss
+ loss_iou = self.loss_iou(bboxes,
+ bboxes_gt,
+ bbox_weights,
+ avg_factor=num_total_pos)
+
+ # regression L1 loss
+ loss_bbox = self.loss_bbox(bbox_preds,
+ bbox_targets,
+ bbox_weights,
+ avg_factor=num_total_pos)
+ return loss_cls, loss_bbox, loss_iou
+
+ def get_targets(self,
+ cls_scores_list,
+ bbox_preds_list,
+ gt_bboxes_list,
+ gt_labels_list,
+ img_metas,
+ gt_bboxes_ignore_list=None):
+ """"Compute regression and classification targets for a batch image.
+
+ Outputs from a single decoder layer of a single feature level are used.
+
+ Args:
+ cls_scores_list (list[Tensor]): Box score logits from a single
+ decoder layer for each image with shape [num_query,
+ cls_out_channels].
+ bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
+ decoder layer for each image, with normalized coordinate
+ (cx, cy, w, h) and shape [num_query, 4].
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+ with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image with shape (num_gts, ).
+ img_metas (list[dict]): List of image meta information.
+ gt_bboxes_ignore_list (list[Tensor], optional): Bounding
+ boxes which can be ignored for each image. Default None.
+
+ Returns:
+ tuple: a tuple containing the following targets.
+
+ - labels_list (list[Tensor]): Labels for all images.
+ - label_weights_list (list[Tensor]): Label weights for all \
+ images.
+ - bbox_targets_list (list[Tensor]): BBox targets for all \
+ images.
+ - bbox_weights_list (list[Tensor]): BBox weights for all \
+ images.
+ - num_total_pos (int): Number of positive samples in all \
+ images.
+ - num_total_neg (int): Number of negative samples in all \
+ images.
+ """
+ assert gt_bboxes_ignore_list is None, \
+ 'Only supports for gt_bboxes_ignore setting to None.'
+ num_imgs = len(cls_scores_list)
+ gt_bboxes_ignore_list = [
+ gt_bboxes_ignore_list for _ in range(num_imgs)
+ ]
+
+ (labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply(
+ self._get_target_single, cls_scores_list, bbox_preds_list,
+ gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore_list)
+ num_total_pos = sum((inds.numel() for inds in pos_inds_list))
+ num_total_neg = sum((inds.numel() for inds in neg_inds_list))
+ return (labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg)
+
+ def _get_target_single(self,
+ cls_score,
+ bbox_pred,
+ gt_bboxes,
+ gt_labels,
+ img_meta,
+ gt_bboxes_ignore=None):
+ """"Compute regression and classification targets for one image.
+
+ Outputs from a single decoder layer of a single feature level are used.
+
+ Args:
+ cls_score (Tensor): Box score logits from a single decoder layer
+ for one image. Shape [num_query, cls_out_channels].
+ bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
+ for one image, with normalized coordinate (cx, cy, w, h) and
+ shape [num_query, 4].
+ gt_bboxes (Tensor): Ground truth bboxes for one image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (Tensor): Ground truth class indices for one image
+ with shape (num_gts, ).
+ img_meta (dict): Meta information for one image.
+ gt_bboxes_ignore (Tensor, optional): Bounding boxes
+ which can be ignored. Default None.
+
+ Returns:
+ tuple[Tensor]: a tuple containing the following for one image.
+
+ - labels (Tensor): Labels of each image.
+ - label_weights (Tensor]): Label weights of each image.
+ - bbox_targets (Tensor): BBox targets of each image.
+ - bbox_weights (Tensor): BBox weights of each image.
+ - pos_inds (Tensor): Sampled positive indices for each image.
+ - neg_inds (Tensor): Sampled negative indices for each image.
+ """
+
+ num_bboxes = bbox_pred.size(0)
+ # assigner and sampler
+ assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes,
+ gt_labels, img_meta,
+ gt_bboxes_ignore)
+ sampling_result = self.sampler.sample(assign_result, bbox_pred,
+ gt_bboxes)
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+
+ # label targets
+ labels = gt_bboxes.new_full((num_bboxes, ),
+ self.num_classes,
+ dtype=torch.long)
+ labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
+ label_weights = gt_bboxes.new_ones(num_bboxes)
+
+ # bbox targets
+ bbox_targets = torch.zeros_like(bbox_pred)
+ bbox_weights = torch.zeros_like(bbox_pred)
+ bbox_weights[pos_inds] = 1.0
+ img_h, img_w, _ = img_meta['img_shape']
+
+ # DETR regress the relative position of boxes (cxcywh) in the image.
+ # Thus the learning target should be normalized by the image size, also
+ # the box format should be converted from defaultly x1y1x2y2 to cxcywh.
+ factor = bbox_pred.new_tensor([img_w, img_h, img_w,
+ img_h]).unsqueeze(0)
+ pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor
+ pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
+ bbox_targets[pos_inds] = pos_gt_bboxes_targets
+ return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
+ neg_inds)
+
+ # over-write because img_metas are needed as inputs for bbox_head.
+ def forward_train(self,
+ x,
+ img_metas,
+ gt_bboxes,
+ gt_labels=None,
+ gt_bboxes_ignore=None,
+ proposal_cfg=None,
+ **kwargs):
+ """Forward function for training mode.
+
+ Args:
+ x (list[Tensor]): Features from backbone.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ proposal_cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert proposal_cfg is None, '"proposal_cfg" must be None'
+ outs = self(x, img_metas)
+ if gt_labels is None:
+ loss_inputs = outs + (gt_bboxes, img_metas)
+ else:
+ loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
+ losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
+ return losses
+
+ @force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
+ def get_bboxes(self,
+ all_cls_scores_list,
+ all_bbox_preds_list,
+ img_metas,
+ rescale=False):
+ """Transform network outputs for a batch into bbox predictions.
+
+ Args:
+ all_cls_scores_list (list[Tensor]): Classification outputs
+ for each feature level. Each is a 4D-tensor with shape
+ [nb_dec, bs, num_query, cls_out_channels].
+ all_bbox_preds_list (list[Tensor]): Sigmoid regression
+ outputs for each feature level. Each is a 4D-tensor with
+ normalized coordinate format (cx, cy, w, h) and shape
+ [nb_dec, bs, num_query, 4].
+ img_metas (list[dict]): Meta information of each image.
+ rescale (bool, optional): If True, return boxes in original
+ image space. Default False.
+
+ Returns:
+ list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
+ The first item is an (n, 5) tensor, where the first 4 columns \
+ are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
+ 5-th column is a score between 0 and 1. The second item is a \
+ (n,) tensor where each item is the predicted class label of \
+ the corresponding box.
+ """
+ # NOTE defaultly only using outputs from the last feature level,
+ # and only the outputs from the last decoder layer is used.
+ cls_scores = all_cls_scores_list[-1][-1]
+ bbox_preds = all_bbox_preds_list[-1][-1]
+
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_score = cls_scores[img_id]
+ bbox_pred = bbox_preds[img_id]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ proposals = self._get_bboxes_single(cls_score, bbox_pred,
+ img_shape, scale_factor,
+ rescale)
+ result_list.append(proposals)
+
+ return result_list
+
+ def _get_bboxes_single(self,
+ cls_score,
+ bbox_pred,
+ img_shape,
+ scale_factor,
+ rescale=False):
+ """Transform outputs from the last decoder layer into bbox predictions
+ for each image.
+
+ Args:
+ cls_score (Tensor): Box score logits from the last decoder layer
+ for each image. Shape [num_query, cls_out_channels].
+ bbox_pred (Tensor): Sigmoid outputs from the last decoder layer
+ for each image, with coordinate format (cx, cy, w, h) and
+ shape [num_query, 4].
+ img_shape (tuple[int]): Shape of input image, (height, width, 3).
+ scale_factor (ndarray, optional): Scale factor of the image arange
+ as (w_scale, h_scale, w_scale, h_scale).
+ rescale (bool, optional): If True, return boxes in original image
+ space. Default False.
+
+ Returns:
+ tuple[Tensor]: Results of detected bboxes and labels.
+
+ - det_bboxes: Predicted bboxes with shape [num_query, 5], \
+ where the first 4 columns are bounding box positions \
+ (tl_x, tl_y, br_x, br_y) and the 5-th column are scores \
+ between 0 and 1.
+ - det_labels: Predicted labels of the corresponding box with \
+ shape [num_query].
+ """
+ assert len(cls_score) == len(bbox_pred)
+ max_per_img = self.test_cfg.get('max_per_img', self.num_query)
+ # exclude background
+ if self.loss_cls.use_sigmoid:
+ cls_score = cls_score.sigmoid()
+ scores, indexes = cls_score.view(-1).topk(max_per_img)
+ det_labels = indexes % self.num_classes
+ bbox_index = indexes // self.num_classes
+ bbox_pred = bbox_pred[bbox_index]
+ else:
+ scores, det_labels = F.softmax(cls_score, dim=-1)[..., :-1].max(-1)
+ scores, bbox_index = scores.topk(max_per_img)
+ bbox_pred = bbox_pred[bbox_index]
+ det_labels = det_labels[bbox_index]
+
+ det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred)
+ det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1]
+ det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0]
+ det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1])
+ det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0])
+ if rescale:
+ det_bboxes /= det_bboxes.new_tensor(scale_factor)
+ det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(1)), -1)
+
+ return det_bboxes, det_labels
+
+ def simple_test_bboxes(self, feats, img_metas, rescale=False):
+ """Test det bboxes without test-time augmentation.
+
+ Args:
+ feats (tuple[torch.Tensor]): Multi-level features from the
+ upstream network, each is a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is ``bboxes`` with shape (n, 5),
+ where 5 represent (tl_x, tl_y, br_x, br_y, score).
+ The shape of the second tensor in the tuple is ``labels``
+ with shape (n,)
+ """
+ # forward of this head requires img_metas
+ outs = self.forward(feats, img_metas)
+ results_list = self.get_bboxes(*outs, img_metas, rescale=rescale)
+ return results_list
+
+ def forward_onnx(self, feats, img_metas):
+ """Forward function for exporting to ONNX.
+
+ Over-write `forward` because: `masks` is directly created with
+ zero (valid position tag) and has the same spatial size as `x`.
+ Thus the construction of `masks` is different from that in `forward`.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
+
+ - all_cls_scores_list (list[Tensor]): Classification scores \
+ for each scale level. Each is a 4D-tensor with shape \
+ [nb_dec, bs, num_query, cls_out_channels]. Note \
+ `cls_out_channels` should includes background.
+ - all_bbox_preds_list (list[Tensor]): Sigmoid regression \
+ outputs for each scale level. Each is a 4D-tensor with \
+ normalized coordinate format (cx, cy, w, h) and shape \
+ [nb_dec, bs, num_query, 4].
+ """
+ num_levels = len(feats)
+ img_metas_list = [img_metas for _ in range(num_levels)]
+ return multi_apply(self.forward_single_onnx, feats, img_metas_list)
+
+ def forward_single_onnx(self, x, img_metas):
+ """"Forward function for a single feature level with ONNX exportation.
+
+ Args:
+ x (Tensor): Input feature from backbone's single stage, shape
+ [bs, c, h, w].
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ all_cls_scores (Tensor): Outputs from the classification head,
+ shape [nb_dec, bs, num_query, cls_out_channels]. Note
+ cls_out_channels should includes background.
+ all_bbox_preds (Tensor): Sigmoid outputs from the regression
+ head with normalized coordinate format (cx, cy, w, h).
+ Shape [nb_dec, bs, num_query, 4].
+ """
+ # Note `img_shape` is not dynamically traceable to ONNX,
+ # since the related augmentation was done with numpy under
+ # CPU. Thus `masks` is directly created with zeros (valid tag)
+ # and the same spatial shape as `x`.
+ # The difference between torch and exported ONNX model may be
+ # ignored, since the same performance is achieved (e.g.
+ # 40.1 vs 40.1 for DETR)
+ batch_size = x.size(0)
+ h, w = x.size()[-2:]
+ masks = x.new_zeros((batch_size, h, w)) # [B,h,w]
+
+ x = self.input_proj(x)
+ # interpolate masks to have the same spatial shape with x
+ masks = F.interpolate(masks.unsqueeze(1),
+ size=x.shape[-2:]).to(torch.bool).squeeze(1)
+ pos_embed = self.positional_encoding(masks)
+ outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
+ pos_embed)
+
+ all_cls_scores = self.fc_cls(outs_dec)
+ all_bbox_preds = self.fc_reg(self.activate(
+ self.reg_ffn(outs_dec))).sigmoid()
+ return all_cls_scores, all_bbox_preds
+
+ def onnx_export(self, all_cls_scores_list, all_bbox_preds_list, img_metas):
+ """Transform network outputs into bbox predictions, with ONNX
+ exportation.
+
+ Args:
+ all_cls_scores_list (list[Tensor]): Classification outputs
+ for each feature level. Each is a 4D-tensor with shape
+ [nb_dec, bs, num_query, cls_out_channels].
+ all_bbox_preds_list (list[Tensor]): Sigmoid regression
+ outputs for each feature level. Each is a 4D-tensor with
+ normalized coordinate format (cx, cy, w, h) and shape
+ [nb_dec, bs, num_query, 4].
+ img_metas (list[dict]): Meta information of each image.
+
+ Returns:
+ tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
+ and class labels of shape [N, num_det].
+ """
+ assert len(img_metas) == 1, \
+ 'Only support one input image while in exporting to ONNX'
+
+ cls_scores = all_cls_scores_list[-1][-1]
+ bbox_preds = all_bbox_preds_list[-1][-1]
+
+ # Note `img_shape` is not dynamically traceable to ONNX,
+ # here `img_shape_for_onnx` (padded shape of image tensor)
+ # is used.
+ img_shape = img_metas[0]['img_shape_for_onnx']
+ max_per_img = self.test_cfg.get('max_per_img', self.num_query)
+ batch_size = cls_scores.size(0)
+ # `batch_index_offset` is used for the gather of concatenated tensor
+ batch_index_offset = torch.arange(batch_size).to(
+ cls_scores.device) * max_per_img
+ batch_index_offset = batch_index_offset.unsqueeze(1).expand(
+ batch_size, max_per_img)
+
+ # supports dynamical batch inference
+ if self.loss_cls.use_sigmoid:
+ cls_scores = cls_scores.sigmoid()
+ scores, indexes = cls_scores.view(batch_size, -1).topk(max_per_img,
+ dim=1)
+ det_labels = indexes % self.num_classes
+ bbox_index = indexes // self.num_classes
+ bbox_index = (bbox_index + batch_index_offset).view(-1)
+ bbox_preds = bbox_preds.view(-1, 4)[bbox_index]
+ bbox_preds = bbox_preds.view(batch_size, -1, 4)
+ else:
+ scores, det_labels = F.softmax(cls_scores,
+ dim=-1)[..., :-1].max(-1)
+ scores, bbox_index = scores.topk(max_per_img, dim=1)
+ bbox_index = (bbox_index + batch_index_offset).view(-1)
+ bbox_preds = bbox_preds.view(-1, 4)[bbox_index]
+ det_labels = det_labels.view(-1)[bbox_index]
+ bbox_preds = bbox_preds.view(batch_size, -1, 4)
+ det_labels = det_labels.view(batch_size, -1)
+
+ det_bboxes = bbox_cxcywh_to_xyxy(bbox_preds)
+ # use `img_shape_tensor` for dynamically exporting to ONNX
+ img_shape_tensor = img_shape.flip(0).repeat(2) # [w,h,w,h]
+ img_shape_tensor = img_shape_tensor.unsqueeze(0).unsqueeze(0).expand(
+ batch_size, det_bboxes.size(1), 4)
+ det_bboxes = det_bboxes * img_shape_tensor
+ # dynamically clip bboxes
+ x1, y1, x2, y2 = det_bboxes.split((1, 1, 1, 1), dim=-1)
+ from mmdet.core.export import dynamic_clip_for_onnx
+ x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, img_shape)
+ det_bboxes = torch.cat([x1, y1, x2, y2], dim=-1)
+ det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(-1)), -1)
+
+ return det_bboxes, det_labels
+
+ # BaseDenseHead
+ def _bbox_post_process(self,
+ mlvl_scores,
+ mlvl_labels,
+ mlvl_bboxes,
+ scale_factor,
+ cfg,
+ rescale=False,
+ with_nms=True,
+ mlvl_score_factors=None,
+ **kwargs):
+ """bbox post-processing method.
+
+ The boxes would be rescaled to the original image scale and do
+ the nms operation. Usually `with_nms` is False is used for aug test.
+
+ Args:
+ mlvl_scores (list[Tensor]): Box scores from all scale
+ levels of a single image, each item has shape
+ (num_bboxes, ).
+ mlvl_labels (list[Tensor]): Box class labels from all scale
+ levels of a single image, each item has shape
+ (num_bboxes, ).
+ mlvl_bboxes (list[Tensor]): Decoded bboxes from all scale
+ levels of a single image, each item has shape (num_bboxes, 4).
+ scale_factor (ndarray, optional): Scale factor of the image arange
+ as (w_scale, h_scale, w_scale, h_scale).
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+ mlvl_score_factors (list[Tensor], optional): Score factor from
+ all scale levels of a single image, each item has shape
+ (num_bboxes, ). Default: None.
+
+ Returns:
+ tuple[Tensor]: Results of detected bboxes and labels. If with_nms
+ is False and mlvl_score_factor is None, return mlvl_bboxes and
+ mlvl_scores, else return mlvl_bboxes, mlvl_scores and
+ mlvl_score_factor. Usually with_nms is False is used for aug
+ test. If with_nms is True, then return the following format
+
+ - det_bboxes (Tensor): Predicted bboxes with shape \
+ [num_bboxes, 5], where the first 4 columns are bounding \
+ box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
+ column are scores between 0 and 1.
+ - det_labels (Tensor): Predicted labels of the corresponding \
+ box with shape [num_bboxes].
+ """
+ assert len(mlvl_scores) == len(mlvl_bboxes) == len(mlvl_labels)
+
+ mlvl_bboxes = torch.cat(mlvl_bboxes)
+ if rescale:
+ mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
+ mlvl_scores = torch.cat(mlvl_scores)
+ mlvl_labels = torch.cat(mlvl_labels)
+
+ if mlvl_score_factors is not None:
+ # TODO: Add sqrt operation in order to be consistent with
+ # the paper.
+ mlvl_score_factors = torch.cat(mlvl_score_factors)
+ mlvl_scores = mlvl_scores * mlvl_score_factors
+
+ if with_nms:
+ if mlvl_bboxes.numel() == 0:
+ det_bboxes = torch.cat([mlvl_bboxes, mlvl_scores[:, None]], -1)
+ return det_bboxes, mlvl_labels
+
+ det_bboxes, keep_idxs = batched_nms(mlvl_bboxes, mlvl_scores,
+ mlvl_labels, cfg.nms)
+ det_bboxes = det_bboxes[:cfg.max_per_img]
+ det_labels = mlvl_labels[keep_idxs][:cfg.max_per_img]
+ return det_bboxes, det_labels
+ else:
+ return mlvl_bboxes, mlvl_scores, mlvl_labels
+
+ def simple_test(self, feats, img_metas, rescale=False):
+ """Test function without test-time augmentation.
+
+ Args:
+ feats (tuple[torch.Tensor]): Multi-level features from the
+ upstream network, each is a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is ``bboxes`` with shape (n, 5),
+ where 5 represent (tl_x, tl_y, br_x, br_y, score).
+ The shape of the second tensor in the tuple is ``labels``
+ with shape (n, ).
+ """
+ return self.simple_test_bboxes(feats, img_metas, rescale=rescale)
+
+ # AnchorfreeHead
+
+ def _init_cls_convs(self):
+ """Initialize classification conv layers of the head."""
+ self.cls_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ if self.dcn_on_last_conv and i == self.stacked_convs - 1:
+ conv_cfg = dict(type='DCNv2')
+ else:
+ conv_cfg = self.conv_cfg
+ self.cls_convs.append(
+ ConvModule(chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.conv_bias))
+
+ def _init_reg_convs(self):
+ """Initialize bbox regression conv layers of the head."""
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ if self.dcn_on_last_conv and i == self.stacked_convs - 1:
+ conv_cfg = dict(type='DCNv2')
+ else:
+ conv_cfg = self.conv_cfg
+ self.reg_convs.append(
+ ConvModule(chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.conv_bias))
+
+ def _init_predictor(self):
+ """Initialize predictor layers of the head."""
+ self.conv_cls = nn.Conv2d(self.feat_channels,
+ self.cls_out_channels,
+ 3,
+ padding=1)
+ self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
+
+ def _get_points_single(self,
+ featmap_size,
+ stride,
+ dtype,
+ device,
+ flatten=False):
+ """Get points of a single scale level.
+
+ This function will be deprecated soon.
+ """
+
+ warnings.warn(
+ '`_get_points_single` in `AnchorFreeHead` will be '
+ 'deprecated soon, we support a multi level point generator now'
+ 'you can get points of a single level feature map '
+ 'with `self.prior_generator.single_level_grid_priors` ')
+
+ h, w = featmap_size
+ # First create Range with the default dtype, than convert to
+ # target `dtype` for onnx exporting.
+ x_range = torch.arange(w, device=device).to(dtype)
+ y_range = torch.arange(h, device=device).to(dtype)
+ y, x = torch.meshgrid(y_range, x_range)
+ if flatten:
+ y = y.flatten()
+ x = x.flatten()
+ return y, x
+
+ def get_points(self, featmap_sizes, dtype, device, flatten=False):
+ """Get points according to feature map sizes.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ dtype (torch.dtype): Type of points.
+ device (torch.device): Device of points.
+
+ Returns:
+ tuple: points of each image.
+ """
+ warnings.warn(
+ '`get_points` in `AnchorFreeHead` will be '
+ 'deprecated soon, we support a multi level point generator now'
+ 'you can get points of all levels '
+ 'with `self.prior_generator.grid_priors` ')
+
+ mlvl_points = []
+ for i in range(len(featmap_sizes)):
+ mlvl_points.append(
+ self._get_points_single(featmap_sizes[i], self.strides[i],
+ dtype, device, flatten))
+ return mlvl_points
+
+ def aug_test(self, feats, img_metas, rescale=False):
+ """Test function with test time augmentation.
+
+ Args:
+ feats (list[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains features for all images in the batch.
+ img_metas (list[list[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch. each dict has image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[ndarray]: bbox results of each class
+ """
+ return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
+
+
+class DeformableDETRHead(DETRHead):
+ """Head of DeformDETR: Deformable DETR: Deformable Transformers for End-to-
+ End Object Detection.
+
+ Code is modified from the `official github repo
+ `_.
+
+ More details can be found in the `paper
+ `_ .
+
+ Args:
+ with_box_refine (bool): Whether to refine the reference points
+ in the decoder. Defaults to False.
+ as_two_stage (bool) : Whether to generate the proposal from
+ the outputs of encoder.
+ transformer (obj:`ConfigDict`): ConfigDict is used for building
+ the Encoder and Decoder.
+ """
+ def __init__(
+ self,
+ *args,
+ with_box_refine=False,
+ as_two_stage=False,
+ transformer=None,
+ npose=144,
+ nbeta=10,
+ ncam=3,
+ hdim=256, # TODO: choose proper hdim
+ niter=3,
+ smpl_mean_params=None,
+ **kwargs):
+ self.with_box_refine = with_box_refine
+ self.as_two_stage = as_two_stage
+ self.npose = npose
+ self.nbeta = nbeta
+ self.ncam = ncam
+ self.hdim = hdim
+ self.niter = niter
+
+ if self.as_two_stage:
+ transformer['as_two_stage'] = self.as_two_stage
+
+ super(DeformableDETRHead, self).__init__(*args,
+ transformer=transformer,
+ **kwargs)
+
+ if smpl_mean_params is None:
+ init_pose = torch.zeros([1, npose])
+ init_shape = torch.zeros([1, nbeta])
+ init_cam = torch.FloatTensor([[1, 0, 0]])
+ else:
+ mean_params = np.load(smpl_mean_params)
+ init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
+ init_shape = torch.from_numpy(
+ mean_params['shape'][:].astype('float32')).unsqueeze(0)
+ init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)
+ self.register_buffer('init_pose', init_pose)
+ self.register_buffer('init_shape', init_shape)
+ self.register_buffer('init_cam', init_cam)
+
+ def _init_layers(self):
+ """Initialize classification branch and regression branch of head."""
+
+ fc_cls = Linear(self.embed_dims, self.cls_out_channels)
+ reg_branch = []
+ for _ in range(self.num_reg_fcs):
+ reg_branch.append(Linear(self.embed_dims, self.embed_dims))
+ reg_branch.append(nn.ReLU())
+ reg_branch.append(Linear(self.embed_dims, 4))
+ reg_branch = nn.Sequential(*reg_branch)
+
+ # smpl branch
+ smpl_branch = nn.ModuleList([
+ nn.Linear(self.embed_dims + self.npose + self.nbeta + self.ncam,
+ self.hdim), # fc1
+ nn.Dropout(),
+ nn.Linear(self.hdim, self.hdim), # fc2
+ nn.Dropout(),
+ nn.Linear(self.hdim, self.npose), # regress pose
+ nn.Linear(self.hdim, self.nbeta), # regress beta
+ nn.Linear(self.hdim, self.ncam) # regress cam
+ ])
+
+ def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+ # last reg_branch is used to generate proposal from
+ # encode feature map when as_two_stage is True.
+ num_pred = (self.transformer.decoder.num_layers + 1) if \
+ self.as_two_stage else self.transformer.decoder.num_layers
+
+ if self.with_box_refine:
+ self.cls_branches = _get_clones(fc_cls, num_pred)
+ self.reg_branches = _get_clones(reg_branch, num_pred)
+ self.smpl_branches = _get_clones(smpl_branch, num_pred)
+ else:
+
+ self.cls_branches = nn.ModuleList(
+ [fc_cls for _ in range(num_pred)])
+ self.reg_branches = nn.ModuleList(
+ [reg_branch for _ in range(num_pred)])
+ self.smpl_branches = nn.ModuleList(
+ [smpl_branch for _ in range(num_pred)])
+ if not self.as_two_stage:
+ self.query_embedding = nn.Embedding(self.num_query,
+ self.embed_dims * 2)
+
+ def regress_smpl(self,
+ lvl,
+ feature,
+ init_pose=None,
+ init_shape=None,
+ init_cam=None,
+ n_iter=3):
+ batch_size = feature.shape[0]
+ num_query = feature.shape[1]
+ if init_pose is None:
+ init_pose = self.init_pose.expand(batch_size, num_query, -1)
+ if init_shape is None:
+ init_shape = self.init_shape.expand(batch_size, num_query, -1)
+ if init_cam is None:
+ init_cam = self.init_cam.expand(batch_size, num_query, -1)
+
+ pred_pose = init_pose
+ pred_shape = init_shape
+ pred_cam = init_cam
+
+ for _ in range(n_iter):
+ xc = torch.cat([feature, pred_pose, pred_shape, pred_cam], -1)
+ xc = self.smpl_branches[lvl][0](xc) # fc1
+ xc = self.smpl_branches[lvl][1](xc) # drop
+ xc = self.smpl_branches[lvl][2](xc) # fc2
+ xc = self.smpl_branches[lvl][3](xc) # drop
+ pred_pose = self.smpl_branches[lvl][4](xc) + pred_pose # reg pose
+ pred_shape = self.smpl_branches[lvl][5](
+ xc) + pred_shape # reg beat
+ pred_cam = self.smpl_branches[lvl][6](xc) + pred_cam # reg cam
+
+ pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, num_query,
+ 24, 3, 3)
+ return pred_rotmat, pred_shape, pred_cam
+
+ def init_weights(self):
+ """Initialize weights of the DeformDETR head."""
+ self.transformer.init_weights()
+ if self.loss_cls.use_sigmoid:
+ bias_init = bias_init_with_prob(0.01)
+ for m in self.cls_branches:
+ nn.init.constant_(m.bias, bias_init)
+ for m in self.reg_branches:
+ constant_init(m[-1], 0, bias=0)
+ nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], -2.0)
+ if self.as_two_stage:
+ for m in self.reg_branches:
+ nn.init.constant_(m[-1].bias.data[2:], 0.0)
+
+ def forward(self, mlvl_feats, img_metas):
+ """Forward function.
+
+ Args:
+ mlvl_feats (tuple[Tensor]): Features from the upstream
+ network, each is a 4D-tensor with shape
+ (N, C, H, W).
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ all_cls_scores (Tensor): Outputs from the classification head, \
+ shape [nb_dec, bs, num_query, cls_out_channels]. Note \
+ cls_out_channels should includes background.
+ all_bbox_preds (Tensor): Sigmoid outputs from the regression \
+ head with normalized coordinate format (cx, cy, w, h). \
+ Shape [nb_dec, bs, num_query, 4].
+ enc_outputs_class (Tensor): The score of each point on encode \
+ feature map, has shape (N, h*w, num_class). Only when \
+ as_two_stage is True it would be returned, otherwise \
+ `None` would be returned.
+ enc_outputs_coord (Tensor): The proposal generate from the \
+ encode feature map, has shape (N, h*w, 4). Only when \
+ as_two_stage is True it would be returned, otherwise \
+ `None` would be returned.
+ """
+
+ batch_size = mlvl_feats[0].size(0)
+ input_img_h, input_img_w = img_metas[0]['batch_input_shape']
+ img_masks = mlvl_feats[0].new_ones(
+ (batch_size, input_img_h, input_img_w))
+ for img_id in range(batch_size):
+ img_h, img_w = img_metas[img_id]['img_shape']
+ img_masks[img_id, :img_h, :img_w] = 0
+
+ mlvl_masks = []
+ mlvl_positional_encodings = []
+ for feat in mlvl_feats:
+ mlvl_masks.append(
+ F.interpolate(img_masks[None],
+ size=feat.shape[-2:]).to(torch.bool).squeeze(0))
+ mlvl_positional_encodings.append(
+ self.positional_encoding(mlvl_masks[-1]))
+
+ query_embeds = None
+ if not self.as_two_stage:
+ query_embeds = self.query_embedding.weight
+ hs, init_reference, inter_references, \
+ enc_outputs_class, enc_outputs_coord = self.transformer(
+ mlvl_feats,
+ mlvl_masks,
+ query_embeds,
+ mlvl_positional_encodings,
+ reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501
+ cls_branches=self.cls_branches if self.as_two_stage else None, # noqa:E501
+ smpl_branches=self.smpl_branches if self.with_box_refine else None # noqa: E501
+ )
+ hs = hs.permute(0, 2, 1, 3)
+ outputs_classes = []
+ outputs_coords = []
+ outputs_poses = []
+ outputs_shapes = []
+ outputs_cams = []
+ for lvl in range(hs.shape[0]):
+ if lvl == 0:
+ reference = init_reference
+ else:
+ reference = inter_references[lvl - 1]
+ reference = inverse_sigmoid(reference)
+ outputs_class = self.cls_branches[lvl](hs[lvl])
+ tmp = self.reg_branches[lvl](hs[lvl])
+ if reference.shape[-1] == 4:
+ tmp += reference
+ else:
+ assert reference.shape[-1] == 2
+ tmp[..., :2] += reference
+ outputs_coord = tmp.sigmoid()
+
+ # smpl
+ pred_pose, pred_betas, pred_cam = \
+ self.regress_smpl(lvl, hs[lvl], n_iter=self.niter)
+ outputs_poses.append(pred_pose)
+ outputs_shapes.append(pred_betas)
+ outputs_cams.append(pred_cam)
+ outputs_classes.append(outputs_class)
+ outputs_coords.append(outputs_coord)
+
+ outputs_classes = torch.stack(outputs_classes)
+ outputs_coords = torch.stack(outputs_coords)
+ outputs_poses = torch.stack(outputs_poses)
+ outputs_shapes = torch.stack(outputs_shapes)
+ outputs_cams = torch.stack(outputs_cams)
+ if self.as_two_stage:
+ return outputs_classes, outputs_coords, \
+ outputs_poses, outputs_shapes, outputs_cams, \
+ enc_outputs_class, enc_outputs_coord.sigmoid()
+ else:
+ # return outputs_classes, outputs_coords, \
+ return outputs_poses, outputs_shapes, outputs_cams, \
+ None, None
+
+ @force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
+ def loss(self,
+ all_cls_scores,
+ all_bbox_preds,
+ enc_cls_scores,
+ enc_bbox_preds,
+ gt_bboxes_list,
+ gt_labels_list,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """"Loss function.
+
+ Args:
+ all_cls_scores (Tensor): Classification score of all
+ decoder layers, has shape
+ [nb_dec, bs, num_query, cls_out_channels].
+ all_bbox_preds (Tensor): Sigmoid regression
+ outputs of all decode layers. Each is a 4D-tensor with
+ normalized coordinate format (cx, cy, w, h) and shape
+ [nb_dec, bs, num_query, 4].
+ enc_cls_scores (Tensor): Classification scores of
+ points on encode feature map , has shape
+ (N, h*w, num_classes). Only be passed when as_two_stage is
+ True, otherwise is None.
+ enc_bbox_preds (Tensor): Regression results of each points
+ on the encode feature map, has shape (N, h*w, 4). Only be
+ passed when as_two_stage is True, otherwise is None.
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+ with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image with shape (num_gts, ).
+ img_metas (list[dict]): List of image meta information.
+ gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
+ which can be ignored for each image. Default None.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert gt_bboxes_ignore is None, \
+ f'{self.__class__.__name__} only supports ' \
+ f'for gt_bboxes_ignore setting to None.'
+
+ num_dec_layers = len(all_cls_scores)
+ all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
+ all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
+ all_gt_bboxes_ignore_list = [
+ gt_bboxes_ignore for _ in range(num_dec_layers)
+ ]
+ img_metas_list = [img_metas for _ in range(num_dec_layers)]
+
+ losses_cls, losses_bbox, losses_iou = multi_apply(
+ self.loss_single, all_cls_scores, all_bbox_preds,
+ all_gt_bboxes_list, all_gt_labels_list, img_metas_list,
+ all_gt_bboxes_ignore_list)
+
+ loss_dict = dict()
+ # loss of proposal generated from encode feature map.
+ if enc_cls_scores is not None:
+ binary_labels_list = [
+ torch.zeros_like(gt_labels_list[i])
+ for i in range(len(img_metas))
+ ]
+ enc_loss_cls, enc_losses_bbox, enc_losses_iou = \
+ self.loss_single(enc_cls_scores, enc_bbox_preds,
+ gt_bboxes_list, binary_labels_list,
+ img_metas, gt_bboxes_ignore)
+ loss_dict['enc_loss_cls'] = enc_loss_cls
+ loss_dict['enc_loss_bbox'] = enc_losses_bbox
+ loss_dict['enc_loss_iou'] = enc_losses_iou
+
+ # loss from the last decoder layer
+ loss_dict['loss_cls'] = losses_cls[-1]
+ loss_dict['loss_bbox'] = losses_bbox[-1]
+ loss_dict['loss_iou'] = losses_iou[-1]
+ # loss from other decoder layers
+ num_dec_layer = 0
+ for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
+ losses_bbox[:-1],
+ losses_iou[:-1]):
+ loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
+ loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
+ loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
+ num_dec_layer += 1
+ return loss_dict
+
+ @force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
+ def get_bboxes(self,
+ all_cls_scores,
+ all_bbox_preds,
+ enc_cls_scores,
+ enc_bbox_preds,
+ img_metas,
+ rescale=False):
+ """Transform network outputs for a batch into bbox predictions.
+
+ Args:
+ all_cls_scores (Tensor): Classification score of all
+ decoder layers, has shape
+ [nb_dec, bs, num_query, cls_out_channels].
+ all_bbox_preds (Tensor): Sigmoid regression
+ outputs of all decode layers. Each is a 4D-tensor with
+ normalized coordinate format (cx, cy, w, h) and shape
+ [nb_dec, bs, num_query, 4].
+ enc_cls_scores (Tensor): Classification scores of
+ points on encode feature map , has shape
+ (N, h*w, num_classes). Only be passed when as_two_stage is
+ True, otherwise is None.
+ enc_bbox_preds (Tensor): Regression results of each points
+ on the encode feature map, has shape (N, h*w, 4). Only be
+ passed when as_two_stage is True, otherwise is None.
+ img_metas (list[dict]): Meta information of each image.
+ rescale (bool, optional): If True, return boxes in original
+ image space. Default False.
+
+ Returns:
+ list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
+ The first item is an (n, 5) tensor, where the first 4 columns \
+ are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
+ 5-th column is a score between 0 and 1. The second item is a \
+ (n,) tensor where each item is the predicted class label of \
+ the corresponding box.
+ """
+ cls_scores = all_cls_scores[-1]
+ bbox_preds = all_bbox_preds[-1]
+
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_score = cls_scores[img_id]
+ bbox_pred = bbox_preds[img_id]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ proposals = self._get_bboxes_single(cls_score, bbox_pred,
+ img_shape, scale_factor,
+ rescale)
+ result_list.append(proposals)
+ return result_list
diff --git a/detrsmpl/models/heads/expose_head.py b/detrsmpl/models/heads/expose_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe825875189d5764cb48b6e35f625d4702157ee7
--- /dev/null
+++ b/detrsmpl/models/heads/expose_head.py
@@ -0,0 +1,526 @@
+import os
+import pickle
+from abc import abstractmethod
+from typing import List, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_activation_layer, initialize
+from mmcv.runner.base_module import BaseModule
+
+from detrsmpl.utils.geometry import rot6d_to_rotmat
+
+
+class IterativeRegression(nn.Module):
+ """Regressor for ExPose Head."""
+ def __init__(self,
+ module,
+ mean_param,
+ num_stages=1,
+ append_params=True,
+ learn_mean=False,
+ detach_mean=False,
+ dim=1,
+ **kwargs):
+ super(IterativeRegression, self).__init__()
+ self.module = module
+ self._num_stages = num_stages
+ self.dim = dim
+
+ if learn_mean:
+ self.register_parameter(
+ 'mean_param', nn.Parameter(mean_param, requires_grad=True))
+ else:
+ self.register_buffer('mean_param', mean_param)
+
+ self.append_params = append_params
+ self.detach_mean = detach_mean
+
+ def get_mean(self):
+ """Get the initial mean param."""
+ return self.mean_param.clone()
+
+ @property
+ def num_stages(self):
+ return self._num_stages
+
+ def forward(self,
+ features: torch.Tensor,
+ cond: Optional[torch.Tensor] = None):
+ ''' Computes deltas on top of condition iteratively
+ Parameters
+ ----------
+ features: torch.Tensor
+ Input features
+ '''
+ batch_size = features.shape[0]
+ expand_shape = [batch_size] + [-1] * len(features.shape[1:])
+
+ parameters = []
+ deltas = []
+ module_input = features
+ if cond is None:
+ cond = self.mean_param.expand(*expand_shape).clone()
+
+ # Detach mean
+ if self.detach_mean:
+ cond = cond.detach()
+
+ if self.append_params:
+ assert features is not None, (
+ 'Features are none even though append_params is True')
+ module_input = torch.cat([module_input, cond], dim=self.dim)
+
+ deltas.append(self.module(module_input))
+ num_params = deltas[-1].shape[1]
+ parameters.append(cond[:, :num_params].clone() + deltas[-1])
+
+ for stage_idx in range(1, self.num_stages):
+ module_input = torch.cat([features, parameters[stage_idx - 1]],
+ dim=-1)
+ params_upd = self.module(module_input)
+ deltas.append(params_upd)
+ parameters.append(parameters[stage_idx - 1] + params_upd)
+
+ return parameters
+
+
+class MLP(nn.Module):
+ """MLP
+ Args:
+ input_dim (int): Input dim of MLP.
+ output_dim (int): Output dim of MLP.
+ layers (List): Layer dims.
+ activ_type (str): Activation layer type.
+ dropout (float): Dropout.
+ gain (float): Xavier init gain value.
+ """
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ layers: List[int] = [],
+ activ_type: str = 'relu',
+ dropout: float = 0.5,
+ gain: float = 0.01,
+ ):
+ super(MLP, self).__init__()
+ curr_input_dim = input_dim
+ self.num_layers = len(layers)
+
+ self.blocks = nn.ModuleList()
+ for layer_idx, layer_dim in enumerate(layers):
+ if activ_type == 'none':
+ active = None
+ else:
+ active = build_activation_layer(
+ cfg=dict(type=activ_type, inplace=True))
+ linear = nn.Linear(curr_input_dim, layer_dim, bias=True)
+ curr_input_dim = layer_dim
+
+ layer = []
+ layer.append(linear)
+
+ if active is not None:
+ layer.append(active)
+
+ if dropout > 0.0:
+ layer.append(nn.Dropout(dropout))
+
+ block = nn.Sequential(*layer)
+ self.add_module('layer_{:03d}'.format(layer_idx), block)
+ self.blocks.append(block)
+
+ self.output_layer = nn.Linear(curr_input_dim, output_dim)
+ initialize(self.output_layer,
+ init_cfg=dict(type='Xavier',
+ gain=gain,
+ distribution='uniform'))
+
+ def forward(self, module_input):
+ curr_input = module_input
+ for block in self.blocks:
+ curr_input = block(curr_input)
+ return self.output_layer(curr_input)
+
+
+class ContinuousRotReprDecoder:
+ """ExPose Decoder Decode latent representation to rotation.
+
+ Args:
+ num_angles (int): Joint num.
+ dtype: dtype.
+ mean (torch.tensor): Mean value for params.
+ """
+ def __init__(self, num_angles, dtype=torch.float32, mean=None):
+ self.num_angles = num_angles
+ self.dtype = dtype
+
+ if isinstance(mean, dict):
+ mean = mean.get('cont_rot_repr', None)
+ if mean is None:
+ mean = torch.tensor([1.0, 0.0, 0.0, 1.0, 0.0, 0.0],
+ dtype=self.dtype).unsqueeze(dim=0).expand(
+ self.num_angles, -1).contiguous().view(-1)
+ if not torch.is_tensor(mean):
+ mean = torch.tensor(mean)
+ mean = mean.reshape(-1, 6)
+
+ if mean.shape[0] < self.num_angles:
+ mean = mean.repeat(self.num_angles // mean.shape[0] + 1,
+ 1).contiguous()
+ mean = mean[:self.num_angles]
+ elif mean.shape[0] > self.num_angles:
+ mean = mean[:self.num_angles]
+
+ mean = mean.reshape(-1)
+ self.mean = mean
+
+ def get_mean(self):
+ return self.mean.clone()
+
+ def get_dim_size(self):
+ return self.num_angles * 6
+
+ def __call__(self, module_input):
+ batch_size = module_input.shape[0]
+ reshaped_input = module_input.view(-1, 6)
+ rot_mats = rot6d_to_rotmat(reshaped_input)
+ # aa = rot6d_to_aa(reshaped_input)
+ # return aa.view(batch_size,-1,3)
+ return rot_mats.view(batch_size, -1, 3, 3)
+
+
+class ExPoseHead(BaseModule):
+ """General Head for ExPose."""
+ def __init__(self, init_cfg=None):
+ super().__init__(init_cfg)
+
+ def load_regressor(self,
+ input_feat_dim: int = 2048,
+ param_mean: torch.Tensor = None,
+ regressor_cfg: dict = None):
+ """Build regressor for ExPose Head."""
+ param_dim = param_mean.numel()
+ regressor = MLP(input_feat_dim + param_dim, param_dim, **regressor_cfg)
+ self.regressor = IterativeRegression(regressor,
+ param_mean,
+ num_stages=3)
+
+ def load_param_decoder(self, mean_poses_dict):
+ """Build decoders for each pose."""
+ start = 0
+ mean_lst = []
+ self.pose_param_decoders = {}
+ for pose_param in self.pose_param_conf:
+ pose_name = pose_param['name']
+ num_angles = pose_param['num_angles']
+ if pose_param['use_mean']:
+ pose_decoder = ContinuousRotReprDecoder(
+ num_angles,
+ dtype=torch.float32,
+ mean=mean_poses_dict.get(pose_name, None))
+ else:
+ pose_decoder = ContinuousRotReprDecoder(num_angles,
+ dtype=torch.float32,
+ mean=None)
+ self.pose_param_decoders['{}_decoder'.format(
+ pose_name)] = pose_decoder
+ pose_dim = pose_decoder.get_dim_size()
+ pose_mean = pose_decoder.get_mean()
+ if pose_param['rotate_axis_x']:
+ pose_mean[3] = -1
+ idxs = list(range(start, start + pose_dim))
+ idxs = torch.tensor(idxs, dtype=torch.long)
+ self.register_buffer('{}_idxs'.format(pose_name), idxs)
+ start += pose_dim
+ mean_lst.append(pose_mean.view(-1))
+ return start, mean_lst
+
+ def get_camera_param(self, camera_cfg):
+ """Build camera param."""
+ camera_pos_scale = camera_cfg.get('pos_func')
+ if camera_pos_scale == 'softplus':
+ camera_scale_func = F.softplus
+ elif camera_pos_scale == 'exp':
+ camera_scale_func = torch.exp
+ elif camera_pos_scale == 'none' or camera_pos_scale == 'None':
+
+ def func(x):
+ return x
+
+ camera_scale_func = func
+ mean_scale = camera_cfg.get('mean_scale', 0.9)
+ if camera_pos_scale == 'softplus':
+ mean_scale = np.log(np.exp(mean_scale) - 1)
+ elif camera_pos_scale == 'exp':
+ mean_scale = np.log(mean_scale)
+ camera_mean = torch.tensor([mean_scale, 0.0, 0.0], dtype=torch.float32)
+ camera_param_dim = 3
+ return camera_mean, camera_param_dim, camera_scale_func
+
+ def flat_params_to_dict(self, param_tensor):
+ """Turn param tensors to dict."""
+ smplx_dict = {}
+ raw_dict = {}
+ for pose_param in self.pose_param_conf:
+ pose_name = pose_param['name']
+ pose_idxs = getattr(self, f'{pose_name}_idxs')
+ decoder = self.pose_param_decoders[f'{pose_name}_decoder']
+ pose = torch.index_select(param_tensor, 1, pose_idxs)
+ raw_dict[f'raw_{pose_name}'] = pose.clone()
+ smplx_dict[pose_name] = decoder(pose)
+ return smplx_dict, raw_dict
+
+ def get_mean(self, name, batch_size):
+ """Get mean value of params."""
+ mean_param = self.regressor.get_mean().view(-1)
+ if name is None:
+ return mean_param.reshape(1, -1).expand(batch_size, -1)
+ idxs = getattr(self, f'{name}_idxs')
+ return mean_param[idxs].reshape(1, -1).expand(batch_size, -1)
+
+ def get_num_betas(self):
+ return self.num_betas
+
+ def get_num_expression_coeffs(self):
+ return self.num_expression_coeffs
+
+ @abstractmethod
+ def forward(self, features):
+ pass
+
+
+class ExPoseBodyHead(ExPoseHead):
+ """Head for ExPose Body Model."""
+ def __init__(self,
+ init_cfg=None,
+ num_betas: int = 10,
+ num_expression_coeffs: int = 10,
+ mean_pose_path: str = '',
+ shape_mean_path: str = '',
+ pose_param_conf: list = None,
+ input_feat_dim: int = 2048,
+ regressor_cfg: dict = None,
+ camera_cfg: dict = None):
+ super().__init__(init_cfg)
+ self.num_betas = num_betas
+ self.num_expression_coeffs = num_expression_coeffs
+ # poses
+ self.pose_param_conf = pose_param_conf
+ mean_poses_dict = {}
+ if os.path.exists(mean_pose_path):
+ with open(mean_pose_path, 'rb') as f:
+ mean_poses_dict = pickle.load(f)
+ start, mean_lst = self.load_param_decoder(mean_poses_dict)
+
+ # shape
+ if os.path.exists(shape_mean_path):
+ shape_mean = torch.from_numpy(
+ np.load(shape_mean_path,
+ allow_pickle=True)).to(dtype=torch.float32).reshape(
+ 1, -1)[:, :num_betas].reshape(-1)
+ else:
+ shape_mean = torch.zeros([num_betas], dtype=torch.float32)
+ shape_idxs = list(range(start, start + num_betas))
+ self.register_buffer('shape_idxs',
+ torch.tensor(shape_idxs, dtype=torch.long))
+ start += num_betas
+ mean_lst.append(shape_mean.view(-1))
+
+ # expression
+ expression_mean = torch.zeros([num_expression_coeffs],
+ dtype=torch.float32)
+ expression_idxs = list(range(start, start + num_expression_coeffs))
+ self.register_buffer('expression_idxs',
+ torch.tensor(expression_idxs, dtype=torch.long))
+ start += num_expression_coeffs
+ mean_lst.append(expression_mean.view(-1))
+
+ # camera
+ mean, dim, scale_func = self.get_camera_param(camera_cfg)
+ self.camera_scale_func = scale_func
+ camera_idxs = list(range(start, start + dim))
+ self.register_buffer('camera_idxs',
+ torch.tensor(camera_idxs, dtype=torch.long))
+ start += dim
+ mean_lst.append(mean)
+
+ param_mean = torch.cat(mean_lst).view(1, -1)
+ self.load_regressor(input_feat_dim, param_mean, regressor_cfg)
+
+ def forward(self, features):
+ """Forward function of ExPose Body Head.
+
+ Args:
+ features (List[torch.tensor]) : Output of restnet.
+ cond : Initial params. If none, use the mean params.
+ """
+ body_parameters = self.regressor(features)[-1]
+ params_dict, raw_dict = self.flat_params_to_dict(body_parameters)
+ params_dict['betas'] = torch.index_select(body_parameters, 1,
+ self.shape_idxs)
+ params_dict['expression'] = torch.index_select(body_parameters, 1,
+ self.expression_idxs)
+
+ camera_params = torch.index_select(body_parameters, 1,
+ self.camera_idxs)
+ scale = camera_params[:, 0:1]
+ translation = camera_params[:, 1:3]
+ scale = self.camera_scale_func(scale)
+ camera_params = torch.cat([scale, translation], dim=1)
+ return {
+ 'pred_param': params_dict,
+ 'pred_cam': camera_params,
+ 'pred_raw': raw_dict
+ }
+
+
+class ExPoseHandHead(ExPoseHead):
+ """Head for ExPose Hand Model."""
+ def __init__(self,
+ init_cfg=None,
+ num_betas: int = 10,
+ mean_pose_path: str = '',
+ pose_param_conf: list = None,
+ input_feat_dim: int = 2048,
+ regressor_cfg: dict = None,
+ camera_cfg: dict = None):
+ super().__init__(init_cfg)
+ self.num_betas = num_betas
+ # poses
+ self.pose_param_conf = pose_param_conf
+ mean_poses_dict = {}
+ if os.path.exists(mean_pose_path):
+ with open(mean_pose_path, 'rb') as f:
+ mean_poses_dict = pickle.load(f)
+ start, mean_lst = self.load_param_decoder(mean_poses_dict)
+
+ shape_mean = torch.zeros([num_betas], dtype=torch.float32)
+ shape_idxs = list(range(start, start + num_betas))
+ self.register_buffer('shape_idxs',
+ torch.tensor(shape_idxs, dtype=torch.long))
+ start += num_betas
+ mean_lst.append(shape_mean.view(-1))
+
+ # camera
+ mean, dim, scale_func = self.get_camera_param(camera_cfg)
+ self.camera_scale_func = scale_func
+ camera_idxs = list(range(start, start + dim))
+ self.register_buffer('camera_idxs',
+ torch.tensor(camera_idxs, dtype=torch.long))
+ start += dim
+ mean_lst.append(mean)
+
+ param_mean = torch.cat(mean_lst).view(1, -1)
+ self.load_regressor(input_feat_dim, param_mean, regressor_cfg)
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+
+ def forward(self, features, cond=None):
+ """Forward function of ExPose Hand Head.
+
+ Args:
+ features (List[torch.tensor]) : Output of restnet.
+ cond : Initial params. If none, use the mean params.
+ """
+ batch_size = features[-1].size(0)
+ features = self.avgpool(features[-1]).view(batch_size, -1)
+ hand_parameters = self.regressor(features, cond=cond)[-1]
+ params_dict, raw_dict = self.flat_params_to_dict(hand_parameters)
+ params_dict['betas'] = torch.index_select(hand_parameters, 1,
+ self.shape_idxs)
+
+ camera_params = torch.index_select(hand_parameters, 1,
+ self.camera_idxs)
+ scale = camera_params[:, 0:1]
+ translation = camera_params[:, 1:3]
+ scale = self.camera_scale_func(scale)
+ camera_params = torch.cat([scale, translation], dim=1)
+ return {
+ 'pred_param': params_dict,
+ 'pred_cam': camera_params,
+ 'pred_raw': raw_dict
+ }
+
+
+class ExPoseFaceHead(ExPoseHead):
+ """Head for ExPose Face Model."""
+ def __init__(self,
+ init_cfg=None,
+ num_betas: int = 10,
+ num_expression_coeffs: int = 10,
+ pose_param_conf: list = None,
+ mean_pose_path: str = '',
+ input_feat_dim: int = 2048,
+ regressor_cfg: dict = None,
+ camera_cfg: dict = None):
+ super().__init__(init_cfg)
+ self.num_betas = num_betas
+ self.num_expression_coeffs = num_expression_coeffs
+ # poses
+ self.pose_param_conf = pose_param_conf
+ mean_poses_dict = {}
+ if os.path.exists(mean_pose_path):
+ with open(mean_pose_path, 'rb') as f:
+ mean_poses_dict = pickle.load(f)
+ start, mean_lst = self.load_param_decoder(mean_poses_dict)
+
+ # shape
+ shape_mean = torch.zeros([num_betas], dtype=torch.float32)
+ shape_idxs = list(range(start, start + num_betas))
+ self.register_buffer('shape_idxs',
+ torch.tensor(shape_idxs, dtype=torch.long))
+ start += num_betas
+ mean_lst.append(shape_mean.view(-1))
+
+ # expression
+ expression_mean = torch.zeros([num_expression_coeffs],
+ dtype=torch.float32)
+ expression_idxs = list(range(start, start + num_expression_coeffs))
+ self.register_buffer('expression_idxs',
+ torch.tensor(expression_idxs, dtype=torch.long))
+ start += num_expression_coeffs
+ mean_lst.append(expression_mean.view(-1))
+
+ # camera
+ mean, dim, scale_func = self.get_camera_param(camera_cfg)
+ self.camera_scale_func = scale_func
+ camera_idxs = list(range(start, start + dim))
+ self.register_buffer('camera_idxs',
+ torch.tensor(camera_idxs, dtype=torch.long))
+ start += dim
+ mean_lst.append(mean)
+
+ param_mean = torch.cat(mean_lst).view(1, -1)
+ self.load_regressor(input_feat_dim, param_mean, regressor_cfg)
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+
+ def forward(self, features, cond=None):
+ """Forward function of ExPose Face Head.
+
+ Args:
+ features (List[torch.tensor]) : Output of restnet.
+ cond : Initial params. If none, use the mean params.
+ """
+ batch_size = features[-1].size(0)
+ features = self.avgpool(features[-1]).view(batch_size, -1)
+ head_parameters = self.regressor(features, cond=cond)[-1]
+ params_dict, raw_dict = self.flat_params_to_dict(head_parameters)
+ params_dict['betas'] = torch.index_select(head_parameters, 1,
+ self.shape_idxs)
+ params_dict['expression'] = torch.index_select(head_parameters, 1,
+ self.expression_idxs)
+
+ camera_params = torch.index_select(head_parameters, 1,
+ self.camera_idxs)
+ scale = camera_params[:, 0:1]
+ translation = camera_params[:, 1:3]
+ scale = self.camera_scale_func(scale)
+ camera_params = torch.cat([scale, translation], dim=1)
+ return {
+ 'pred_param': params_dict,
+ 'pred_cam': camera_params,
+ 'pred_raw': raw_dict
+ }
diff --git a/detrsmpl/models/heads/hmr_head.py b/detrsmpl/models/heads/hmr_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..379845d2a51d48af72116e9a4414698080288395
--- /dev/null
+++ b/detrsmpl/models/heads/hmr_head.py
@@ -0,0 +1,99 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.runner.base_module import BaseModule
+
+from detrsmpl.utils.geometry import rot6d_to_rotmat
+
+
+class HMRHead(BaseModule):
+ def __init__(self,
+ feat_dim,
+ smpl_mean_params=None,
+ npose=144,
+ nbeta=10,
+ ncam=3,
+ hdim=1024,
+ init_cfg=None):
+ super(HMRHead, self).__init__(init_cfg=init_cfg)
+ self.fc1 = nn.Linear(feat_dim + npose + nbeta + ncam, hdim)
+ self.drop1 = nn.Dropout()
+ self.fc2 = nn.Linear(hdim, hdim)
+ self.drop2 = nn.Dropout()
+ self.decpose = nn.Linear(hdim, npose)
+ self.decshape = nn.Linear(hdim, nbeta)
+ self.deccam = nn.Linear(hdim, ncam)
+
+ nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
+ nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
+ nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
+
+ if smpl_mean_params is None:
+ init_pose = torch.zeros([1, npose])
+ init_shape = torch.zeros([1, nbeta])
+ init_cam = torch.FloatTensor([[1, 0, 0]])
+ else:
+ mean_params = np.load(smpl_mean_params)
+ init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
+ init_shape = torch.from_numpy(
+ mean_params['shape'][:].astype('float32')).unsqueeze(0)
+ init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)
+ self.register_buffer('init_pose', init_pose)
+ self.register_buffer('init_shape', init_shape)
+ self.register_buffer('init_cam', init_cam)
+
+ def forward(self,
+ x,
+ init_pose=None,
+ init_shape=None,
+ init_cam=None,
+ n_iter=3):
+
+ # hmr head only support one layer feature
+ if isinstance(x, list) or isinstance(x, tuple):
+ x = x[-1]
+
+ output_seq = False
+ if len(x.shape) == 4:
+ # use feature from the last layer of the backbone
+ # apply global average pooling on the feature map
+ x = x.mean(dim=-1).mean(dim=-1)
+ elif len(x.shape) == 3:
+ # temporal feature
+ output_seq = True
+ B, T, L = x.shape
+ x = x.view(-1, L)
+
+ batch_size = x.shape[0]
+ if init_pose is None:
+ init_pose = self.init_pose.expand(batch_size, -1)
+ if init_shape is None:
+ init_shape = self.init_shape.expand(batch_size, -1)
+ if init_cam is None:
+ init_cam = self.init_cam.expand(batch_size, -1)
+
+ pred_pose = init_pose
+ pred_shape = init_shape
+ pred_cam = init_cam
+ for i in range(n_iter):
+ xc = torch.cat([x, pred_pose, pred_shape, pred_cam], 1)
+ xc = self.fc1(xc)
+ xc = self.drop1(xc)
+ xc = self.fc2(xc)
+ xc = self.drop2(xc)
+ pred_pose = self.decpose(xc) + pred_pose
+ pred_shape = self.decshape(xc) + pred_shape
+ pred_cam = self.deccam(xc) + pred_cam
+
+ pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3)
+
+ if output_seq:
+ pred_rotmat = pred_rotmat.view(B, T, 24, 3, 3)
+ pred_shape = pred_shape.view(B, T, 10)
+ pred_cam = pred_cam.view(B, T, 3)
+ output = {
+ 'pred_pose': pred_rotmat,
+ 'pred_shape': pred_shape,
+ 'pred_cam': pred_cam
+ }
+ return output
diff --git a/detrsmpl/models/heads/hybrik_head.py b/detrsmpl/models/heads/hybrik_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cea3dcb83624b4611b6f8067db118a222671e44
--- /dev/null
+++ b/detrsmpl/models/heads/hybrik_head.py
@@ -0,0 +1,443 @@
+import numpy as np
+import torch
+import torch.cuda.comm
+import torch.nn as nn
+from mmcv.runner.base_module import BaseModule
+from torch.nn import functional as F
+
+from detrsmpl.core.conventions.keypoints_mapping import get_flip_pairs
+
+
+def norm_heatmap(norm_type, heatmap):
+ """Normalize heatmap.
+
+ Args:
+ norm_type (str):
+ type of normalization. Currently only 'softmax' is supported
+ heatmap (torch.Tensor):
+ model output heatmap with shape (Bx29xF^2) where F^2 refers to
+ number of squared feature channels F
+
+ Returns:
+ heatmap (torch.Tensor):
+ normalized heatmap according to specified type with
+ shape (Bx29xF^2)
+ """
+
+ # Input tensor shape: [N,C,...]
+ shape = heatmap.shape
+ if norm_type == 'softmax':
+ heatmap = heatmap.reshape(*shape[:2], -1)
+ # global soft max
+ heatmap = F.softmax(heatmap, 2)
+ return heatmap.reshape(*shape)
+ else:
+ raise NotImplementedError
+
+
+class HybrIKHead(BaseModule):
+ """HybrIK parameters regressor head.
+
+ Args:
+ feature_channel (int):
+ Number of input channels
+ deconv_dim (List[int]):
+ List of deconvolution dimensions
+ num_joints (int):
+ Number of keypoints
+ depth_dim (int):
+ Depth dimension
+ height_dim (int):
+ Height dimension
+ width_dim (int):
+ Width dimension
+ smpl_mean_params (str):
+ file name of the mean SMPL parameters
+ """
+ def __init__(
+ self,
+ feature_channel=512,
+ deconv_dim=[256, 256, 256],
+ num_joints=29,
+ depth_dim=64,
+ height_dim=64,
+ width_dim=64,
+ smpl_mean_params=None,
+ ):
+
+ super(HybrIKHead, self).__init__()
+
+ self.deconv_dim = deconv_dim
+ self._norm_layer = nn.BatchNorm2d
+ self.num_joints = num_joints
+ self.norm_type = 'softmax'
+ self.depth_dim = depth_dim
+ self.height_dim = height_dim
+ self.width_dim = width_dim
+ self.smpl_dtype = torch.float32
+ self.feature_channel = feature_channel
+
+ self.deconv_layers = self._make_deconv_layer()
+ self.final_layer = nn.Conv2d(self.deconv_dim[2],
+ self.num_joints * self.depth_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.joint_pairs_24 = get_flip_pairs('smpl')
+ self.joint_pairs_29 = get_flip_pairs('hybrik_29')
+
+ self.leaf_pairs = ((0, 1), (3, 4))
+ self.root_idx_smpl = 0
+
+ # mean shape
+ init_shape = np.load(smpl_mean_params)
+ self.register_buffer('init_shape', torch.Tensor(init_shape).float())
+
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc1 = nn.Linear(self.feature_channel, 1024)
+ self.drop1 = nn.Dropout(p=0.5)
+ self.fc2 = nn.Linear(1024, 1024)
+ self.drop2 = nn.Dropout(p=0.5)
+ self.decshape = nn.Linear(1024, 10)
+ self.decphi = nn.Linear(1024, 23 * 2) # [cos(phi), sin(phi)]
+
+ def _make_deconv_layer(self):
+ deconv_layers = []
+ deconv1 = nn.ConvTranspose2d(self.feature_channel,
+ self.deconv_dim[0],
+ kernel_size=4,
+ stride=2,
+ padding=int(4 / 2) - 1,
+ bias=False)
+ bn1 = self._norm_layer(self.deconv_dim[0])
+ deconv2 = nn.ConvTranspose2d(self.deconv_dim[0],
+ self.deconv_dim[1],
+ kernel_size=4,
+ stride=2,
+ padding=int(4 / 2) - 1,
+ bias=False)
+ bn2 = self._norm_layer(self.deconv_dim[1])
+ deconv3 = nn.ConvTranspose2d(self.deconv_dim[1],
+ self.deconv_dim[2],
+ kernel_size=4,
+ stride=2,
+ padding=int(4 / 2) - 1,
+ bias=False)
+ bn3 = self._norm_layer(self.deconv_dim[2])
+
+ deconv_layers.append(deconv1)
+ deconv_layers.append(bn1)
+ deconv_layers.append(nn.ReLU(inplace=True))
+ deconv_layers.append(deconv2)
+ deconv_layers.append(bn2)
+ deconv_layers.append(nn.ReLU(inplace=True))
+ deconv_layers.append(deconv3)
+ deconv_layers.append(bn3)
+ deconv_layers.append(nn.ReLU(inplace=True))
+
+ return nn.Sequential(*deconv_layers)
+
+ def _initialize(self):
+ for name, m in self.deconv_layers.named_modules():
+ if isinstance(m, nn.ConvTranspose2d):
+ nn.init.normal_(m.weight, std=0.001)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ for m in self.final_layer.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, std=0.001)
+ nn.init.constant_(m.bias, 0)
+
+ def uvd_to_cam(self,
+ uvd_jts,
+ trans_inv,
+ intrinsic_param,
+ joint_root,
+ depth_factor,
+ return_relative=True):
+ """Project uvd coordinates to camera frame.
+
+ Args:
+ uvd_jts (torch.Tensor):
+ uvd coordinates with shape (BxNum_jointsx3)
+ trans_inv (torch.Tensor):
+ inverse affine transformation matrix with shape (Bx2x3)
+ intrinsic_param (torch.Tensor):
+ camera intrinsic matrix with shape (Bx3x3)
+ joint_root (torch.Tensor):
+ root joint coordinate with shape (Bx3)
+ depth_factor (float):
+ depth factor with shape (Bx1)
+ return_relative (bool):
+ Store True to return root normalized relative coordinates.
+ Default: True.
+
+ Returns:
+ xyz_jts (torch.Tensor):
+ uvd coordinates in camera frame with shape (BxNum_jointsx3)
+ """
+ assert uvd_jts.dim() == 3 and uvd_jts.shape[2] == 3, uvd_jts.shape
+ uvd_jts_new = uvd_jts.clone()
+ # if torch.sum(torch.isnan(uvd_jts)) > 0:
+ # aaa= 1
+ assert torch.sum(torch.isnan(uvd_jts)) == 0, ('uvd_jts', uvd_jts)
+
+ # remap uv coordinate to input space
+ uvd_jts_new[:, :, 0] = (uvd_jts[:, :, 0] + 0.5) * self.width_dim * 4
+ uvd_jts_new[:, :, 1] = (uvd_jts[:, :, 1] + 0.5) * self.height_dim * 4
+ # remap d to mm
+ uvd_jts_new[:, :, 2] = uvd_jts[:, :, 2] * depth_factor
+ assert torch.sum(torch.isnan(uvd_jts_new)) == 0, ('uvd_jts_new',
+ uvd_jts_new)
+
+ dz = uvd_jts_new[:, :, 2]
+
+ # transform in-bbox coordinate to image coordinate
+ uv_homo_jts = torch.cat(
+ (uvd_jts_new[:, :, :2], torch.ones_like(uvd_jts_new)[:, :, 2:]),
+ dim=2)
+ # batch-wise matrix multiply : (B,1,2,3) * (B,K,3,1) -> (B,K,2,1)
+ uv_jts = torch.matmul(trans_inv.unsqueeze(1),
+ uv_homo_jts.unsqueeze(-1))
+ # transform (u,v,1) to (x,y,z)
+ cam_2d_homo = torch.cat((uv_jts, torch.ones_like(uv_jts)[:, :, :1, :]),
+ dim=2)
+ # batch-wise matrix multiply : (B,1,3,3) * (B,K,3,1) -> (B,K,3,1)
+ xyz_jts = torch.matmul(intrinsic_param.unsqueeze(1), cam_2d_homo)
+ xyz_jts = xyz_jts.squeeze(dim=3)
+ # recover absolute z : (B,K) + (B,1)
+ abs_z = dz + joint_root[:, 2].unsqueeze(-1)
+ # multiply absolute z : (B,K,3) * (B,K,1)
+ xyz_jts = xyz_jts * abs_z.unsqueeze(-1)
+
+ if return_relative:
+ # (B,K,3) - (B,1,3)
+ xyz_jts = xyz_jts - joint_root.unsqueeze(1)
+
+ xyz_jts = xyz_jts / depth_factor.unsqueeze(-1)
+
+ return xyz_jts
+
+ def flip_uvd_coord(self, pred_jts, flip=False, flatten=True):
+ """Flip uvd coordinates.
+
+ Args:
+ pred_jts (torch.Tensor):
+ predicted uvd coordinates with shape (Bx87)
+ flip (bool):
+ Store True to flip uvd coordinates. Default: False.
+ flatten (bool):
+ Store True to reshape uvd_coordinates to shape (Bx29x3)
+ Default: True
+
+ Returns:
+ pred_jts (torch.Tensor):
+ flipped uvd coordinates with shape (Bx29x3)
+ """
+ if flatten:
+ assert pred_jts.dim() == 2
+ num_batches = pred_jts.shape[0]
+ pred_jts = pred_jts.reshape(num_batches, self.num_joints, 3)
+ else:
+ assert pred_jts.dim() == 3
+ num_batches = pred_jts.shape[0]
+
+ # flip
+ if flip:
+ pred_jts[:, :, 0] = -pred_jts[:, :, 0]
+ else:
+ pred_jts[:, :, 0] = -1 / self.width_dim - pred_jts[:, :, 0]
+
+ for pair in self.joint_pairs_29:
+ dim0, dim1 = pair
+ idx = torch.Tensor((dim0, dim1)).long()
+ inv_idx = torch.Tensor((dim1, dim0)).long()
+ pred_jts[:, idx] = pred_jts[:, inv_idx]
+
+ return pred_jts
+
+ def flip_phi(self, pred_phi):
+ """Flip phi.
+
+ Args:
+ pred_phi (torch.Tensor): phi in shape (Num_twistx2)
+
+ Returns:
+ pred_phi (torch.Tensor): flipped phi in shape (Num_twistx2)
+ """
+ pred_phi[:, :, 1] = -1 * pred_phi[:, :, 1]
+
+ for pair in self.joint_pairs_24:
+ dim0, dim1 = pair
+ idx = torch.Tensor((dim0 - 1, dim1 - 1)).long()
+ inv_idx = torch.Tensor((dim1 - 1, dim0 - 1)).long()
+ pred_phi[:, idx] = pred_phi[:, inv_idx]
+
+ return pred_phi
+
+ def forward(self,
+ feature,
+ trans_inv,
+ intrinsic_param,
+ joint_root,
+ depth_factor,
+ smpl_layer,
+ flip_item=None,
+ flip_output=False):
+ """Forward function.
+
+ Args:
+ feature (torch.Tensor): features extracted from backbone
+ trans_inv (torch.Tensor):
+ inverse affine transformation matrix with shape (Bx2x3)
+ intrinsic_param (torch.Tensor):
+ camera intrinsic matrix with shape (Bx3x3)
+ joint_root (torch.Tensor):
+ root joint coordinate with shape (Bx3)
+ depth_factor (float):
+ depth factor with shape (Bx1)
+ smpl_layer (torch.Tensor):
+ smpl body model
+ flip_item (List[torch.Tensor]|None):
+ list containing items to flip
+ flip_output (bool):
+ Store True to flip output. Default: False
+
+ Returns:
+ output (dict): Dict containing model predictions.
+ """
+ batch_size = feature.shape[0]
+
+ x0 = feature
+ out = self.deconv_layers(x0)
+ out = self.final_layer(out)
+
+ out = out.reshape((out.shape[0], self.num_joints, -1))
+ out = norm_heatmap(self.norm_type, out)
+ assert out.dim() == 3, out.shape
+
+ if self.norm_type == 'sigmoid':
+ maxvals, _ = torch.max(out, dim=2, keepdim=True)
+ else:
+ maxvals = torch.ones((*out.shape[:2], 1),
+ dtype=torch.float,
+ device=out.device)
+
+ heatmaps = out / out.sum(dim=2, keepdim=True)
+
+ heatmaps = heatmaps.reshape(
+ (heatmaps.shape[0], self.num_joints, self.depth_dim,
+ self.height_dim, self.width_dim))
+
+ hm_x = heatmaps.sum((2, 3))
+ hm_y = heatmaps.sum((2, 4))
+ hm_z = heatmaps.sum((3, 4))
+
+ hm_x = hm_x * torch.cuda.comm.broadcast(torch.arange(
+ hm_x.shape[-1]).type(torch.cuda.FloatTensor),
+ devices=[hm_x.device.index])[0]
+ hm_y = hm_y * torch.cuda.comm.broadcast(torch.arange(
+ hm_y.shape[-1]).type(torch.cuda.FloatTensor),
+ devices=[hm_y.device.index])[0]
+ hm_z = hm_z * torch.cuda.comm.broadcast(torch.arange(
+ hm_z.shape[-1]).type(torch.cuda.FloatTensor),
+ devices=[hm_z.device.index])[0]
+ coord_x = hm_x.sum(dim=2, keepdim=True)
+ coord_y = hm_y.sum(dim=2, keepdim=True)
+ coord_z = hm_z.sum(dim=2, keepdim=True)
+
+ coord_x = coord_x / float(self.width_dim) - 0.5
+ coord_y = coord_y / float(self.height_dim) - 0.5
+ coord_z = coord_z / float(self.depth_dim) - 0.5
+
+ # -0.5 ~ 0.5
+ pred_uvd_jts_29 = torch.cat((coord_x, coord_y, coord_z), dim=2)
+
+ pred_uvd_jts_29_flat = pred_uvd_jts_29.reshape(
+ (batch_size, self.num_joints * 3))
+
+ x0 = self.avg_pool(x0)
+ x0 = x0.view(x0.size(0), -1)
+ init_shape = self.init_shape.expand(batch_size, -1) # (B, 10,)
+
+ xc = x0
+
+ xc = self.fc1(xc)
+ xc = self.drop1(xc)
+ xc = self.fc2(xc)
+ xc = self.drop2(xc)
+
+ delta_shape = self.decshape(xc)
+ pred_shape = delta_shape + init_shape
+ pred_phi = self.decphi(xc)
+
+ if flip_item is not None:
+ assert flip_output
+ pred_uvd_jts_29_orig, pred_phi_orig, pred_leaf_orig, \
+ pred_shape_orig = flip_item
+
+ if flip_output:
+ pred_uvd_jts_29 = self.flip_uvd_coord(pred_uvd_jts_29,
+ flatten=False,
+ shift=True)
+ if flip_output and flip_item is not None:
+ pred_uvd_jts_29 = (pred_uvd_jts_29 + pred_uvd_jts_29_orig.reshape(
+ batch_size, 29, 3)) / 2
+
+ pred_uvd_jts_29_flat = pred_uvd_jts_29.reshape(
+ (batch_size, self.num_joints * 3))
+
+ # -0.5 ~ 0.5
+ # Rotate back
+ pred_xyz_jts_29 = self.uvd_to_cam(pred_uvd_jts_29, trans_inv,
+ intrinsic_param, joint_root,
+ depth_factor)
+ assert torch.sum(
+ torch.isnan(pred_xyz_jts_29)) == 0, ('pred_xyz_jts_29',
+ pred_xyz_jts_29)
+
+ pred_xyz_jts_29 = pred_xyz_jts_29 - \
+ pred_xyz_jts_29[:, self.root_idx_smpl, :].unsqueeze(1)
+
+ pred_phi = pred_phi.reshape(batch_size, 23, 2)
+
+ if flip_output:
+ pred_phi = self.flip_phi(pred_phi)
+
+ if flip_output and flip_item is not None:
+ pred_phi = (pred_phi + pred_phi_orig) / 2
+ pred_shape = (pred_shape + pred_shape_orig) / 2
+
+ hybrik_output = smpl_layer(
+ pose_skeleton=pred_xyz_jts_29.type(self.smpl_dtype) * 2,
+ betas=pred_shape.type(self.smpl_dtype),
+ phis=pred_phi.type(self.smpl_dtype),
+ global_orient=None,
+ return_verts=True)
+ pred_vertices = hybrik_output['vertices'].float()
+ # -0.5 ~ 0.5
+ pred_xyz_jts_24_struct = hybrik_output['joints'].float() / 2
+ # -0.5 ~ 0.5
+ pred_xyz_jts_17 = hybrik_output['joints_from_verts'].float() / 2
+ pred_poses = hybrik_output['poses'].float().reshape(
+ batch_size, 24, 3, 3)
+ pred_xyz_jts_24 = pred_xyz_jts_29[:, :24, :].reshape(batch_size, 72)
+ pred_xyz_jts_24_struct = pred_xyz_jts_24_struct.reshape(batch_size, 72)
+ pred_xyz_jts_17 = pred_xyz_jts_17.reshape(batch_size, 17 * 3)
+
+ output = {
+ 'pred_phi': pred_phi,
+ 'pred_delta_shape': delta_shape,
+ 'pred_shape': pred_shape,
+ 'pred_pose': pred_poses,
+ 'pred_uvd_jts': pred_uvd_jts_29_flat,
+ 'pred_xyz_jts_24': pred_xyz_jts_24,
+ 'pred_xyz_jts_24_struct': pred_xyz_jts_24_struct,
+ 'pred_xyz_jts_17': pred_xyz_jts_17,
+ 'pred_vertices': pred_vertices,
+ 'maxvals': maxvals,
+ }
+
+ return output
diff --git a/detrsmpl/models/heads/pare_head.py b/detrsmpl/models/heads/pare_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a27d4cae5b04103bbd71ea2fd99a048830e9cec9
--- /dev/null
+++ b/detrsmpl/models/heads/pare_head.py
@@ -0,0 +1,611 @@
+"""This script is modified from [PARE](https://github.com/
+mkocabas/PARE/tree/master/pare/models/layers).
+
+Original license please see docs/additional_licenses.md.
+"""
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.runner.base_module import BaseModule
+from torch.nn.modules.utils import _pair
+
+from detrsmpl.utils.geometry import rot6d_to_rotmat
+
+
+class LocallyConnected2d(nn.Module):
+ """Locally Connected Layer.
+
+ Args:
+ in_channels (int):
+ the in channel of the features.
+ out_channels (int):
+ the out channel of the features.
+ output_size (List[int]):
+ the output size of the features.
+ kernel_size (int):
+ the size of the kernel.
+ stride (int):
+ the stride of the kernel.
+ Returns:
+ attended_features (torch.Tensor):
+ attended feature maps
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ output_size,
+ kernel_size,
+ stride,
+ bias=False):
+ super(LocallyConnected2d, self).__init__()
+ output_size = _pair(output_size)
+ self.weight = nn.Parameter(
+ torch.randn(1, out_channels, in_channels, output_size[0],
+ output_size[1], kernel_size**2),
+ requires_grad=True,
+ )
+ if bias:
+ self.bias = nn.Parameter(torch.randn(1, out_channels,
+ output_size[0],
+ output_size[1]),
+ requires_grad=True)
+ else:
+ self.register_parameter('bias', None)
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+
+ def forward(self, x):
+ _, c, h, w = x.size()
+ kh, kw = self.kernel_size
+ dh, dw = self.stride
+ x = x.unfold(2, kh, dh).unfold(3, kw, dw)
+ x = x.contiguous().view(*x.size()[:-2], -1)
+ # Sum in in_channel and kernel_size dims
+ out = (x.unsqueeze(1) * self.weight).sum([2, -1])
+ if self.bias is not None:
+ out += self.bias
+ return out
+
+
+class KeypointAttention(nn.Module):
+ """Keypoint Attention Layer.
+
+ Args:
+ use_conv (bool):
+ whether to use conv for the attended feature map.
+ Default: False
+ in_channels (List[int]):
+ the in channel of shape_cam features and pose features.
+ Default: (256, 64)
+ out_channels (List[int]):
+ the out channel of shape_cam features and pose features.
+ Default: (256, 64)
+ Returns:
+ attended_features (torch.Tensor):
+ attended feature maps
+ """
+ def __init__(self,
+ use_conv=False,
+ in_channels=(256, 64),
+ out_channels=(256, 64),
+ act='softmax',
+ use_scale=False):
+ super(KeypointAttention, self).__init__()
+ self.use_conv = use_conv
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.act = act
+ self.use_scale = use_scale
+ if use_conv:
+ self.conv1x1_pose = nn.Conv1d(in_channels[0],
+ out_channels[0],
+ kernel_size=1)
+ self.conv1x1_shape_cam = nn.Conv1d(in_channels[1],
+ out_channels[1],
+ kernel_size=1)
+
+ def forward(self, features, heatmaps):
+ batch_size, num_joints, height, width = heatmaps.shape
+
+ if self.use_scale:
+ scale = 1.0 / np.sqrt(height * width)
+ heatmaps = heatmaps * scale
+
+ if self.act == 'softmax':
+ normalized_heatmap = F.softmax(heatmaps.reshape(
+ batch_size, num_joints, -1),
+ dim=-1)
+ elif self.act == 'sigmoid':
+ normalized_heatmap = torch.sigmoid(
+ heatmaps.reshape(batch_size, num_joints, -1))
+ features = features.reshape(batch_size, -1, height * width)
+
+ attended_features = torch.matmul(normalized_heatmap,
+ features.transpose(2, 1))
+ attended_features = attended_features.transpose(2, 1)
+
+ if self.use_conv:
+ if attended_features.shape[1] == self.in_channels[0]:
+ attended_features = self.conv1x1_pose(attended_features)
+ else:
+ attended_features = self.conv1x1_shape_cam(attended_features)
+
+ return attended_features
+
+
+def interpolate(feat, uv):
+ """
+ Args:
+ feat (torch.Tensor): [B, C, H, W] image features
+ uv (torch.Tensor): [B, 2, N] uv coordinates
+ in the image plane, range [-1, 1]
+ Returns:
+ samples[:, :, :, 0] (torch.Tensor):
+ [B, C, N] image features at the uv coordinates
+ """
+ if uv.shape[-1] != 2:
+ uv = uv.transpose(1, 2) # [B, N, 2]
+ uv = uv.unsqueeze(2) # [B, N, 1, 2]
+ # NOTE: for newer PyTorch, it seems that training
+ # results are degraded due to implementation diff in F.grid_sample
+ # for old versions, simply remove the aligned_corners argument.
+ if int(torch.__version__.split('.')[1]) < 4:
+ samples = torch.nn.functional.grid_sample(feat, uv) # [B, C, N, 1]
+ else:
+ samples = torch.nn.functional.grid_sample(
+ feat, uv, align_corners=True) # [B, C, N, 1]
+ return samples[:, :, :, 0] # [B, C, N]
+
+
+def _softmax(tensor, temperature, dim=-1):
+ return F.softmax(tensor * temperature, dim=dim)
+
+
+def softargmax2d(
+ heatmaps,
+ temperature=None,
+ normalize_keypoints=True,
+):
+ """Softargmax layer for heatmaps."""
+ dtype, device = heatmaps.dtype, heatmaps.device
+ if temperature is None:
+ temperature = torch.tensor(1.0, dtype=dtype, device=device)
+ batch_size, num_channels, height, width = heatmaps.shape
+ x = torch.arange(0, width, device=device, dtype=dtype).reshape(
+ 1, 1, 1, width).expand(batch_size, -1, height, -1)
+ y = torch.arange(0, height, device=device,
+ dtype=dtype).reshape(1, 1, height,
+ 1).expand(batch_size, -1, -1, width)
+ # Should be Bx2xHxW
+ points = torch.cat([x, y], dim=1)
+ normalized_heatmap = _softmax(heatmaps.reshape(batch_size, num_channels,
+ -1),
+ temperature=temperature.reshape(1, -1, 1),
+ dim=-1)
+
+ # Should be BxJx2
+ keypoints = (
+ normalized_heatmap.reshape(batch_size, -1, 1, height * width) *
+ points.reshape(batch_size, 1, 2, -1)).sum(dim=-1)
+
+ if normalize_keypoints:
+ # Normalize keypoints to [-1, 1]
+ keypoints[:, :, 0] = (keypoints[:, :, 0] / (width - 1) * 2 - 1)
+ keypoints[:, :, 1] = (keypoints[:, :, 1] / (height - 1) * 2 - 1)
+
+ return keypoints, normalized_heatmap.reshape(batch_size, -1, height, width)
+
+
+class PareHead(BaseModule):
+ def __init__(
+ self,
+ num_joints=24,
+ num_input_features=480,
+ softmax_temp=1.0,
+ num_deconv_layers=3,
+ num_deconv_filters=(256, 256, 256),
+ num_deconv_kernels=(4, 4, 4),
+ num_camera_params=3,
+ num_features_smpl=64,
+ final_conv_kernel=1,
+ pose_mlp_num_layers=1,
+ shape_mlp_num_layers=1,
+ pose_mlp_hidden_size=256,
+ shape_mlp_hidden_size=256,
+ bn_momentum=0.1,
+ use_heatmaps='part_segm',
+ use_keypoint_attention=False,
+ use_postconv_keypoint_attention=False,
+ keypoint_attention_act='softmax', # softmax, sigmoid
+ use_scale_keypoint_attention=False,
+ backbone='hrnet_w32-conv', # hrnet, resnet
+ smpl_mean_params=None,
+ deconv_with_bias=False,
+ ):
+ """PARE parameters regressor head. This class is modified from.
+
+ [PARE](hhttps://github.com/
+ mkocabas/PARE/blob/master/pare/models/head/pare_head.py). Original
+ license please see docs/additional_licenses.md.
+
+ Args:
+ num_joints (int):
+ Number of joints, should be 24 for smpl.
+ num_input_features (int):
+ Number of input featuremap channels.
+ softmax_temp (float):
+ Softmax tempreture
+ num_deconv_layers (int):
+ Number of deconvolution layers.
+ num_deconv_filters (List[int]):
+ Number of filters for each deconvolution layer,
+ len(num_deconv_filters) == num_deconv_layers.
+ num_deconv_kernels (List[int]):
+ Kernel size for each deconvolution layer,
+ len(num_deconv_kernels) == num_deconv_layers.
+ num_camera_params (int):
+ Number of predicted camera parameter dimension.
+ num_features_smpl (int):
+ Number of feature map channels.
+ final_conv_kernel (int):
+ Kernel size for the final deconvolution feature map channels.
+ pose_mlp_num_layers (int):
+ Number of mpl layers for pose parameter regression.
+ shape_mlp_num_layers (int):
+ Number of mpl layers for pose parameter regression.
+ pose_mlp_hidden_size (int):
+ Hidden size for pose mpl layers.
+ shape_mlp_hidden_size (int):
+ Hidden size for pose mpl layers.
+ bn_momemtum (float):
+ Momemtum for batch normalization.
+ use_heatmaps (str):
+ Types of heat maps to use.
+ use_keypoint_attention (bool)
+ Whether to use attention based on heat maps.
+ keypoint_attention_act (str):
+ Types of activation function for attention layers.
+ use_scale_keypoint_attention (str):
+ Whether to scale the attention
+ according to the size of the attention map.
+ deconv_with_bias (bool)
+ Whether to deconv with bias.
+ backbone (str):
+ Types of the backbone.
+ smpl_mean_params (str):
+ File name of the mean SMPL parameters
+ """
+
+ super(PareHead, self).__init__()
+ self.backbone = backbone
+ self.num_joints = num_joints
+ self.deconv_with_bias = deconv_with_bias
+ self.use_heatmaps = use_heatmaps
+ self.pose_mlp_num_layers = pose_mlp_num_layers
+ self.shape_mlp_num_layers = shape_mlp_num_layers
+ self.pose_mlp_hidden_size = pose_mlp_hidden_size
+ self.shape_mlp_hidden_size = shape_mlp_hidden_size
+ self.use_keypoint_attention = use_keypoint_attention
+
+ self.num_input_features = num_input_features
+ self.bn_momentum = bn_momentum
+ if self.use_heatmaps == 'part_segm':
+
+ self.use_keypoint_attention = True
+
+ if backbone.startswith('hrnet'):
+
+ self.keypoint_deconv_layers = self._make_conv_layer(
+ num_deconv_layers,
+ num_deconv_filters,
+ (3, ) * num_deconv_layers,
+ )
+ self.num_input_features = num_input_features
+ self.smpl_deconv_layers = self._make_conv_layer(
+ num_deconv_layers,
+ num_deconv_filters,
+ (3, ) * num_deconv_layers,
+ )
+ else:
+ # part branch that estimates 2d keypoints
+
+ conv_fn = self._make_deconv_layer
+
+ self.keypoint_deconv_layers = conv_fn(
+ num_deconv_layers,
+ num_deconv_filters,
+ num_deconv_kernels,
+ )
+ # reset inplanes to 2048 -> final resnet layer
+ self.num_input_features = num_input_features
+ self.smpl_deconv_layers = conv_fn(
+ num_deconv_layers,
+ num_deconv_filters,
+ num_deconv_kernels,
+ )
+
+ pose_mlp_inp_dim = num_deconv_filters[-1]
+ smpl_final_dim = num_features_smpl
+ shape_mlp_inp_dim = num_joints * smpl_final_dim
+
+ self.keypoint_final_layer = nn.Conv2d(
+ in_channels=num_deconv_filters[-1],
+ out_channels=num_joints +
+ 1 if self.use_heatmaps in ('part_segm',
+ 'part_segm_pool') else num_joints,
+ kernel_size=final_conv_kernel,
+ stride=1,
+ padding=1 if final_conv_kernel == 3 else 0,
+ )
+
+ self.smpl_final_layer = nn.Conv2d(
+ in_channels=num_deconv_filters[-1],
+ out_channels=smpl_final_dim,
+ kernel_size=final_conv_kernel,
+ stride=1,
+ padding=1 if final_conv_kernel == 3 else 0,
+ )
+
+ # temperature for softargmax function
+ self.register_buffer('temperature', torch.tensor(softmax_temp))
+ mean_params = np.load(smpl_mean_params)
+ init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
+ init_shape = torch.from_numpy(
+ mean_params['shape'][:].astype('float32')).unsqueeze(0)
+ init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0)
+ self.register_buffer('init_pose', init_pose)
+ self.register_buffer('init_shape', init_shape)
+ self.register_buffer('init_cam', init_cam)
+
+ self.pose_mlp_inp_dim = pose_mlp_inp_dim
+ self.shape_mlp_inp_dim = shape_mlp_inp_dim
+
+ self.shape_mlp = self._get_shape_mlp(output_size=10)
+ self.cam_mlp = self._get_shape_mlp(output_size=num_camera_params)
+
+ self.pose_mlp = self._get_pose_mlp(num_joints=num_joints,
+ output_size=6)
+
+ self.keypoint_attention = KeypointAttention(
+ use_conv=use_postconv_keypoint_attention,
+ in_channels=(self.pose_mlp_inp_dim, smpl_final_dim),
+ out_channels=(self.pose_mlp_inp_dim, smpl_final_dim),
+ act=keypoint_attention_act,
+ use_scale=use_scale_keypoint_attention,
+ )
+
+ def _get_shape_mlp(self, output_size):
+ """mlp layers for shape regression."""
+ if self.shape_mlp_num_layers == 1:
+ return nn.Linear(self.shape_mlp_inp_dim, output_size)
+
+ module_list = []
+ for i in range(self.shape_mlp_num_layers):
+ if i == 0:
+ module_list.append(
+ nn.Linear(self.shape_mlp_inp_dim,
+ self.shape_mlp_hidden_size))
+ elif i == self.shape_mlp_num_layers - 1:
+ module_list.append(
+ nn.Linear(self.shape_mlp_hidden_size, output_size))
+ else:
+ module_list.append(
+ nn.Linear(self.shape_mlp_hidden_size,
+ self.shape_mlp_hidden_size))
+ return nn.Sequential(*module_list)
+
+ def _get_pose_mlp(self, num_joints, output_size):
+ """mlp layers for pose regression."""
+ if self.pose_mlp_num_layers == 1:
+
+ return LocallyConnected2d(
+ in_channels=self.pose_mlp_inp_dim,
+ out_channels=output_size,
+ output_size=[num_joints, 1],
+ kernel_size=1,
+ stride=1,
+ )
+
+ module_list = []
+ for i in range(self.pose_mlp_num_layers):
+ if i == 0:
+ module_list.append(
+ LocallyConnected2d(
+ in_channels=self.pose_mlp_inp_dim,
+ out_channels=self.pose_mlp_hidden_size,
+ output_size=[num_joints, 1],
+ kernel_size=1,
+ stride=1,
+ ))
+ elif i == self.pose_mlp_num_layers - 1:
+ module_list.append(
+ LocallyConnected2d(
+ in_channels=self.pose_mlp_hidden_size,
+ out_channels=output_size,
+ output_size=[num_joints, 1],
+ kernel_size=1,
+ stride=1,
+ ))
+ else:
+ module_list.append(
+ LocallyConnected2d(
+ in_channels=self.pose_mlp_hidden_size,
+ out_channels=self.pose_mlp_hidden_size,
+ output_size=[num_joints, 1],
+ kernel_size=1,
+ stride=1,
+ ))
+ return nn.Sequential(*module_list)
+
+ def _get_deconv_cfg(self, deconv_kernel):
+ """get deconv padding, output padding according to kernel size."""
+ if deconv_kernel == 4:
+ padding = 1
+ output_padding = 0
+ elif deconv_kernel == 3:
+ padding = 1
+ output_padding = 1
+ elif deconv_kernel == 2:
+ padding = 0
+ output_padding = 0
+
+ return deconv_kernel, padding, output_padding
+
+ def _make_conv_layer(self, num_layers, num_filters, num_kernels):
+ """make convolution layers."""
+ assert num_layers == len(num_filters), \
+ 'ERROR: num_conv_layers is different len(num_conv_filters)'
+ assert num_layers == len(num_kernels), \
+ 'ERROR: num_conv_layers is different len(num_conv_filters)'
+ layers = []
+ for i in range(num_layers):
+ kernel, padding, output_padding = \
+ self._get_deconv_cfg(num_kernels[i])
+
+ planes = num_filters[i]
+ layers.append(
+ nn.Conv2d(in_channels=self.num_input_features,
+ out_channels=planes,
+ kernel_size=kernel,
+ stride=1,
+ padding=padding,
+ bias=self.deconv_with_bias))
+ layers.append(nn.BatchNorm2d(planes, momentum=self.bn_momentum))
+ layers.append(nn.ReLU(inplace=True))
+ self.num_input_features = planes
+
+ return nn.Sequential(*layers)
+
+ def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
+ """make deconvolution layers."""
+ assert num_layers == len(num_filters), \
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
+ assert num_layers == len(num_kernels), \
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
+
+ layers = []
+ for i in range(num_layers):
+ kernel, padding, output_padding = \
+ self._get_deconv_cfg(num_kernels[i])
+
+ planes = num_filters[i]
+ layers.append(
+ nn.ConvTranspose2d(in_channels=self.num_input_features,
+ out_channels=planes,
+ kernel_size=kernel,
+ stride=2,
+ padding=padding,
+ output_padding=output_padding,
+ bias=self.deconv_with_bias))
+ layers.append(nn.BatchNorm2d(planes, momentum=self.bn_momentum))
+ layers.append(nn.ReLU(inplace=True))
+ # if self.use_self_attention:
+ # layers.append(SelfAttention(planes))
+ self.num_input_features = planes
+
+ return nn.Sequential(*layers)
+
+ def forward(self, features):
+ batch_size = features.shape[0]
+
+ init_pose = self.init_pose.expand(batch_size, -1) # N, Jx6
+ init_shape = self.init_shape.expand(batch_size, -1)
+ init_cam = self.init_cam.expand(batch_size, -1)
+
+ output = {}
+
+ part_feats = self._get_2d_branch_feats(features)
+
+ part_attention = self._get_part_attention_map(part_feats, output)
+
+ smpl_feats = self._get_3d_smpl_feats(features, part_feats)
+
+ point_local_feat, cam_shape_feats = self._get_local_feats(
+ smpl_feats, part_attention, output)
+
+ pred_pose, pred_shape, pred_cam = self._get_final_preds(
+ point_local_feat, cam_shape_feats, init_pose, init_shape, init_cam)
+
+ pred_rotmat = rot6d_to_rotmat(pred_pose).reshape(batch_size, 24, 3, 3)
+
+ output.update({
+ 'pred_pose': pred_rotmat,
+ 'pred_cam': pred_cam,
+ 'pred_shape': pred_shape,
+ })
+ return output
+
+ def _get_local_feats(self, smpl_feats, part_attention, output):
+ # 1x1 conv
+ """get keypoints and camera features from backbone features."""
+
+ cam_shape_feats = self.smpl_final_layer(smpl_feats)
+
+ if self.use_keypoint_attention:
+ point_local_feat = self.keypoint_attention(smpl_feats,
+ part_attention)
+ cam_shape_feats = self.keypoint_attention(cam_shape_feats,
+ part_attention)
+ else:
+ point_local_feat = interpolate(smpl_feats, output['pred_kp2d'])
+ cam_shape_feats = interpolate(cam_shape_feats, output['pred_kp2d'])
+ return point_local_feat, cam_shape_feats
+
+ def _get_2d_branch_feats(self, features):
+ """get part features from backbone features."""
+ part_feats = self.keypoint_deconv_layers(features)
+
+ return part_feats
+
+ def _get_3d_smpl_feats(self, features, part_feats):
+ """get smpl feature maps from backbone features."""
+
+ smpl_feats = self.smpl_deconv_layers(features)
+
+ return smpl_feats
+
+ def _get_part_attention_map(self, part_feats, output):
+ """get attention map from part feature map."""
+ heatmaps = self.keypoint_final_layer(part_feats)
+
+ if self.use_heatmaps == 'part_segm':
+
+ output['pred_segm_mask'] = heatmaps
+ # remove the the background channel
+ heatmaps = heatmaps[:, 1:, :, :]
+ else:
+ pred_kp2d, _ = softargmax2d(heatmaps, self.temperature)
+ output['pred_kp2d'] = pred_kp2d
+ output['pred_heatmaps_2d'] = heatmaps
+ return heatmaps
+
+ def _get_final_preds(self, pose_feats, cam_shape_feats, init_pose,
+ init_shape, init_cam):
+ """get final preds."""
+ return self._pare_get_final_preds(pose_feats, cam_shape_feats,
+ init_pose, init_shape, init_cam)
+
+ def _pare_get_final_preds(self, pose_feats, cam_shape_feats, init_pose,
+ init_shape, init_cam):
+ """get final preds."""
+ pose_feats = pose_feats.unsqueeze(-1) #
+
+ if init_pose.shape[-1] == 6:
+ # This means init_pose comes from a previous iteration
+ init_pose = init_pose.transpose(2, 1).unsqueeze(-1)
+ else:
+ # This means init pose comes from mean pose
+ init_pose = init_pose.reshape(init_pose.shape[0], 6,
+ -1).unsqueeze(-1)
+
+ shape_feats = cam_shape_feats
+
+ shape_feats = torch.flatten(shape_feats, start_dim=1)
+
+ pred_pose = self.pose_mlp(pose_feats)
+ pred_cam = self.cam_mlp(shape_feats)
+ pred_shape = self.shape_mlp(shape_feats)
+
+ pred_pose = pred_pose.squeeze(-1).transpose(2, 1) # N, J, 6
+ return pred_pose, pred_shape, pred_cam
diff --git a/detrsmpl/models/losses/__init__.py b/detrsmpl/models/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/models/losses/balanced_mse_loss.py b/detrsmpl/models/losses/balanced_mse_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..fee145e114eaf9ca7efa83ddc4bf9765efb3c7fa
--- /dev/null
+++ b/detrsmpl/models/losses/balanced_mse_loss.py
@@ -0,0 +1,146 @@
+# ------------------------------------------------------------------------------
+# Adapted from https://github.com/jiawei-ren/BalancedMSE
+# Original licence: Copyright (c) 2022 Jiawei Ren, under the MIT License.
+# ------------------------------------------------------------------------------
+
+from typing import Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from mmcv.runner import get_dist_info
+from torch.nn.modules.loss import _Loss
+
+from .utils import weighted_loss
+
+
+@weighted_loss
+def bmc_loss_md(pred: torch.Tensor, target: torch.Tensor,
+ noise_var: torch.Tensor, all_gather: bool,
+ loss_mse_weight: float,
+ loss_debias_weight: float) -> torch.Tensor:
+ """
+ Args:
+ pred (torch.Tensor): The prediction. Shape should be (N, L).
+ target (torch.Tensor): The learning target of the prediction.
+ noise_var (torch.Tensor): Noise var of ground truth distribution.
+ all_gather (bool): Whether gather tensors across all sub-processes.
+ Only used in DDP training scheme.
+ loss_mse_weight (float, optional): The weight of the mse term.
+ loss_debias_weight (float, optional): The weight of the debiased term.
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ N = pred.shape[0]
+ L = pred.shape[1]
+ device = pred.device
+
+ loss_mse = F.mse_loss(pred, target, reduction='none').sum(-1)
+ loss_mse = loss_mse / noise_var
+
+ if all_gather:
+ rank, world_size = get_dist_info()
+ bs, length = target.shape
+ all_bs = [torch.zeros(1).to(device) for _ in range(world_size)]
+ dist.all_gather(all_bs, torch.Tensor([bs]).to(device))
+ all_bs_int = [int(v.item()) for v in all_bs]
+ max_bs_int = max(all_bs_int)
+ target_padding = torch.zeros(max_bs_int, length).to(device)
+ target_padding[:bs] = target
+ all_tensor = []
+ for _ in range(world_size):
+ all_tensor.append(torch.zeros(max_bs_int, length).type_as(target))
+ dist.all_gather(all_tensor, target_padding)
+ # remove padding
+ for i in range(world_size):
+ all_tensor[i] = all_tensor[i][:all_bs_int[i]]
+ target = torch.cat(all_tensor, dim=0)
+
+ # Debias term
+ target = target.unsqueeze(0).repeat(N, 1, 1)
+ pred = pred.unsqueeze(1).expand_as(target)
+ debias_term = F.mse_loss(pred, target, reduction='none').sum(-1)
+ debias_term = -0.5 * debias_term / noise_var
+ loss_debias = torch.logsumexp(debias_term, dim=1).squeeze(-1)
+ loss = loss_mse * loss_mse_weight + loss_debias * loss_debias_weight
+ # recover loss scale of mse_loss
+ loss = loss / L * noise_var.detach()
+ return loss
+
+
+class BMCLossMD(_Loss):
+ """Balanced MSE loss, use batch monte-carlo to estimate distribution.
+ https://arxiv.org/abs/2203.16427.
+
+ Args:
+ init_noise_sigma (float, optional): The initial value of noise sigma.
+ This sigma is used to represent ground truth distribution.
+ Defaults to 1.0.
+ all_gather (bool, optional): Whether gather tensors across all
+ sub-processes. If set True, BMC will have more precise estimation
+ with more time cost. Default: False.
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ loss_mse_weight (float, optional): The weight of the mse term.
+ Defaults to 1.0.
+ loss_debias_weight (float, optional): The weight of the debiased term.
+ Defaults to 1.0.
+ """
+ def __init__(self,
+ init_noise_sigma: Optional[float] = 1.0,
+ all_gather: Optional[bool] = False,
+ reduction: Optional[str] = 'mean',
+ loss_mse_weight: Optional[float] = 1.0,
+ loss_debias_weight: Optional[float] = 1.0):
+ super(BMCLossMD, self).__init__()
+ self.noise_sigma = torch.nn.Parameter(
+ torch.tensor(init_noise_sigma).float())
+ self.all_gather = all_gather
+ assert reduction in (None, 'none', 'mean', 'sum')
+ reduction = 'none' if reduction is None else reduction
+ self.reduction = reduction
+ self.loss_mse_weight = loss_mse_weight
+ self.loss_debias_weight = loss_debias_weight
+
+ def forward(
+ self,
+ pred: torch.Tensor,
+ target: torch.Tensor,
+ weight: Optional[Union[torch.Tensor, None]] = None,
+ avg_factor: Optional[Union[int, None]] = None,
+ reduction_override: Optional[Union[str,
+ None]] = None) -> torch.Tensor:
+ """Forward function of loss.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): Weight of the loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ weight (torch.Tensor, optional): Weight of the loss for each
+ prediction. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ noise_var = (self.noise_sigma**2).type_as(pred)
+ pred = pred.view(pred.shape[0], -1)
+ target = target.view(target.shape[0], -1)
+ loss = bmc_loss_md(pred,
+ target,
+ noise_var=noise_var,
+ all_gather=self.all_gather,
+ loss_mse_weight=self.loss_mse_weight,
+ loss_debias_weight=self.loss_debias_weight,
+ weight=weight,
+ reduction=reduction,
+ avg_factor=avg_factor)
+ return loss
diff --git a/detrsmpl/models/losses/builder.py b/detrsmpl/models/losses/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7c1db0f3b1b11ec918e3c9a5a094e3dcfce2978
--- /dev/null
+++ b/detrsmpl/models/losses/builder.py
@@ -0,0 +1,60 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from mmcv.utils import Registry
+
+from .balanced_mse_loss import BMCLossMD
+from .cross_entropy_loss import CrossEntropyLoss
+from .focal_loss import FocalLoss
+from .gan_loss import GANLoss
+from .iou_loss import BoundedIoULoss, CIoULoss, DIoULoss, GIoULoss, IoULoss
+from .mse_loss import KeypointMSELoss, MSELoss
+from .prior_loss import (
+ CameraPriorLoss,
+ JointPriorLoss,
+ LimbLengthLoss,
+ MaxMixturePrior,
+ PoseRegLoss,
+ ShapePriorLoss,
+ ShapeThresholdPriorLoss,
+ SmoothJointLoss,
+ SmoothPelvisLoss,
+ SmoothTranslationLoss,
+)
+from .rotaion_distance_loss import RotationDistance
+from .smooth_l1_loss import L1Loss, SmoothL1Loss
+
+LOSSES = Registry('losses')
+
+LOSSES.register_module(name='GANLoss', module=GANLoss)
+LOSSES.register_module(name='MSELoss', module=MSELoss)
+LOSSES.register_module(name='KeypointMSELoss', module=KeypointMSELoss)
+LOSSES.register_module(name='ShapePriorLoss', module=ShapePriorLoss)
+LOSSES.register_module(name='PoseRegLoss', module=PoseRegLoss)
+LOSSES.register_module(name='LimbLengthLoss', module=LimbLengthLoss)
+LOSSES.register_module(name='JointPriorLoss', module=JointPriorLoss)
+LOSSES.register_module(name='SmoothJointLoss', module=SmoothJointLoss)
+LOSSES.register_module(name='SmoothPelvisLoss', module=SmoothPelvisLoss)
+LOSSES.register_module(name='SmoothTranslationLoss',
+ module=SmoothTranslationLoss)
+LOSSES.register_module(name='ShapeThresholdPriorLoss',
+ module=ShapeThresholdPriorLoss)
+LOSSES.register_module(name='CameraPriorLoss', module=CameraPriorLoss)
+LOSSES.register_module(name='MaxMixturePrior', module=MaxMixturePrior)
+LOSSES.register_module(name='L1Loss', module=L1Loss)
+LOSSES.register_module(name='SmoothL1Loss', module=SmoothL1Loss)
+LOSSES.register_module(name='CrossEntropyLoss', module=CrossEntropyLoss)
+LOSSES.register_module(name='RotationDistance', module=RotationDistance)
+LOSSES.register_module(name='BMCLossMD', module=BMCLossMD)
+LOSSES.register_module(name='FocalLoss', module=FocalLoss)
+LOSSES.register_module(name='IoULoss', module=IoULoss)
+LOSSES.register_module(name='BoundedIoULoss', module=BoundedIoULoss)
+LOSSES.register_module(name='GIoULoss', module=GIoULoss)
+LOSSES.register_module(name='DIoULoss', module=DIoULoss)
+LOSSES.register_module(name='CIoULoss', module=CIoULoss)
+
+
+def build_loss(cfg):
+ """Build loss."""
+ if cfg is None:
+ return None
+ return LOSSES.build(cfg)
diff --git a/detrsmpl/models/losses/cross_entropy_loss.py b/detrsmpl/models/losses/cross_entropy_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7ecdc375309c0fd097371faebffca8b842338d2
--- /dev/null
+++ b/detrsmpl/models/losses/cross_entropy_loss.py
@@ -0,0 +1,254 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .utils import weight_reduce_loss
+
+
+def cross_entropy(pred,
+ label,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=-100):
+ """Calculate the CrossEntropy loss.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
+ of classes.
+ label (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ reduction (str, optional): The method used to reduce the loss.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (int | None): The label index to be ignored.
+ If None, it will be set to default value. Default: -100.
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ # The default value of ignore_index is the same as F.cross_entropy
+ ignore_index = -100 if ignore_index is None else ignore_index
+ # element-wise losses
+ loss = F.cross_entropy(pred,
+ label,
+ weight=class_weight,
+ reduction='none',
+ ignore_index=ignore_index)
+
+ # apply weights and do the reduction
+ if weight is not None:
+ weight = weight.float()
+ loss = weight_reduce_loss(loss,
+ weight=weight,
+ reduction=reduction,
+ avg_factor=avg_factor)
+
+ return loss
+
+
+def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
+ """Expand onehot labels to match the size of prediction."""
+ bin_labels = labels.new_full((labels.size(0), label_channels), 0)
+ valid_mask = (labels >= 0) & (labels != ignore_index)
+ inds = torch.nonzero(valid_mask & (labels < label_channels),
+ as_tuple=False)
+
+ if inds.numel() > 0:
+ bin_labels[inds, labels[inds]] = 1
+
+ valid_mask = valid_mask.view(-1, 1).expand(labels.size(0),
+ label_channels).float()
+ if label_weights is None:
+ bin_label_weights = valid_mask
+ else:
+ bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
+ bin_label_weights *= valid_mask
+
+ return bin_labels, bin_label_weights
+
+
+def binary_cross_entropy(pred,
+ label,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=-100):
+ """Calculate the binary CrossEntropy loss.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, 1).
+ label (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (int | None): The label index to be ignored.
+ If None, it will be set to default value. Default: -100.
+
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ # The default value of ignore_index is the same as F.cross_entropy
+ ignore_index = -100 if ignore_index is None else ignore_index
+ if pred.dim() != label.dim():
+ label, weight = _expand_onehot_labels(label, weight, pred.size(-1),
+ ignore_index)
+
+ # weighted element-wise losses
+ if weight is not None:
+ weight = weight.float()
+ loss = F.binary_cross_entropy_with_logits(pred,
+ label.float(),
+ pos_weight=class_weight,
+ reduction='none')
+ # do the reduction for the weighted loss
+ loss = weight_reduce_loss(loss,
+ weight,
+ reduction=reduction,
+ avg_factor=avg_factor)
+
+ return loss
+
+
+def mask_cross_entropy(pred,
+ target,
+ label,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=None):
+ """Calculate the CrossEntropy loss for masks.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C, *), C is the
+ number of classes. The trailing * indicates arbitrary shape.
+ target (torch.Tensor): The learning label of the prediction.
+ label (torch.Tensor): ``label`` indicates the class label of the mask
+ corresponding object. This will be used to select the mask in the
+ of the class which the object belongs to when the mask prediction
+ if not class-agnostic.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (None): Placeholder, to be consistent with other loss.
+ Default: None.
+
+ Returns:
+ torch.Tensor: The calculated loss
+
+ Example:
+ >>> N, C = 3, 11
+ >>> H, W = 2, 2
+ >>> pred = torch.randn(N, C, H, W) * 1000
+ >>> target = torch.rand(N, H, W)
+ >>> label = torch.randint(0, C, size=(N,))
+ >>> reduction = 'mean'
+ >>> avg_factor = None
+ >>> class_weights = None
+ >>> loss = mask_cross_entropy(pred, target, label, reduction,
+ >>> avg_factor, class_weights)
+ >>> assert loss.shape == (1,)
+ """
+ assert ignore_index is None, 'BCE loss does not support ignore_index'
+ # TODO: handle these two reserved arguments
+ assert reduction == 'mean' and avg_factor is None
+ num_rois = pred.size()[0]
+ inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
+ pred_slice = pred[inds, label].squeeze(1)
+ return F.binary_cross_entropy_with_logits(pred_slice,
+ target,
+ weight=class_weight,
+ reduction='mean')[None]
+
+
+class CrossEntropyLoss(nn.Module):
+ def __init__(self,
+ use_sigmoid=False,
+ use_mask=False,
+ reduction='mean',
+ class_weight=None,
+ ignore_index=None,
+ loss_weight=1.0):
+ """CrossEntropyLoss.
+
+ Args:
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
+ of softmax. Defaults to False.
+ use_mask (bool, optional): Whether to use mask cross entropy loss.
+ Defaults to False.
+ reduction (str, optional): . Defaults to 'mean'.
+ Options are "none", "mean" and "sum".
+ class_weight (list[float], optional): Weight of each class.
+ Defaults to None.
+ ignore_index (int | None): The label index to be ignored.
+ Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+ """
+ super(CrossEntropyLoss, self).__init__()
+ assert (use_sigmoid is False) or (use_mask is False)
+ self.use_sigmoid = use_sigmoid
+ self.use_mask = use_mask
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.class_weight = class_weight
+ self.ignore_index = ignore_index
+
+ if self.use_sigmoid:
+ self.cls_criterion = binary_cross_entropy
+ elif self.use_mask:
+ self.cls_criterion = mask_cross_entropy
+ else:
+ self.cls_criterion = cross_entropy
+
+ def forward(self,
+ cls_score,
+ label,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ ignore_index=None,
+ **kwargs):
+ """Forward function.
+
+ Args:
+ cls_score (torch.Tensor): The prediction.
+ label (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The method used to reduce the
+ loss. Options are "none", "mean" and "sum".
+ ignore_index (int | None): The label index to be ignored.
+ If not None, it will override the default value. Default: None.
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ if ignore_index is None:
+ ignore_index = self.ignore_index
+
+ if self.class_weight is not None:
+ class_weight = cls_score.new_tensor(self.class_weight,
+ device=cls_score.device)
+ else:
+ class_weight = None
+ loss_cls = self.loss_weight * self.cls_criterion(
+ cls_score,
+ label,
+ weight,
+ class_weight=class_weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ ignore_index=ignore_index,
+ **kwargs)
+ return loss_cls
diff --git a/detrsmpl/models/losses/focal_loss.py b/detrsmpl/models/losses/focal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..5687252d28978f838a66b456c4442de9d72df99e
--- /dev/null
+++ b/detrsmpl/models/losses/focal_loss.py
@@ -0,0 +1,241 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
+
+from .utils import weight_reduce_loss
+
+
+# This method is only for debugging
+def py_sigmoid_focal_loss(pred,
+ target,
+ weight=None,
+ gamma=2.0,
+ alpha=0.25,
+ reduction='mean',
+ avg_factor=None):
+ """PyTorch version of `Focal Loss `_.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the
+ number of classes
+ target (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 0.25.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ """
+ pred_sigmoid = pred.sigmoid()
+ target = target.type_as(pred)
+ pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
+ focal_weight = (alpha * target + (1 - alpha) *
+ (1 - target)) * pt.pow(gamma)
+ loss = F.binary_cross_entropy_with_logits(pred, target,
+ reduction='none') * focal_weight
+ if weight is not None:
+ if weight.shape != loss.shape:
+ if weight.size(0) == loss.size(0):
+ # For most cases, weight is of shape (num_priors, ),
+ # which means it does not have the second axis num_class
+ weight = weight.view(-1, 1)
+ else:
+ # Sometimes, weight per anchor per class is also needed. e.g.
+ # in FSAF. But it may be flattened of shape
+ # (num_priors x num_class, ), while loss is still of shape
+ # (num_priors, num_class).
+ assert weight.numel() == loss.numel()
+ weight = weight.view(loss.size(0), -1)
+ assert weight.ndim == loss.ndim
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+
+def py_focal_loss_with_prob(pred,
+ target,
+ weight=None,
+ gamma=2.0,
+ alpha=0.25,
+ reduction='mean',
+ avg_factor=None):
+ """PyTorch version of `Focal Loss `_.
+ Different from `py_sigmoid_focal_loss`, this function accepts probability
+ as input.
+
+ Args:
+ pred (torch.Tensor): The prediction probability with shape (N, C),
+ C is the number of classes.
+ target (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 0.25.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ """
+ num_classes = pred.size(1)
+ target = F.one_hot(target, num_classes=num_classes + 1)
+ target = target[:, :num_classes]
+
+ target = target.type_as(pred)
+ pt = (1 - pred) * target + pred * (1 - target)
+ focal_weight = (alpha * target + (1 - alpha) *
+ (1 - target)) * pt.pow(gamma)
+ loss = F.binary_cross_entropy(pred, target,
+ reduction='none') * focal_weight
+ if weight is not None:
+ if weight.shape != loss.shape:
+ if weight.size(0) == loss.size(0):
+ # For most cases, weight is of shape (num_priors, ),
+ # which means it does not have the second axis num_class
+ weight = weight.view(-1, 1)
+ else:
+ # Sometimes, weight per anchor per class is also needed. e.g.
+ # in FSAF. But it may be flattened of shape
+ # (num_priors x num_class, ), while loss is still of shape
+ # (num_priors, num_class).
+ assert weight.numel() == loss.numel()
+ weight = weight.view(loss.size(0), -1)
+ assert weight.ndim == loss.ndim
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+
+def sigmoid_focal_loss(pred,
+ target,
+ weight=None,
+ gamma=2.0,
+ alpha=0.25,
+ reduction='mean',
+ avg_factor=None):
+ r"""A warpper of cuda version `Focal Loss
+ `_.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
+ of classes.
+ target (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 0.25.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ """
+ # Function.apply does not accept keyword arguments, so the decorator
+ # "weighted_loss" is not applicable
+ loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma,
+ alpha, None, 'none')
+ if weight is not None:
+ if weight.shape != loss.shape:
+ if weight.size(0) == loss.size(0):
+ # For most cases, weight is of shape (num_priors, ),
+ # which means it does not have the second axis num_class
+ weight = weight.view(-1, 1)
+ else:
+ # Sometimes, weight per anchor per class is also needed. e.g.
+ # in FSAF. But it may be flattened of shape
+ # (num_priors x num_class, ), while loss is still of shape
+ # (num_priors, num_class).
+ assert weight.numel() == loss.numel()
+ weight = weight.view(loss.size(0), -1)
+ assert weight.ndim == loss.ndim
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+
+class FocalLoss(nn.Module):
+ def __init__(self,
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ reduction='mean',
+ loss_weight=1.0,
+ activated=False):
+ """`Focal Loss `_
+
+ Args:
+ use_sigmoid (bool, optional): Whether to the prediction is
+ used for sigmoid or softmax. Defaults to True.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 0.25.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'. Options are "none", "mean" and
+ "sum".
+ loss_weight (float, optional): Weight of loss. Defaults to 1.0.
+ activated (bool, optional): Whether the input is activated.
+ If True, it means the input has been activated and can be
+ treated as probabilities. Else, it should be treated as logits.
+ Defaults to False.
+ """
+ super(FocalLoss, self).__init__()
+ assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
+ self.use_sigmoid = use_sigmoid
+ self.gamma = gamma
+ self.alpha = alpha
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.activated = activated
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Options are "none", "mean" and "sum".
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ if self.use_sigmoid:
+ if self.activated:
+ calculate_loss_func = py_focal_loss_with_prob
+ else:
+ if torch.cuda.is_available() and pred.is_cuda:
+ calculate_loss_func = sigmoid_focal_loss
+ else:
+ num_classes = pred.size(1)
+ target = F.one_hot(target, num_classes=num_classes + 1)
+ target = target[:, :num_classes]
+ calculate_loss_func = py_sigmoid_focal_loss
+
+ loss_cls = self.loss_weight * calculate_loss_func(
+ pred,
+ target,
+ weight,
+ gamma=self.gamma,
+ alpha=self.alpha,
+ reduction=reduction,
+ avg_factor=avg_factor)
+
+ else:
+ raise NotImplementedError
+ return loss_cls
diff --git a/detrsmpl/models/losses/gan_loss.py b/detrsmpl/models/losses/gan_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f6b085acdbcfcce11cff6fb13bfebc7edb16080
--- /dev/null
+++ b/detrsmpl/models/losses/gan_loss.py
@@ -0,0 +1,90 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+
+class GANLoss(nn.Module):
+ """Define GAN loss.
+
+ Args:
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
+ real_label_val (float): The value for real label. Default: 1.0.
+ fake_label_val (float): The value for fake label. Default: 0.0.
+ loss_weight (float): Loss weight. Default: 1.0.
+ Note that loss_weight is only for generators; and it is always 1.0
+ for discriminators.
+ """
+ def __init__(self,
+ gan_type,
+ real_label_val=1.0,
+ fake_label_val=0.0,
+ loss_weight=1.0):
+ super().__init__()
+ self.gan_type = gan_type
+ self.loss_weight = loss_weight
+ self.real_label_val = real_label_val
+ self.fake_label_val = fake_label_val
+
+ if self.gan_type == 'vanilla':
+ self.loss = nn.BCEWithLogitsLoss()
+ elif self.gan_type == 'lsgan':
+ self.loss = nn.MSELoss()
+ elif self.gan_type == 'wgan':
+ self.loss = self._wgan_loss
+ elif self.gan_type == 'hinge':
+ self.loss = nn.ReLU()
+ else:
+ raise NotImplementedError(
+ f'GAN type {self.gan_type} is not implemented.')
+
+ @staticmethod
+ def _wgan_loss(input, target):
+ """wgan loss.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+ Returns:
+ Tensor: wgan loss.
+ """
+ return -input.mean() if target else input.mean()
+
+ def get_target_label(self, input, target_is_real):
+ """Get target label.
+
+ Args:
+ input (Tensor): Input tensor.
+ target_is_real (bool): Whether the target is real or fake.
+ Returns:
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
+ return Tensor.
+ """
+
+ if self.gan_type == 'wgan':
+ return target_is_real
+ target_val = (self.real_label_val
+ if target_is_real else self.fake_label_val)
+ return input.new_ones(input.size()) * target_val
+
+ def forward(self, input, target_is_real, is_disc=False):
+ """
+ Args:
+ input (Tensor): The input for the loss module, i.e., the network
+ prediction.
+ target_is_real (bool): Whether the targe is real or fake.
+ is_disc (bool): Whether the loss for discriminators or not.
+ Default: False.
+ Returns:
+ Tensor: GAN loss value.
+ """
+ target_label = self.get_target_label(input, target_is_real)
+ if self.gan_type == 'hinge':
+ if is_disc: # for discriminators in hinge-gan
+ input = -input if target_is_real else input
+ loss = self.loss(1 + input).mean()
+ else: # for generators in hinge-gan
+ loss = -input.mean()
+ else: # other gan types
+ loss = self.loss(input, target_label)
+
+ # loss_weight is always 1.0 for discriminators
+ return loss if is_disc else loss * self.loss_weight
diff --git a/detrsmpl/models/losses/iou_loss.py b/detrsmpl/models/losses/iou_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d3fdd3417fd2ee1a25fa28ce6c90c4700794ab2
--- /dev/null
+++ b/detrsmpl/models/losses/iou_loss.py
@@ -0,0 +1,458 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import warnings
+
+import mmcv
+import torch
+import torch.nn as nn
+from mmdet.core import bbox_overlaps
+
+from .utils import weighted_loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def iou_loss(pred, target, linear=False, mode='log', eps=1e-6):
+ """IoU loss.
+
+ Computing the IoU loss between a set of predicted bboxes and target bboxes.
+ The loss is calculated as negative log of IoU.
+
+ Args:
+ pred (torch.Tensor): Predicted bboxes of format (x1, y1, x2, y2),
+ shape (n, 4).
+ target (torch.Tensor): Corresponding gt bboxes, shape (n, 4).
+ linear (bool, optional): If True, use linear scale of loss instead of
+ log scale. Default: False.
+ mode (str): Loss scaling mode, including "linear", "square", and "log".
+ Default: 'log'
+ eps (float): Eps to avoid log(0).
+
+ Return:
+ torch.Tensor: Loss tensor.
+ """
+ assert mode in ['linear', 'square', 'log']
+ if linear:
+ mode = 'linear'
+ warnings.warn('DeprecationWarning: Setting "linear=True" in '
+ 'iou_loss is deprecated, please use "mode=`linear`" '
+ 'instead.')
+ ious = bbox_overlaps(pred, target, is_aligned=True).clamp(min=eps)
+ if mode == 'linear':
+ loss = 1 - ious
+ elif mode == 'square':
+ loss = 1 - ious**2
+ elif mode == 'log':
+ loss = -ious.log()
+ else:
+ raise NotImplementedError
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def bounded_iou_loss(pred, target, beta=0.2, eps=1e-3):
+ """BIoULoss.
+
+ This is an implementation of paper
+ `Improving Object Localization with Fitness NMS and Bounded IoU Loss.
+ `_.
+
+ Args:
+ pred (torch.Tensor): Predicted bboxes.
+ target (torch.Tensor): Target bboxes.
+ beta (float): beta parameter in smoothl1.
+ eps (float): eps to avoid NaN.
+ """
+ pred_ctrx = (pred[:, 0] + pred[:, 2]) * 0.5
+ pred_ctry = (pred[:, 1] + pred[:, 3]) * 0.5
+ pred_w = pred[:, 2] - pred[:, 0]
+ pred_h = pred[:, 3] - pred[:, 1]
+ with torch.no_grad():
+ target_ctrx = (target[:, 0] + target[:, 2]) * 0.5
+ target_ctry = (target[:, 1] + target[:, 3]) * 0.5
+ target_w = target[:, 2] - target[:, 0]
+ target_h = target[:, 3] - target[:, 1]
+
+ dx = target_ctrx - pred_ctrx
+ dy = target_ctry - pred_ctry
+
+ loss_dx = 1 - torch.max(
+ (target_w - 2 * dx.abs()) /
+ (target_w + 2 * dx.abs() + eps), torch.zeros_like(dx))
+ loss_dy = 1 - torch.max(
+ (target_h - 2 * dy.abs()) /
+ (target_h + 2 * dy.abs() + eps), torch.zeros_like(dy))
+ loss_dw = 1 - torch.min(target_w / (pred_w + eps), pred_w /
+ (target_w + eps))
+ loss_dh = 1 - torch.min(target_h / (pred_h + eps), pred_h /
+ (target_h + eps))
+ # view(..., -1) does not work for empty tensor
+ loss_comb = torch.stack([loss_dx, loss_dy, loss_dw, loss_dh],
+ dim=-1).flatten(1)
+
+ loss = torch.where(loss_comb < beta, 0.5 * loss_comb * loss_comb / beta,
+ loss_comb - 0.5 * beta)
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def giou_loss(pred, target, eps=1e-7):
+ r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding
+ Box Regression `_.
+
+ Args:
+ pred (torch.Tensor): Predicted bboxes of format (x1, y1, x2, y2),
+ shape (n, 4).
+ target (torch.Tensor): Corresponding gt bboxes, shape (n, 4).
+ eps (float): Eps to avoid log(0).
+
+ Return:
+ Tensor: Loss tensor.
+ """
+ gious = bbox_overlaps(pred, target, mode='giou', is_aligned=True, eps=eps)
+ loss = 1 - gious
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def diou_loss(pred, target, eps=1e-7):
+ r"""`Implementation of Distance-IoU Loss: Faster and Better
+ Learning for Bounding Box Regression, https://arxiv.org/abs/1911.08287`_.
+
+ Code is modified from https://github.com/Zzh-tju/DIoU.
+
+ Args:
+ pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
+ shape (n, 4).
+ target (Tensor): Corresponding gt bboxes, shape (n, 4).
+ eps (float): Eps to avoid log(0).
+ Return:
+ Tensor: Loss tensor.
+ """
+ # overlap
+ lt = torch.max(pred[:, :2], target[:, :2])
+ rb = torch.min(pred[:, 2:], target[:, 2:])
+ wh = (rb - lt).clamp(min=0)
+ overlap = wh[:, 0] * wh[:, 1]
+
+ # union
+ ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
+ ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
+ union = ap + ag - overlap + eps
+
+ # IoU
+ ious = overlap / union
+
+ # enclose area
+ enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
+ enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
+ enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
+
+ cw = enclose_wh[:, 0]
+ ch = enclose_wh[:, 1]
+
+ c2 = cw**2 + ch**2 + eps
+
+ b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
+ b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
+ b2_x1, b2_y1 = target[:, 0], target[:, 1]
+ b2_x2, b2_y2 = target[:, 2], target[:, 3]
+
+ left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
+ right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
+ rho2 = left + right
+
+ # DIoU
+ dious = ious - rho2 / c2
+ loss = 1 - dious
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def ciou_loss(pred, target, eps=1e-7):
+ r"""`Implementation of paper `Enhancing Geometric Factors into
+ Model Learning and Inference for Object Detection and Instance
+ Segmentation `_.
+
+ Code is modified from https://github.com/Zzh-tju/CIoU.
+
+ Args:
+ pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
+ shape (n, 4).
+ target (Tensor): Corresponding gt bboxes, shape (n, 4).
+ eps (float): Eps to avoid log(0).
+ Return:
+ Tensor: Loss tensor.
+ """
+ # overlap
+ lt = torch.max(pred[:, :2], target[:, :2])
+ rb = torch.min(pred[:, 2:], target[:, 2:])
+ wh = (rb - lt).clamp(min=0)
+ overlap = wh[:, 0] * wh[:, 1]
+
+ # union
+ ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
+ ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
+ union = ap + ag - overlap + eps
+
+ # IoU
+ ious = overlap / union
+
+ # enclose area
+ enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
+ enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
+ enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
+
+ cw = enclose_wh[:, 0]
+ ch = enclose_wh[:, 1]
+
+ c2 = cw**2 + ch**2 + eps
+
+ b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
+ b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
+ b2_x1, b2_y1 = target[:, 0], target[:, 1]
+ b2_x2, b2_y2 = target[:, 2], target[:, 3]
+
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
+
+ left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
+ right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
+ rho2 = left + right
+
+ factor = 4 / math.pi**2
+ v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
+
+ with torch.no_grad():
+ alpha = (ious > 0.5).float() * v / (1 - ious + v)
+
+ # CIoU
+ cious = ious - (rho2 / c2 + alpha * v)
+ loss = 1 - cious.clamp(min=-1.0, max=1.0)
+ return loss
+
+
+class IoULoss(nn.Module):
+ """IoULoss.
+
+ Computing the IoU loss between a set of predicted bboxes and target bboxes.
+
+ Args:
+ linear (bool): If True, use linear scale of loss else determined
+ by mode. Default: False.
+ eps (float): Eps to avoid log(0).
+ reduction (str): Options are "none", "mean" and "sum".
+ loss_weight (float): Weight of loss.
+ mode (str): Loss scaling mode, including "linear", "square", and "log".
+ Default: 'log'
+ """
+ def __init__(self,
+ linear=False,
+ eps=1e-6,
+ reduction='mean',
+ loss_weight=1.0,
+ mode='log'):
+ super(IoULoss, self).__init__()
+ assert mode in ['linear', 'square', 'log']
+ if linear:
+ mode = 'linear'
+ warnings.warn('DeprecationWarning: Setting "linear=True" in '
+ 'IOULoss is deprecated, please use "mode=`linear`" '
+ 'instead.')
+ self.mode = mode
+ self.linear = linear
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None. Options are "none", "mean" and "sum".
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ if (weight is not None) and (not torch.any(weight > 0)) and (
+ reduction != 'none'):
+ if pred.dim() == weight.dim() + 1:
+ weight = weight.unsqueeze(1)
+ return (pred * weight).sum() # 0
+ if weight is not None and weight.dim() > 1:
+ # TODO: remove this in the future
+ # reduce the weight of shape (n, 4) to (n,) to match the
+ # iou_loss of shape (n,)
+ assert weight.shape == pred.shape
+ weight = weight.mean(-1)
+ loss = self.loss_weight * iou_loss(pred,
+ target,
+ weight,
+ mode=self.mode,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
+
+
+class BoundedIoULoss(nn.Module):
+ def __init__(self, beta=0.2, eps=1e-3, reduction='mean', loss_weight=1.0):
+ super(BoundedIoULoss, self).__init__()
+ self.beta = beta
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ if weight is not None and not torch.any(weight > 0):
+ if pred.dim() == weight.dim() + 1:
+ weight = weight.unsqueeze(1)
+ return (pred * weight).sum() # 0
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ loss = self.loss_weight * bounded_iou_loss(pred,
+ target,
+ weight,
+ beta=self.beta,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
+
+
+class GIoULoss(nn.Module):
+ def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
+ super(GIoULoss, self).__init__()
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ if weight is not None and not torch.any(weight > 0):
+ if pred.dim() == weight.dim() + 1:
+ weight = weight.unsqueeze(1)
+ return (pred * weight).sum() # 0
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ if weight is not None and weight.dim() > 1:
+ # TODO: remove this in the future
+ # reduce the weight of shape (n, 4) to (n,) to match the
+ # giou_loss of shape (n,)
+ assert weight.shape == pred.shape
+ weight = weight.mean(-1)
+ loss = self.loss_weight * giou_loss(pred,
+ target,
+ weight,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
+
+
+class DIoULoss(nn.Module):
+ def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
+ super(DIoULoss, self).__init__()
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ if weight is not None and not torch.any(weight > 0):
+ if pred.dim() == weight.dim() + 1:
+ weight = weight.unsqueeze(1)
+ return (pred * weight).sum() # 0
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ if weight is not None and weight.dim() > 1:
+ # TODO: remove this in the future
+ # reduce the weight of shape (n, 4) to (n,) to match the
+ # giou_loss of shape (n,)
+ assert weight.shape == pred.shape
+ weight = weight.mean(-1)
+ loss = self.loss_weight * diou_loss(pred,
+ target,
+ weight,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
+
+
+class CIoULoss(nn.Module):
+ def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
+ super(CIoULoss, self).__init__()
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ if weight is not None and not torch.any(weight > 0):
+ if pred.dim() == weight.dim() + 1:
+ weight = weight.unsqueeze(1)
+ return (pred * weight).sum() # 0
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ if weight is not None and weight.dim() > 1:
+ # TODO: remove this in the future
+ # reduce the weight of shape (n, 4) to (n,) to match the
+ # giou_loss of shape (n,)
+ assert weight.shape == pred.shape
+ weight = weight.mean(-1)
+ loss = self.loss_weight * ciou_loss(pred,
+ target,
+ weight,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
diff --git a/detrsmpl/models/losses/mse_loss.py b/detrsmpl/models/losses/mse_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..24d3ae6e5f5b3b16cb4c0501f246c75320b4fdb0
--- /dev/null
+++ b/detrsmpl/models/losses/mse_loss.py
@@ -0,0 +1,171 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .utils import weighted_loss
+
+
+def gmof(x, sigma):
+ """Geman-McClure error function."""
+ x_squared = x**2
+ sigma_squared = sigma**2
+ return (sigma_squared * x_squared) / (sigma_squared + x_squared)
+
+
+@weighted_loss
+def mse_loss(pred, target):
+ """Warpper of mse loss."""
+ return F.mse_loss(pred, target, reduction='none')
+
+
+@weighted_loss
+def mse_loss_with_gmof(pred, target, sigma):
+ """Extended MSE Loss with GMOF."""
+ loss = F.mse_loss(pred, target, reduction='none')
+ loss = gmof(loss, sigma)
+ return loss
+
+
+class MSELoss(nn.Module):
+ """MSELoss.
+
+ Args:
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ """
+ def __init__(self, reduction='mean', loss_weight=1.0):
+ super().__init__()
+ assert reduction in (None, 'none', 'mean', 'sum')
+ reduction = 'none' if reduction is None else reduction
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function of loss.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): Weight of the loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ loss = self.loss_weight * mse_loss(
+ pred, target, weight, reduction=reduction, avg_factor=avg_factor)
+ return loss
+
+
+class KeypointMSELoss(nn.Module):
+ """MSELoss for 2D and 3D keypoints.
+
+ Args:
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ sigma (float, optional): Weighing parameter of Geman-McClure
+ error function. Defaults to 1.0 (no effect).
+ keypoint_weight (List[float], optional): Weighing parameter for each
+ keypoint. Shape should be (K). K: number of keypoints. Defaults to
+ None (no effect).
+ """
+ def __init__(self,
+ reduction='mean',
+ loss_weight=1.0,
+ sigma=1.0,
+ keypoint_weight=None):
+ super().__init__()
+ assert reduction in (None, 'none', 'mean', 'sum')
+ reduction = 'none' if reduction is None else reduction
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.sigma = sigma
+ if keypoint_weight is None:
+ self.keypoint_weight = None
+ else:
+ self.keypoint_weight = torch.Tensor(keypoint_weight)
+
+ def forward(self,
+ pred,
+ target,
+ pred_conf=None,
+ target_conf=None,
+ keypoint_weight=None,
+ avg_factor=None,
+ loss_weight_override=None,
+ reduction_override=None):
+ """Forward function of loss.
+
+ Args:
+ pred (torch.Tensor): The prediction. Shape should be (N, K, 2/3)
+ B: batch size. K: number of keypoints.
+ target (torch.Tensor): The learning target of the prediction.
+ Shape should be the same as pred.
+ pred_conf (optional, torch.Tensor): Confidence of
+ predicted keypoints. Shape should be (N, K).
+ target_conf (optional, torch.Tensor): Confidence of
+ target keypoints. Shape should be the same as pred_conf.
+ keypoint_weight (optional, torch.Tensor): keypoint-wise weight.
+ shape should be (K,). This weight allow different weights
+ to be assigned at different body parts.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ loss_weight_override (float, optional): The overall weight of loss
+ used to override the original weight of loss.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ loss_weight = (loss_weight_override if loss_weight_override is not None
+ else self.loss_weight)
+
+ B, K, D = pred.shape
+ pred_conf = pred_conf.view((B, K, 1)) \
+ if pred_conf is not None else 1.0
+ target_conf = target_conf.view((B, K, 1)) \
+ if target_conf is not None else 1.0
+ keypoint_weight = keypoint_weight.view((1, K, 1)) \
+ if keypoint_weight is not None else \
+ self.keypoint_weight.view((1, K, 1)).type_as(pred) \
+ if self.keypoint_weight is not None else 1.0
+
+ weight = keypoint_weight * pred_conf * target_conf
+ assert isinstance(
+ weight,
+ float) or weight.shape == (B, K, 1) or weight.shape == (1, K, 1)
+
+ # B, J, D = pred.shape[:2]
+ # if len(weight.shape) == 1:
+ # # for simplify tools
+ # weight = weight.view(1, -1, 1)
+ # else:
+ # # for body model estimator
+ # weight = weight.view(B, J, 1)
+
+ loss = loss_weight * mse_loss_with_gmof(pred,
+ target,
+ weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ sigma=self.sigma)
+
+ return loss
diff --git a/detrsmpl/models/losses/prior_loss.py b/detrsmpl/models/losses/prior_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..5509916afd539074f871856053c7bd61d8f2ccf3
--- /dev/null
+++ b/detrsmpl/models/losses/prior_loss.py
@@ -0,0 +1,754 @@
+import itertools
+import os
+import pickle
+import sys
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from detrsmpl.core.conventions.joints_mapping.standard_joint_angles import (
+ STANDARD_JOINT_ANGLE_LIMITS,
+ TRANSFORMATION_AA_TO_SJA,
+ TRANSFORMATION_SJA_TO_AA,
+)
+from detrsmpl.utils.keypoint_utils import search_limbs
+from detrsmpl.utils.transforms import aa_to_rot6d, aa_to_sja
+
+
+class ShapePriorLoss(nn.Module):
+ """Prior loss for body shape parameters.
+
+ Args:
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ """
+ def __init__(self, reduction='mean', loss_weight=1.0):
+ super().__init__()
+ assert reduction in (None, 'none', 'mean', 'sum')
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ betas,
+ loss_weight_override=None,
+ reduction_override=None):
+ """Forward function of loss.
+
+ Args:
+ betas (torch.Tensor): The body shape parameters
+ loss_weight_override (float, optional): The weight of loss used to
+ override the original weight of loss
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ loss_weight = (loss_weight_override if loss_weight_override is not None
+ else self.loss_weight)
+
+ shape_prior_loss = loss_weight * betas**2
+
+ if reduction == 'mean':
+ shape_prior_loss = shape_prior_loss.mean()
+ elif reduction == 'sum':
+ shape_prior_loss = shape_prior_loss.sum()
+
+ return shape_prior_loss
+
+
+class ShapeThresholdPriorLoss(nn.Module):
+ """Threshold loss for betas. Soft constraint to prevent parameters for
+ leaving feasible set. Implements a penalty constraint that encourages the
+ parameters to stay in the feasible set of solutions.
+
+ Args:
+ margin (int, optional): The threshold value
+ norm (str, optional): The loss method. Options are 'l1', l2'
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ """
+ def __init__(self, margin=1, norm='l2', epsilon=1e-7, loss_weight=1.0):
+ super().__init__()
+ self.margin = margin
+ assert norm in ['l1', 'l2'], 'Norm variable must me l1 or l2'
+ self.norm = norm
+ self.epsilon = epsilon
+ self.loss_weight = loss_weight
+
+ def forward(self, betas):
+ """Forward function of loss.
+
+ Args:
+ betas (torch.Tensor): The body shape parameters
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ abs_values = betas.abs()
+ mask = abs_values.gt(self.margin)
+ invalid_values = torch.masked_select(betas, mask)
+
+ if self.norm == 'l1':
+ return self.loss_weight * invalid_values.abs().sum() / (
+ mask.to(dtype=betas.dtype).sum() + self.epsilon)
+ elif self.norm == 'l2':
+ return self.loss_weight * invalid_values.pow(2).sum() / (
+ mask.to(dtype=betas.dtype).sum() + self.epsilon)
+
+
+class PoseRegLoss(nn.Module):
+ """Regulizer loss for body pose parameters.
+
+ Args:
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ """
+ def __init__(self, reduction='mean', loss_weight=1.0):
+ super().__init__()
+ assert reduction in (None, 'none', 'mean', 'sum')
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ body_pose,
+ weight=None,
+ avg_factor=None,
+ loss_weight_override=None,
+ reduction_override=None):
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ loss_weight = (loss_weight_override if loss_weight_override is not None
+ else self.loss_weight)
+
+ pose_prior_loss = loss_weight * (body_pose**2)
+
+ if reduction == 'mean':
+ pose_prior_loss = pose_prior_loss.mean()
+ elif reduction == 'sum':
+ pose_prior_loss = pose_prior_loss.sum()
+
+ return pose_prior_loss
+
+
+class LimbLengthLoss(nn.Module):
+ """Limb length loss for body shape parameters. As betas are associated with
+ the height of a person, fitting on limb length help determine body shape
+ parameters. It penalizes the L2 distance between target limb length and
+ pred limb length. Note that it should take keypoints3d as input, as limb
+ length computed from keypoints2d varies with camera.
+
+ Args:
+ convention (str): Limb convention to search for keypoint connections.
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ eps (float, optional): epsilon for computing normalized limb vector.
+ Defaults to 1e-4.
+ """
+ def __init__(self,
+ convention,
+ reduction='mean',
+ loss_weight=1.0,
+ eps=1e-4):
+ super().__init__()
+ assert reduction in (None, 'none', 'mean', 'sum')
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.eps = eps
+ limb_idxs, _ = search_limbs(data_source=convention)
+ limb_idxs = sorted(limb_idxs['body'])
+ self.limb_idxs = np.array(
+ list(x for x, _ in itertools.groupby(limb_idxs)))
+
+ def _compute_limb_length(self, keypoints3d):
+ kp_src = keypoints3d[:, self.limb_idxs[:, 0], :3]
+ kp_dst = keypoints3d[:, self.limb_idxs[:, 1], :3]
+ limb_vec = kp_dst - kp_src
+ limb_length = torch.norm(limb_vec, dim=2)
+ return limb_length
+
+ def _keypoint_conf_to_limb_conf(self, keypoint_conf):
+ limb_conf = torch.min(keypoint_conf[:, self.limb_idxs[:, 1]],
+ keypoint_conf[:, self.limb_idxs[:, 0]])
+ return limb_conf
+
+ def forward(self,
+ pred,
+ target,
+ pred_conf=None,
+ target_conf=None,
+ loss_weight_override=None,
+ reduction_override=None):
+ """Forward function of LimbLengthLoss.
+
+ Args:
+ pred (torch.Tensor): The predicted smpl keypoints3d.
+ Shape should be (N, K, 3).
+ B: batch size. K: number of keypoints.
+ target (torch.Tensor): The ground-truth keypoints3d.
+ Shape should be (N, K, 3).
+ pred_conf (torch.Tensor, optional): Confidence of
+ predicted keypoints. Shape should be (N, K).
+ target_conf (torch.Tensor, optional): Confidence of
+ target keypoints. Shape should be (N, K).
+ loss_weight_override (float, optional): The weight of loss used to
+ override the original weight of loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert pred.dim() == 3 and pred.shape[-1] == 3
+ assert pred.shape == target.shape
+ if pred_conf is not None:
+ assert pred_conf.dim() == 2
+ assert pred_conf.shape == pred.shape[:2]
+ if target_conf is not None:
+ assert target_conf.dim() == 2
+ assert target_conf.shape == target.shape[:2]
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ loss_weight = (loss_weight_override if loss_weight_override is not None
+ else self.loss_weight)
+
+ limb_len_target = self._compute_limb_length(target)
+ limb_len_pred = self._compute_limb_length(pred)
+
+ if target_conf is None:
+ target_conf = torch.ones_like(target[..., 0])
+ if pred_conf is None:
+ pred_conf = torch.ones_like(pred[..., 0])
+ limb_conf_target = self._keypoint_conf_to_limb_conf(target_conf)
+ limb_conf_pred = self._keypoint_conf_to_limb_conf(pred_conf)
+ limb_conf = limb_conf_target * limb_conf_pred
+
+ diff_len = limb_len_target - limb_len_pred
+ loss = diff_len**2 * limb_conf
+
+ if reduction == 'mean':
+ loss = loss.mean()
+ elif reduction == 'sum':
+ loss = loss.sum()
+
+ loss *= loss_weight
+
+ return loss
+
+
+class JointPriorLoss(nn.Module):
+ """Prior loss for joint angles.
+
+ Args:
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ use_full_body (bool, optional): Use full set of joint constraints
+ (in standard joint angles).
+ smooth_spine (bool, optional): Ensuring smooth spine rotations
+ smooth_spine_loss_weight (float, optional): An additional weight
+ factor multiplied on smooth spine loss
+ """
+ def __init__(self,
+ reduction='mean',
+ loss_weight=1.0,
+ use_full_body=False,
+ smooth_spine=False,
+ smooth_spine_loss_weight=1.0):
+ super().__init__()
+ assert reduction in (None, 'none', 'mean', 'sum')
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.use_full_body = use_full_body
+ self.smooth_spine = smooth_spine
+ self.smooth_spine_loss_weight = smooth_spine_loss_weight
+
+ if self.use_full_body:
+ self.register_buffer('R_t', TRANSFORMATION_AA_TO_SJA)
+ self.register_buffer('R_t_inv', TRANSFORMATION_SJA_TO_AA)
+ self.register_buffer('sja_limits', STANDARD_JOINT_ANGLE_LIMITS)
+
+ def forward(self,
+ body_pose,
+ loss_weight_override=None,
+ reduction_override=None):
+ """Forward function of loss.
+
+ Args:
+ body_pose (torch.Tensor): The body pose parameters
+ loss_weight_override (float, optional): The weight of loss used to
+ override the original weight of loss
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ loss_weight = (loss_weight_override if loss_weight_override is not None
+ else self.loss_weight)
+
+ if self.use_full_body:
+ batch_size = body_pose.shape[0]
+ body_pose_reshape = body_pose.reshape(batch_size, -1, 3)
+ assert body_pose_reshape.shape[1] in (21, 23) # smpl-x, smpl
+ body_pose_reshape = body_pose_reshape[:, :21, :]
+
+ body_pose_sja = aa_to_sja(body_pose_reshape, self.R_t,
+ self.R_t_inv)
+
+ lower_limits = self.sja_limits[:, :, 0] # shape: (21, 3)
+ upper_limits = self.sja_limits[:, :, 1] # shape: (21, 3)
+
+ lower_loss = (torch.exp(F.relu(lower_limits - body_pose_sja)) -
+ 1).pow(2)
+ upper_loss = (torch.exp(F.relu(body_pose_sja - upper_limits)) -
+ 1).pow(2)
+
+ standard_joint_angle_prior_loss = (lower_loss + upper_loss).view(
+ body_pose.shape[0], -1) # shape: (n, 3)
+
+ joint_prior_loss = standard_joint_angle_prior_loss
+
+ else:
+ # default joint prior loss applied on elbows and knees
+ joint_prior_loss = (torch.exp(
+ body_pose[:, [55, 58, 12, 15]] *
+ torch.tensor([1., -1., -1, -1.], device=body_pose.device)) -
+ 1)**2
+
+ if self.smooth_spine:
+ spine1 = body_pose[:, [9, 10, 11]]
+ spine2 = body_pose[:, [18, 19, 20]]
+ spine3 = body_pose[:, [27, 28, 29]]
+ smooth_spine_loss_12 = (torch.exp(F.relu(-spine1 * spine2)) -
+ 1).pow(2) * self.smooth_spine_loss_weight
+ smooth_spine_loss_23 = (torch.exp(F.relu(-spine2 * spine3)) -
+ 1).pow(2) * self.smooth_spine_loss_weight
+
+ joint_prior_loss = torch.cat(
+ [joint_prior_loss, smooth_spine_loss_12, smooth_spine_loss_23],
+ axis=1)
+
+ joint_prior_loss = loss_weight * joint_prior_loss
+
+ if reduction == 'mean':
+ joint_prior_loss = joint_prior_loss.mean()
+ elif reduction == 'sum':
+ joint_prior_loss = joint_prior_loss.sum()
+
+ return joint_prior_loss
+
+
+class SmoothJointLoss(nn.Module):
+ """Smooth loss for joint angles.
+
+ Args:
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ degree (bool, optional): The flag which represents whether the input
+ tensor is in degree or radian.
+ """
+ def __init__(self,
+ reduction='mean',
+ loss_weight=1.0,
+ degree=False,
+ loss_func='L1'):
+ super().__init__()
+ assert reduction in (None, 'none', 'mean', 'sum')
+ assert loss_func in ('L1', 'L2')
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.degree = degree
+ self.loss_func = loss_func
+
+ def forward(self,
+ body_pose,
+ loss_weight_override=None,
+ reduction_override=None):
+ """Forward function of SmoothJointLoss.
+
+ Args:
+ body_pose (torch.Tensor): The body pose parameters
+ loss_weight_override (float, optional): The weight of loss used to
+ override the original weight of loss
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ loss_weight = (loss_weight_override if loss_weight_override is not None
+ else self.loss_weight)
+
+ theta = body_pose.reshape(body_pose.shape[0], -1, 3)
+ if self.degree:
+ theta = torch.deg2rad(theta)
+ rot_6d = aa_to_rot6d(theta)
+ rot_6d_diff = rot_6d[1:] - rot_6d[:-1]
+
+ if self.loss_func == 'L2':
+ smooth_joint_loss = (rot_6d_diff**2).sum(dim=[1, 2])
+ elif self.loss_func == 'L1':
+ smooth_joint_loss = rot_6d_diff.abs().sum(dim=[1, 2])
+ else:
+ raise TypeError(f'{self.func} is not defined')
+
+ # add zero padding to retain original batch_size
+ smooth_joint_loss = torch.cat(
+ [torch.zeros_like(smooth_joint_loss)[:1], smooth_joint_loss])
+
+ if reduction == 'mean':
+ smooth_joint_loss = smooth_joint_loss.mean()
+ elif reduction == 'sum':
+ smooth_joint_loss = smooth_joint_loss.sum()
+
+ smooth_joint_loss *= loss_weight
+
+ return smooth_joint_loss
+
+
+class SmoothPelvisLoss(nn.Module):
+ """Smooth loss for pelvis angles.
+
+ Args:
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ degree (bool, optional): The flag which represents whether the input
+ tensor is in degree or radian.
+ """
+ def __init__(self, reduction='mean', loss_weight=1.0, degree=False):
+ super().__init__()
+ assert reduction in (None, 'none', 'mean', 'sum')
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.degree = degree
+
+ def forward(self,
+ global_orient,
+ loss_weight_override=None,
+ reduction_override=None):
+ """Forward function of SmoothPelvisLoss.
+
+ Args:
+ global_orient (torch.Tensor): The global orientation parameters
+ loss_weight_override (float, optional): The weight of loss used to
+ override the original weight of loss
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ loss_weight = (loss_weight_override if loss_weight_override is not None
+ else self.loss_weight)
+
+ if self.degree:
+ global_orient = torch.deg2rad(global_orient)
+
+ pelvis = global_orient.unsqueeze(1)
+ rot_6d = aa_to_rot6d(pelvis)
+
+ rot_6d_diff = rot_6d[1:] - rot_6d[:-1]
+ smooth_pelvis_loss = rot_6d_diff.abs().sum(dim=-1)
+
+ # add zero padding to retain original batch_size
+ smooth_pelvis_loss = torch.cat(
+ [torch.zeros_like(smooth_pelvis_loss)[:1],
+ smooth_pelvis_loss]).sum(dim=-1)
+
+ smooth_pelvis_loss = loss_weight * smooth_pelvis_loss
+
+ if reduction == 'mean':
+ smooth_pelvis_loss = smooth_pelvis_loss.mean()
+ elif reduction == 'sum':
+ smooth_pelvis_loss = smooth_pelvis_loss.sum()
+
+ return smooth_pelvis_loss
+
+
+class SmoothTranslationLoss(nn.Module):
+ """Smooth loss for translations.
+
+ Args:
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ """
+ def __init__(self, reduction='mean', loss_weight=1.0):
+ super().__init__()
+ assert reduction in (None, 'none', 'mean', 'sum')
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ translation,
+ loss_weight_override=None,
+ reduction_override=None):
+ """Forward function of loss.
+
+ Args:
+ translation (torch.Tensor): The body translation parameters
+ loss_weight_override (float, optional): The weight of loss used to
+ override the original weight of loss
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ loss_weight = (loss_weight_override if loss_weight_override is not None
+ else self.loss_weight)
+
+ translation_diff = translation[1:] - translation[:-1]
+ smooth_translation_loss = translation_diff.abs().sum(dim=-1,
+ keepdim=True)
+
+ # add zero padding to retain original batch_size
+ smooth_translation_loss = torch.cat([
+ torch.zeros_like(smooth_translation_loss)[:1],
+ smooth_translation_loss
+ ]).sum(dim=-1)
+
+ smooth_translation_loss *= 1e3
+
+ smooth_translation_loss = loss_weight * \
+ smooth_translation_loss
+
+ if reduction == 'mean':
+ smooth_translation_loss = smooth_translation_loss.mean()
+ elif reduction == 'sum':
+ smooth_translation_loss = smooth_translation_loss.sum()
+
+ return smooth_translation_loss
+
+
+class CameraPriorLoss(nn.Module):
+ """Prior loss for predicted camera.
+
+ Args:
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ scale (float, optional): The scale coefficient for regularizing camera
+ parameters. Defaults to 10
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ """
+ def __init__(self, scale=10, reduction='mean', loss_weight=1.0):
+ super().__init__()
+ self.scale = scale
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ cameras,
+ loss_weight_override=None,
+ reduction_override=None):
+ """Forward function of loss.
+
+ Args:
+ cameras (torch.Tensor): The predicted camera parameters
+ loss_weight_override (float, optional): The weight of loss used to
+ override the original weight of loss
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ loss_weight = (loss_weight_override if loss_weight_override is not None
+ else self.loss_weight)
+
+ camera_prior_loss = torch.exp(-cameras[:, 0] * self.scale)
+ camera_prior_loss = torch.pow(camera_prior_loss, 2) * loss_weight
+
+ if reduction == 'mean':
+ camera_prior_loss = camera_prior_loss.mean()
+ elif reduction == 'sum':
+ camera_prior_loss = camera_prior_loss.sum()
+
+ return camera_prior_loss
+
+
+class MaxMixturePrior(nn.Module):
+ """Ref: SMPLify-X
+ https://github.com/vchoutas/smplify-x/blob/master/smplifyx/prior.py
+ """
+ def __init__(self,
+ prior_folder='data',
+ num_gaussians=8,
+ dtype=torch.float32,
+ epsilon=1e-16,
+ use_merged=True,
+ reduction=None,
+ loss_weight=1.0):
+ super(MaxMixturePrior, self).__init__()
+
+ assert reduction in (None, 'none', 'mean', 'sum')
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ if dtype == torch.float32:
+ np_dtype = np.float32
+ elif dtype == torch.float64:
+ np_dtype = np.float64
+ else:
+ print('Unknown float type {}, exiting!'.format(dtype))
+ sys.exit(-1)
+
+ self.num_gaussians = num_gaussians
+ self.epsilon = epsilon
+ self.use_merged = use_merged
+ gmm_fn = 'gmm_{:02d}.pkl'.format(num_gaussians)
+
+ full_gmm_fn = os.path.join(prior_folder, gmm_fn)
+ if not os.path.exists(full_gmm_fn):
+ print('The path to the mixture prior "{}"'.format(full_gmm_fn) +
+ ' does not exist, exiting!')
+ sys.exit(-1)
+
+ with open(full_gmm_fn, 'rb') as f:
+ gmm = pickle.load(f, encoding='latin1')
+
+ if type(gmm) == dict:
+ means = gmm['means'].astype(np_dtype)
+ covs = gmm['covars'].astype(np_dtype)
+ weights = gmm['weights'].astype(np_dtype)
+ elif 'sklearn.mixture.gmm.GMM' in str(type(gmm)):
+ means = gmm.means_.astype(np_dtype)
+ covs = gmm.covars_.astype(np_dtype)
+ weights = gmm.weights_.astype(np_dtype)
+ else:
+ print('Unknown type for the prior: {}, exiting!'.format(type(gmm)))
+ sys.exit(-1)
+
+ self.register_buffer('means', torch.tensor(means, dtype=dtype))
+
+ self.register_buffer('covs', torch.tensor(covs, dtype=dtype))
+
+ precisions = [np.linalg.inv(cov) for cov in covs]
+ precisions = np.stack(precisions).astype(np_dtype)
+
+ self.register_buffer('precisions', torch.tensor(precisions,
+ dtype=dtype))
+
+ # The constant term:
+ sqrdets = np.array([(np.sqrt(np.linalg.det(c)))
+ for c in gmm['covars']])
+ const = (2 * np.pi)**(69 / 2.)
+
+ nll_weights = np.asarray(gmm['weights'] / (const *
+ (sqrdets / sqrdets.min())))
+ nll_weights = torch.tensor(nll_weights, dtype=dtype).unsqueeze(dim=0)
+ self.register_buffer('nll_weights', nll_weights)
+
+ weights = torch.tensor(gmm['weights'], dtype=dtype).unsqueeze(dim=0)
+ self.register_buffer('weights', weights)
+
+ self.register_buffer('pi_term',
+ torch.log(torch.tensor(2 * np.pi, dtype=dtype)))
+
+ cov_dets = [
+ np.log(np.linalg.det(cov.astype(np_dtype)) + epsilon)
+ for cov in covs
+ ]
+ self.register_buffer('cov_dets', torch.tensor(cov_dets, dtype=dtype))
+
+ # The dimensionality of the random variable
+ self.random_var_dim = self.means.shape[1]
+
+ def get_mean(self):
+ """Returns the mean of the mixture."""
+ mean_pose = torch.matmul(self.weights, self.means)
+ return mean_pose
+
+ def merged_log_likelihood(self, pose):
+ diff_from_mean = pose.unsqueeze(dim=1) - self.means
+
+ prec_diff_prod = torch.einsum('mij,bmj->bmi',
+ [self.precisions, diff_from_mean])
+ diff_prec_quadratic = (prec_diff_prod * diff_from_mean).sum(dim=-1)
+
+ curr_loglikelihood = 0.5 * diff_prec_quadratic - \
+ torch.log(self.nll_weights)
+ # curr_loglikelihood = 0.5 * (self.cov_dets.unsqueeze(dim=0) +
+ # self.random_var_dim * self.pi_term +
+ # diff_prec_quadratic
+ # ) - torch.log(self.weights)
+
+ min_likelihood, _ = torch.min(curr_loglikelihood, dim=1)
+ return min_likelihood
+
+ def log_likelihood(self, pose):
+ """Create graph operation for negative log-likelihood calculation."""
+ likelihoods = []
+
+ for idx in range(self.num_gaussians):
+ mean = self.means[idx]
+ prec = self.precisions[idx]
+ cov = self.covs[idx]
+ diff_from_mean = pose - mean
+
+ curr_loglikelihood = torch.einsum('bj,ji->bi',
+ [diff_from_mean, prec])
+ curr_loglikelihood = torch.einsum(
+ 'bi,bi->b', [curr_loglikelihood, diff_from_mean])
+ cov_term = torch.log(torch.det(cov) + self.epsilon)
+ curr_loglikelihood += 0.5 * (cov_term +
+ self.random_var_dim * self.pi_term)
+ likelihoods.append(curr_loglikelihood)
+
+ log_likelihoods = torch.stack(likelihoods, dim=1)
+ min_idx = torch.argmin(log_likelihoods, dim=1)
+ weight_component = self.nll_weights[:, min_idx]
+ weight_component = -torch.log(weight_component)
+
+ return weight_component + log_likelihoods[:, min_idx]
+
+ def forward(self,
+ body_pose,
+ loss_weight_override=None,
+ reduction_override=None):
+
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ loss_weight = (loss_weight_override if loss_weight_override is not None
+ else self.loss_weight)
+
+ if self.use_merged:
+ pose_prior_loss = self.merged_log_likelihood(body_pose)
+ else:
+ pose_prior_loss = self.log_likelihood(body_pose)
+
+ pose_prior_loss = loss_weight * pose_prior_loss
+
+ if reduction == 'mean':
+ pose_prior_loss = pose_prior_loss.mean()
+ elif reduction == 'sum':
+ pose_prior_loss = pose_prior_loss.sum()
+
+ return pose_prior_loss
diff --git a/detrsmpl/models/losses/rotaion_distance_loss.py b/detrsmpl/models/losses/rotaion_distance_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..51118826fcd737d7d6728dd5b065c366727d8b84
--- /dev/null
+++ b/detrsmpl/models/losses/rotaion_distance_loss.py
@@ -0,0 +1,62 @@
+import torch
+import torch.nn as nn
+
+
+def rotation_distance_loss(pred, target, epsilon):
+ """Warpper of rotation distance loss."""
+ tr = torch.einsum(
+ 'bij,bij->b',
+ [pred.view(-1, 3, 3), target.view(-1, 3, 3)])
+ theta = (tr - 1) * 0.5
+ loss = torch.acos(torch.clamp(theta, -1 + epsilon, 1 - epsilon))
+ return loss
+
+
+class RotationDistance(nn.Module):
+ """Rotation Distance Loss.
+
+ Args:
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ epsilon (float, optional): A minimal value to avoid NaN.
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ """
+ def __init__(self, reduction='mean', epsilon=1e-7, loss_weight=1.0):
+ super(RotationDistance, self).__init__()
+ assert reduction in (None, 'none', 'mean', 'sum')
+ reduction = 'none' if reduction is None else reduction
+ self.reduction = reduction
+ self.epsilon = epsilon
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function of loss.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): Weight of the loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ loss = self.loss_weight * rotation_distance_loss(
+ pred, target, epsilon=self.epsilon)
+ if weight is not None:
+ loss = loss.view(pred.shape[0], -1) * weight.view(
+ pred.shape[0], -1)
+ return loss.sum() / (weight.gt(0).sum() + self.epsilon)
+ else:
+ return loss.sum() / pred.shape[0]
diff --git a/detrsmpl/models/losses/smooth_l1_loss.py b/detrsmpl/models/losses/smooth_l1_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..63cf2709503ec72955356a7beca14f93a1a9b5eb
--- /dev/null
+++ b/detrsmpl/models/losses/smooth_l1_loss.py
@@ -0,0 +1,128 @@
+import torch
+import torch.nn as nn
+
+from .utils import weighted_loss
+
+
+@weighted_loss
+def smooth_l1_loss(pred, target, beta=1.0):
+ """Smooth L1 loss.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ beta (float, optional): The threshold in the piecewise function.
+ Defaults to 1.0.
+ Returns:
+ torch.Tensor: Calculated loss
+ """
+ assert beta > 0
+ assert pred.size() == target.size() and target.numel() > 0
+ diff = torch.abs(pred - target)
+ loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
+ diff - 0.5 * beta)
+ return loss
+
+
+@weighted_loss
+def l1_loss(pred, target):
+ """L1 loss.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ Returns:
+ torch.Tensor: Calculated loss
+ """
+ assert pred.size() == target.size() and target.numel() > 0
+ loss = torch.abs(pred - target)
+ return loss
+
+
+class SmoothL1Loss(nn.Module):
+ """Smooth L1 loss.
+
+ Args:
+ beta (float, optional): The threshold in the piecewise function.
+ Defaults to 1.0.
+ reduction (str, optional): The method to reduce the loss.
+ Options are "none", "mean" and "sum". Defaults to "mean".
+ loss_weight (float, optional): The weight of loss.
+ """
+ def __init__(self, beta=1.0, reduction='mean', loss_weight=1.0):
+ super(SmoothL1Loss, self).__init__()
+ self.beta = beta
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ loss = self.loss_weight * smooth_l1_loss(pred,
+ target,
+ weight,
+ beta=self.beta,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
+
+
+class L1Loss(nn.Module):
+ """L1 loss.
+
+ Args:
+ reduction (str, optional): The method to reduce the loss.
+ Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of loss.
+ """
+ def __init__(self, reduction='mean', loss_weight=1.0):
+ super(L1Loss, self).__init__()
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (reduction_override
+ if reduction_override else self.reduction)
+ loss = self.loss_weight * l1_loss(
+ pred, target, weight, reduction=reduction, avg_factor=avg_factor)
+ return loss
diff --git a/detrsmpl/models/losses/utils.py b/detrsmpl/models/losses/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..322a0f149ef0a975fad3b2ee41c94757e9af0e37
--- /dev/null
+++ b/detrsmpl/models/losses/utils.py
@@ -0,0 +1,119 @@
+import functools
+
+import torch
+import torch.nn.functional as F
+
+
+def reduce_loss(loss, reduction):
+ """Reduce loss as specified.
+
+ Args:
+ loss (Tensor): Elementwise loss tensor.
+ reduction (str): Options are "none", "mean" and "sum".
+
+ Return:
+ Tensor: Reduced loss tensor.
+ """
+ reduction_enum = F._Reduction.get_enum(reduction)
+ # none: 0, elementwise_mean:1, sum: 2
+ if reduction_enum == 0:
+ return loss
+ elif reduction_enum == 1:
+ return loss.mean()
+ elif reduction_enum == 2:
+ return loss.sum()
+
+
+def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
+ """Apply element-wise weight and reduce loss.
+
+ Args:
+ loss (Tensor): Element-wise loss.
+ weight (Tensor): Element-wise weights.
+ reduction (str): Same as built-in losses of PyTorch.
+ avg_factor (float): Average factor when computing the mean of losses.
+
+ Returns:
+ Tensor: Processed loss values.
+ """
+ # if weight is specified, apply element-wise weight
+ if weight is not None:
+ loss = loss * weight
+
+ # if avg_factor is not specified, just reduce the loss
+ if avg_factor is None:
+ loss = reduce_loss(loss, reduction)
+ else:
+ # if reduction is mean, then average the loss by avg_factor
+ if reduction == 'mean':
+ loss = loss.sum() / avg_factor
+ # if reduction is 'none', then do nothing, otherwise raise an error
+ elif reduction != 'none':
+ raise ValueError('avg_factor can not be used with reduction="sum"')
+ return loss
+
+
+def weighted_loss(loss_func):
+ """Create a weighted version of a given loss function.
+
+ To use this decorator, the loss function must have the signature like
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
+ element-wise loss without any reduction. This decorator will add weight
+ and reduction arguments to the function. The decorated function will have
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
+ avg_factor=None, **kwargs)`.
+
+ :Example:
+
+ >>> import torch
+ >>> @weighted_loss
+ >>> def l1_loss(pred, target):
+ >>> return (pred - target).abs()
+
+ >>> pred = torch.Tensor([0, 2, 3])
+ >>> target = torch.Tensor([1, 1, 1])
+ >>> weight = torch.Tensor([1, 0, 1])
+
+ >>> l1_loss(pred, target)
+ tensor(1.3333)
+ >>> l1_loss(pred, target, weight)
+ tensor(1.)
+ >>> l1_loss(pred, target, reduction='none')
+ tensor([1., 1., 2.])
+ >>> l1_loss(pred, target, weight, avg_factor=2)
+ tensor(1.5000)
+ """
+ @functools.wraps(loss_func)
+ def wrapper(pred,
+ target,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ **kwargs):
+ # get element-wise loss
+ loss = loss_func(pred, target, **kwargs)
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+ return wrapper
+
+
+def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor:
+ """This function converts target class indices to one-hot vectors, given
+ the number of classes.
+
+ Args:
+ targets (Tensor): The ground truth label of the prediction
+ with shape (N, 1)
+ classes (int): the number of classes.
+
+ Returns:
+ Tensor: Processed loss values.
+ """
+ assert (torch.max(targets).item() <
+ classes), 'Class Index must be less than number of classes'
+ one_hot_targets = torch.zeros((targets.shape[0], classes),
+ dtype=torch.long,
+ device=targets.device)
+ one_hot_targets.scatter_(1, targets.long(), 1)
+ return one_hot_targets
diff --git a/detrsmpl/models/necks/__init__.py b/detrsmpl/models/necks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..395994c223756dacb7d5f78063f71d42f6741bf7
--- /dev/null
+++ b/detrsmpl/models/necks/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .channel_mapper import ChannelMapper
+
+__all__ = ['ChannelMapper']
diff --git a/detrsmpl/models/necks/builder.py b/detrsmpl/models/necks/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..65c9b0c1ccf0999dc11d38ba9086d5cd7a009bae
--- /dev/null
+++ b/detrsmpl/models/necks/builder.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from mmcv.utils import Registry
+
+from .temporal_encoder import TemporalGRUEncoder
+
+NECKS = Registry('necks')
+
+NECKS.register_module(name='TemporalGRUEncoder', module=TemporalGRUEncoder)
+
+
+def build_neck(cfg):
+ """Build neck."""
+ if cfg is None:
+ return None
+ return NECKS.build(cfg)
diff --git a/detrsmpl/models/necks/channel_mapper.py b/detrsmpl/models/necks/channel_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..eac9a72b30be99a689ea4477c6090258f7001074
--- /dev/null
+++ b/detrsmpl/models/necks/channel_mapper.py
@@ -0,0 +1,98 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule
+
+from .builder import NECKS
+
+
+@NECKS.register_module()
+class ChannelMapper(BaseModule):
+ r"""Channel Mapper to reduce/increase channels of backbone features.
+
+ This is used to reduce/increase channels of backbone features.
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale).
+ kernel_size (int, optional): kernel_size for reducing channels (used
+ at each scale). Default: 3.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: None.
+ act_cfg (dict, optional): Config dict for activation layer in
+ ConvModule. Default: dict(type='ReLU').
+ num_outs (int, optional): Number of output feature maps. There
+ would be extra_convs when num_outs larger than the length
+ of in_channels.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Example:
+ >>> import torch
+ >>> in_channels = [2, 3, 5, 7]
+ >>> scales = [340, 170, 84, 43]
+ >>> inputs = [torch.rand(1, c, s, s)
+ ... for c, s in zip(in_channels, scales)]
+ >>> self = ChannelMapper(in_channels, 11, 3).eval()
+ >>> outputs = self.forward(inputs)
+ >>> for i in range(len(outputs)):
+ ... print(f'outputs[{i}].shape = {outputs[i].shape}')
+ outputs[0].shape = torch.Size([1, 11, 340, 340])
+ outputs[1].shape = torch.Size([1, 11, 170, 170])
+ outputs[2].shape = torch.Size([1, 11, 84, 84])
+ outputs[3].shape = torch.Size([1, 11, 43, 43])
+ """
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ num_outs=None,
+ init_cfg=dict(type='Xavier',
+ layer='Conv2d',
+ distribution='uniform')):
+ super(ChannelMapper, self).__init__(init_cfg)
+ assert isinstance(in_channels, list)
+ self.extra_convs = None
+ if num_outs is None:
+ num_outs = len(in_channels)
+ self.convs = nn.ModuleList()
+ for in_channel in in_channels:
+ self.convs.append(
+ ConvModule(in_channel,
+ out_channels,
+ kernel_size,
+ padding=(kernel_size - 1) // 2,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ if num_outs > len(in_channels):
+ self.extra_convs = nn.ModuleList()
+ for i in range(len(in_channels), num_outs):
+ if i == len(in_channels):
+ in_channel = in_channels[-1]
+ else:
+ in_channel = out_channels
+ self.extra_convs.append(
+ ConvModule(in_channel,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == len(self.convs)
+ outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
+ if self.extra_convs:
+ for i in range(len(self.extra_convs)):
+ if i == 0:
+ outs.append(self.extra_convs[0](inputs[-1]))
+ else:
+ outs.append(self.extra_convs[i](outs[-1]))
+ return tuple(outs)
diff --git a/detrsmpl/models/necks/temporal_encoder.py b/detrsmpl/models/necks/temporal_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccb0da5ff7225dcb9263f22fbd0d8bc549e5bdcd
--- /dev/null
+++ b/detrsmpl/models/necks/temporal_encoder.py
@@ -0,0 +1,41 @@
+from typing import Optional, Union
+
+import torch.nn as nn
+from mmcv.runner.base_module import BaseModule
+
+
+class TemporalGRUEncoder(BaseModule):
+ """TemporalEncoder used for VIBE. Adapted from
+ https://github.com/mkocabas/VIBE.
+
+ Args:
+ input_size (int, optional): dimension of input feature. Default: 2048.
+ num_layer (int, optional): number of layers for GRU. Default: 1.
+ hidden_size (int, optional): hidden size for GRU. Default: 2048.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+ def __init__(self,
+ input_size: Optional[int] = 2048,
+ num_layers: Optional[int] = 1,
+ hidden_size: Optional[int] = 2048,
+ init_cfg: Optional[Union[list, dict, None]] = None):
+ super(TemporalGRUEncoder, self).__init__(init_cfg)
+
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.gru = nn.GRU(input_size=input_size,
+ hidden_size=hidden_size,
+ bidirectional=False,
+ num_layers=num_layers)
+ self.relu = nn.ReLU()
+ self.linear = self.linear = nn.Linear(hidden_size, input_size)
+
+ def forward(self, x):
+ N, T = x.shape[:2]
+ x = x.permute(1, 0, 2)
+ y, _ = self.gru(x)
+ y = self.linear(self.relu(y).view(-1, self.hidden_size))
+ y = y.view(T, N, self.input_size) + x
+ y = y.permute(1, 0, 2).contiguous()
+ return y
diff --git a/detrsmpl/models/registrants/__init__.py b/detrsmpl/models/registrants/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/models/registrants/builder.py b/detrsmpl/models/registrants/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9da8c4e72cf0ea797b6148d45281668b02ef3940
--- /dev/null
+++ b/detrsmpl/models/registrants/builder.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from mmcv.utils import Registry
+
+from .smplify import SMPLify
+from .smplifyx import SMPLifyX
+
+REGISTRANTS = Registry('registrants')
+
+REGISTRANTS.register_module(name='SMPLify', module=SMPLify)
+REGISTRANTS.register_module(name='SMPLifyX', module=SMPLifyX)
+
+
+def build_registrant(cfg):
+ """Build registrant."""
+ if cfg is None:
+ return None
+ return REGISTRANTS.build(cfg)
diff --git a/detrsmpl/models/registrants/smplify.py b/detrsmpl/models/registrants/smplify.py
new file mode 100644
index 0000000000000000000000000000000000000000..b37cb8094aac246d1082728b57a9411f9161a867
--- /dev/null
+++ b/detrsmpl/models/registrants/smplify.py
@@ -0,0 +1,829 @@
+from typing import List, Tuple, Union
+
+import numpy as np
+import torch
+from mmcv.runner import build_optimizer
+
+from detrsmpl.core.cameras import build_cameras
+from detrsmpl.core.conventions.keypoints_mapping import (
+ get_keypoint_idx,
+ get_keypoint_idxs_by_part,
+)
+from ..body_models.builder import build_body_model
+from ..losses.builder import build_loss
+
+
+class OptimizableParameters():
+ """Collects parameters for optimization."""
+
+ def __init__(self):
+ self.opt_params = []
+
+ def set_param(self, fit_param: torch.Tensor, param: torch.Tensor) -> None:
+ """Set requires_grad and collect parameters for optimization.
+
+ Args:
+ fit_param: whether to optimize this body model parameter
+ param: body model parameter
+
+ Returns:
+ None
+ """
+ if fit_param:
+ param.requires_grad = True
+ self.opt_params.append(param)
+ else:
+ param.requires_grad = False
+
+ def parameters(self) -> List[torch.Tensor]:
+ """Returns parameters. Compatible with mmcv's build_parameters()
+
+ Returns:
+ opt_params: a list of body model parameters for optimization
+ """
+ return self.opt_params
+
+
+class SMPLify(object):
+ """Re-implementation of SMPLify with extended features.
+
+ - video input
+ - 3D keypoints
+ """
+
+ def __init__(self,
+ body_model: Union[dict, torch.nn.Module],
+ num_epochs: int = 20,
+ camera: Union[dict, torch.nn.Module] = None,
+ img_res: Union[Tuple[int], int] = 224,
+ stages: dict = None,
+ optimizer: dict = None,
+ keypoints2d_loss: dict = None,
+ keypoints3d_loss: dict = None,
+ shape_prior_loss: dict = None,
+ joint_prior_loss: dict = None,
+ smooth_loss: dict = None,
+ pose_prior_loss: dict = None,
+ pose_reg_loss: dict = None,
+ limb_length_loss: dict = None,
+ use_one_betas_per_video: bool = False,
+ ignore_keypoints: List[int] = None,
+ device=torch.device(
+ 'cuda' if torch.cuda.is_available() else 'cpu'),
+ verbose: bool = False) -> None:
+ """
+ Args:
+ body_model: config or an object of body model.
+ num_epochs: number of epochs of registration
+ camera: config or an object of camera
+ img_res: image resolution. If tuple, values are (width, height)
+ stages: config of registration stages
+ optimizer: config of optimizer
+ keypoints2d_loss: config of keypoint 2D loss
+ keypoints3d_loss: config of keypoint 3D loss
+ shape_prior_loss: config of shape prior loss.
+ Used to prevent extreme shapes.
+ joint_prior_loss: config of joint prior loss.
+ Used to prevent large joint rotations.
+ smooth_loss: config of smooth loss.
+ Used to prevent jittering by temporal smoothing.
+ pose_prior_loss: config of pose prior loss.
+ Used to prevent unnatural pose.
+ pose_reg_loss: config of pose regularizer loss.
+ Used to prevent pose being too large.
+ limb_length_loss: config of limb length loss.
+ Used to prevent the change of body shape.
+ use_one_betas_per_video: whether to use the same beta parameters
+ for all frames in a single video sequence.
+ ignore_keypoints: list of keypoint names to ignore in keypoint
+ loss computation
+ device: torch device
+ verbose: whether to print information during registration
+
+ Returns:
+ None
+ """
+
+ self.use_one_betas_per_video = use_one_betas_per_video
+ self.num_epochs = num_epochs
+ self.img_res = img_res
+ self.device = device
+ self.stage_config = stages
+ self.optimizer = optimizer
+ self.keypoints2d_mse_loss = build_loss(keypoints2d_loss)
+ self.keypoints3d_mse_loss = build_loss(keypoints3d_loss)
+ self.shape_prior_loss = build_loss(shape_prior_loss)
+ self.joint_prior_loss = build_loss(joint_prior_loss)
+ self.smooth_loss = build_loss(smooth_loss)
+ self.pose_prior_loss = build_loss(pose_prior_loss)
+ self.pose_reg_loss = build_loss(pose_reg_loss)
+ self.limb_length_loss = build_loss(limb_length_loss)
+
+ if self.joint_prior_loss is not None:
+ self.joint_prior_loss = self.joint_prior_loss.to(self.device)
+ if self.smooth_loss is not None:
+ self.smooth_loss = self.smooth_loss.to(self.device)
+ if self.pose_prior_loss is not None:
+ self.pose_prior_loss = self.pose_prior_loss.to(self.device)
+ if self.pose_reg_loss is not None:
+ self.pose_reg_loss = self.pose_reg_loss.to(self.device)
+ if self.limb_length_loss is not None:
+ self.limb_length_loss = self.limb_length_loss.to(self.device)
+
+ # initialize body model
+ if isinstance(body_model, dict):
+ self.body_model = build_body_model(body_model).to(self.device)
+ elif isinstance(body_model, torch.nn.Module):
+ self.body_model = body_model.to(self.device)
+ else:
+ raise TypeError(f'body_model should be either dict or '
+ f'torch.nn.Module, but got {type(body_model)}')
+
+ # initialize camera
+ if camera is not None:
+ if isinstance(camera, dict):
+ self.camera = build_cameras(camera).to(self.device)
+ elif isinstance(camera, torch.nn.Module):
+ self.camera = camera.to(device)
+ else:
+ raise TypeError(f'camera should be either dict or '
+ f'torch.nn.Module, but got {type(camera)}')
+
+ self.ignore_keypoints = ignore_keypoints
+ self.verbose = verbose
+
+ self._set_keypoint_idxs()
+
+ def __call__(self,
+ keypoints2d: torch.Tensor = None,
+ keypoints2d_conf: torch.Tensor = None,
+ keypoints3d: torch.Tensor = None,
+ keypoints3d_conf: torch.Tensor = None,
+ init_global_orient: torch.Tensor = None,
+ init_transl: torch.Tensor = None,
+ init_body_pose: torch.Tensor = None,
+ init_betas: torch.Tensor = None,
+ return_verts: bool = False,
+ return_joints: bool = False,
+ return_full_pose: bool = False,
+ return_losses: bool = False) -> dict:
+ """Run registration.
+
+ Notes:
+ B: batch size
+ K: number of keypoints
+ D: shape dimension
+ Provide only keypoints2d or keypoints3d, not both.
+
+ Args:
+ keypoints2d: 2D keypoints of shape (B, K, 2)
+ keypoints2d_conf: 2D keypoint confidence of shape (B, K)
+ keypoints3d: 3D keypoints of shape (B, K, 3).
+ keypoints3d_conf: 3D keypoint confidence of shape (B, K)
+ init_global_orient: initial global_orient of shape (B, 3)
+ init_transl: initial transl of shape (B, 3)
+ init_body_pose: initial body_pose of shape (B, 69)
+ init_betas: initial betas of shape (B, D)
+ return_verts: whether to return vertices
+ return_joints: whether to return joints
+ return_full_pose: whether to return full pose
+ return_losses: whether to return loss dict
+
+ Returns:
+ ret: a dictionary that includes body model parameters,
+ and optional attributes such as vertices and joints
+ """
+ assert keypoints2d is not None or keypoints3d is not None, \
+ 'Neither of 2D nor 3D keypoints are provided.'
+ assert not (keypoints2d is not None and keypoints3d is not None), \
+ 'Do not provide both 2D and 3D keypoints.'
+ batch_size = keypoints2d.shape[0] if keypoints2d is not None \
+ else keypoints3d.shape[0]
+
+ global_orient = self._match_init_batch_size(
+ init_global_orient, self.body_model.global_orient, batch_size)
+ transl = self._match_init_batch_size(init_transl,
+ self.body_model.transl,
+ batch_size)
+ body_pose = self._match_init_batch_size(init_body_pose,
+ self.body_model.body_pose,
+ batch_size)
+ if init_betas is None and self.use_one_betas_per_video:
+ betas = torch.zeros(1, self.body_model.betas.shape[-1]).to(
+ self.device)
+ else:
+ betas = self._match_init_batch_size(init_betas,
+ self.body_model.betas,
+ batch_size)
+
+ for i in range(self.num_epochs):
+ for stage_idx, stage_config in enumerate(self.stage_config):
+ if self.verbose:
+ print(f'epoch {i}, stage {stage_idx}')
+ self._optimize_stage(
+ global_orient=global_orient,
+ transl=transl,
+ body_pose=body_pose,
+ betas=betas,
+ keypoints2d=keypoints2d,
+ keypoints2d_conf=keypoints2d_conf,
+ keypoints3d=keypoints3d,
+ keypoints3d_conf=keypoints3d_conf,
+ **stage_config,
+ )
+
+ # collate results
+ ret = {
+ 'global_orient': global_orient,
+ 'transl': transl,
+ 'body_pose': body_pose,
+ 'betas': betas
+ }
+
+ if return_verts or return_joints or \
+ return_full_pose or return_losses:
+ eval_ret = self.evaluate(
+ global_orient=global_orient,
+ body_pose=body_pose,
+ betas=betas,
+ transl=transl,
+ keypoints2d=keypoints2d,
+ keypoints2d_conf=keypoints2d_conf,
+ keypoints3d=keypoints3d,
+ keypoints3d_conf=keypoints3d_conf,
+ return_verts=return_verts,
+ return_full_pose=return_full_pose,
+ return_joints=return_joints,
+ reduction_override='none' # sample-wise loss
+ )
+
+ if return_verts:
+ ret['vertices'] = eval_ret['vertices']
+ if return_joints:
+ ret['joints'] = eval_ret['joints']
+ if return_full_pose:
+ ret['full_pose'] = eval_ret['full_pose']
+ if return_losses:
+ for k in eval_ret.keys():
+ if 'loss' in k:
+ ret[k] = eval_ret[k]
+
+ for k, v in ret.items():
+ if isinstance(v, torch.Tensor):
+ ret[k] = v.detach().clone()
+
+ return ret
+
+ def _optimize_stage(self,
+ betas: torch.Tensor,
+ body_pose: torch.Tensor,
+ global_orient: torch.Tensor,
+ transl: torch.Tensor,
+ fit_global_orient: bool = True,
+ fit_transl: bool = True,
+ fit_body_pose: bool = True,
+ fit_betas: bool = True,
+ keypoints2d: torch.Tensor = None,
+ keypoints2d_conf: torch.Tensor = None,
+ keypoints2d_weight: float = None,
+ keypoints3d: torch.Tensor = None,
+ keypoints3d_conf: torch.Tensor = None,
+ keypoints3d_weight: float = None,
+ shape_prior_weight: float = None,
+ joint_prior_weight: float = None,
+ smooth_loss_weight: float = None,
+ pose_prior_weight: float = None,
+ pose_reg_weight: float = None,
+ limb_length_weight: float = None,
+ joint_weights: dict = {},
+ num_iter: int = 1,
+ ftol: float = 1e-4,
+ **kwargs) -> None:
+ """Optimize a stage of body model parameters according to
+ configuration.
+
+ Notes:
+ B: batch size
+ K: number of keypoints
+ D: shape dimension
+
+ Args:
+ betas: shape (B, D)
+ body_pose: shape (B, 69)
+ global_orient: shape (B, 3)
+ transl: shape (B, 3)
+ fit_global_orient: whether to optimize global_orient
+ fit_transl: whether to optimize transl
+ fit_body_pose: whether to optimize body_pose
+ fit_betas: whether to optimize betas
+ keypoints2d: 2D keypoints of shape (B, K, 2)
+ keypoints2d_conf: 2D keypoint confidence of shape (B, K)
+ keypoints2d_weight: weight of 2D keypoint loss
+ keypoints3d: 3D keypoints of shape (B, K, 3).
+ keypoints3d_conf: 3D keypoint confidence of shape (B, K)
+ keypoints3d_weight: weight of 3D keypoint loss
+ shape_prior_weight: weight of shape prior loss
+ joint_prior_weight: weight of joint prior loss
+ smooth_loss_weight: weight of smooth loss
+ pose_prior_weight: weight of pose prior loss
+ pose_reg_weight: weight of pose regularization loss
+ limb_length_weight: weight of limb length loss
+ joint_weights: per joint weight of shape (K, )
+ num_iter: number of iterations
+ ftol: early stop tolerance for relative change in loss
+
+ Returns:
+ None
+ """
+
+ parameters = OptimizableParameters()
+ parameters.set_param(fit_global_orient, global_orient)
+ parameters.set_param(fit_transl, transl)
+ parameters.set_param(fit_body_pose, body_pose)
+ parameters.set_param(fit_betas, betas)
+
+ optimizer = build_optimizer(parameters, self.optimizer)
+
+ pre_loss = None
+ for iter_idx in range(num_iter):
+
+ def closure():
+ optimizer.zero_grad()
+ betas_video = self._expand_betas(body_pose.shape[0], betas)
+
+ loss_dict = self.evaluate(
+ global_orient=global_orient,
+ body_pose=body_pose,
+ betas=betas_video,
+ transl=transl,
+ keypoints2d=keypoints2d,
+ keypoints2d_conf=keypoints2d_conf,
+ keypoints2d_weight=keypoints2d_weight,
+ keypoints3d=keypoints3d,
+ keypoints3d_conf=keypoints3d_conf,
+ keypoints3d_weight=keypoints3d_weight,
+ joint_prior_weight=joint_prior_weight,
+ shape_prior_weight=shape_prior_weight,
+ smooth_loss_weight=smooth_loss_weight,
+ pose_prior_weight=pose_prior_weight,
+ pose_reg_weight=pose_reg_weight,
+ limb_length_weight=limb_length_weight,
+ joint_weights=joint_weights)
+
+ loss = loss_dict['total_loss']
+ loss.backward()
+ return loss
+
+ loss = optimizer.step(closure)
+ if iter_idx > 0 and pre_loss is not None and ftol > 0:
+ loss_rel_change = self._compute_relative_change(
+ pre_loss, loss.item())
+ if loss_rel_change < ftol:
+ if self.verbose:
+ print(f'[ftol={ftol}] Early stop at {iter_idx} iter!')
+ break
+ pre_loss = loss.item()
+
+ def evaluate(
+ self,
+ betas: torch.Tensor = None,
+ body_pose: torch.Tensor = None,
+ global_orient: torch.Tensor = None,
+ transl: torch.Tensor = None,
+ keypoints2d: torch.Tensor = None,
+ keypoints2d_conf: torch.Tensor = None,
+ keypoints2d_weight: float = None,
+ keypoints3d: torch.Tensor = None,
+ keypoints3d_conf: torch.Tensor = None,
+ keypoints3d_weight: float = None,
+ shape_prior_weight: float = None,
+ joint_prior_weight: float = None,
+ smooth_loss_weight: float = None,
+ pose_prior_weight: float = None,
+ pose_reg_weight: float = None,
+ limb_length_weight: float = None,
+ joint_weights: dict = {},
+ return_verts: bool = False,
+ return_full_pose: bool = False,
+ return_joints: bool = False,
+ reduction_override: str = None,
+ ) -> dict:
+ """Evaluate fitted parameters through loss computation. This function
+ serves two purposes: 1) internally, for loss backpropagation 2)
+ externally, for fitting quality evaluation.
+
+ Notes:
+ B: batch size
+ K: number of keypoints
+ D: shape dimension
+
+ Args:
+ betas: shape (B, D)
+ body_pose: shape (B, 69)
+ global_orient: shape (B, 3)
+ transl: shape (B, 3)
+ keypoints2d: 2D keypoints of shape (B, K, 2)
+ keypoints2d_conf: 2D keypoint confidence of shape (B, K)
+ keypoints2d_weight: weight of 2D keypoint loss
+ keypoints3d: 3D keypoints of shape (B, K, 3).
+ keypoints3d_conf: 3D keypoint confidence of shape (B, K)
+ keypoints3d_weight: weight of 3D keypoint loss
+ shape_prior_weight: weight of shape prior loss
+ joint_prior_weight: weight of joint prior loss
+ smooth_loss_weight: weight of smooth loss
+ pose_prior_weight: weight of pose prior loss
+ pose_reg_weight: weight of pose regularization loss
+ limb_length_weight: weight of limb length loss
+ joint_weights: per joint weight of shape (K, )
+ return_verts: whether to return vertices
+ return_joints: whether to return joints
+ return_full_pose: whether to return full pose
+ reduction_override: reduction method, e.g., 'none', 'sum', 'mean'
+
+ Returns:
+ ret: a dictionary that includes body model parameters,
+ and optional attributes such as vertices and joints
+ """
+
+ ret = {}
+
+ body_model_output = self.body_model(
+ global_orient=global_orient,
+ body_pose=body_pose,
+ betas=betas,
+ transl=transl,
+ return_verts=return_verts,
+ return_full_pose=return_full_pose)
+
+ model_joints = body_model_output['joints']
+ model_joint_mask = body_model_output['joint_mask']
+
+ loss_dict = self._compute_loss(
+ model_joints,
+ model_joint_mask,
+ keypoints2d=keypoints2d,
+ keypoints2d_conf=keypoints2d_conf,
+ keypoints2d_weight=keypoints2d_weight,
+ keypoints3d=keypoints3d,
+ keypoints3d_conf=keypoints3d_conf,
+ keypoints3d_weight=keypoints3d_weight,
+ joint_prior_weight=joint_prior_weight,
+ shape_prior_weight=shape_prior_weight,
+ smooth_loss_weight=smooth_loss_weight,
+ pose_prior_weight=pose_prior_weight,
+ pose_reg_weight=pose_reg_weight,
+ limb_length_weight=limb_length_weight,
+ joint_weights=joint_weights,
+ reduction_override=reduction_override,
+ global_orient=global_orient,
+ body_pose=body_pose,
+ betas=betas)
+ ret.update(loss_dict)
+
+ if return_verts:
+ ret['vertices'] = body_model_output['vertices']
+ if return_full_pose:
+ ret['full_pose'] = body_model_output['full_pose']
+ if return_joints:
+ ret['joints'] = model_joints
+
+ return ret
+
+ def _compute_loss(self,
+ model_joints: torch.Tensor,
+ model_joint_conf: torch.Tensor,
+ keypoints2d: torch.Tensor = None,
+ keypoints2d_conf: torch.Tensor = None,
+ keypoints2d_weight: float = None,
+ keypoints3d: torch.Tensor = None,
+ keypoints3d_conf: torch.Tensor = None,
+ keypoints3d_weight: float = None,
+ shape_prior_weight: float = None,
+ joint_prior_weight: float = None,
+ smooth_loss_weight: float = None,
+ pose_prior_weight: float = None,
+ pose_reg_weight: float = None,
+ limb_length_weight: float = None,
+ joint_weights: dict = {},
+ reduction_override: str = None,
+ global_orient: torch.Tensor = None,
+ body_pose: torch.Tensor = None,
+ betas: torch.Tensor = None):
+ """Loss computation.
+
+ Notes:
+ B: batch size
+ K: number of keypoints
+ D: shape dimension
+
+ Args:
+ model_joints: 3D joints regressed from body model of shape (B, K)
+ model_joint_conf: 3D joint confidence of shape (B, K). It is
+ normally all 1, except for zero-pads due to convert_kps in
+ the SMPL wrapper.
+ keypoints2d: 2D keypoints of shape (B, K, 2)
+ keypoints2d_conf: 2D keypoint confidence of shape (B, K)
+ keypoints2d_weight: weight of 2D keypoint loss
+ keypoints3d: 3D keypoints of shape (B, K, 3).
+ keypoints3d_conf: 3D keypoint confidence of shape (B, K)
+ keypoints3d_weight: weight of 3D keypoint loss
+ shape_prior_weight: weight of shape prior loss
+ joint_prior_weight: weight of joint prior loss
+ smooth_loss_weight: weight of smooth loss
+ pose_prior_weight: weight of pose prior loss
+ joint_weights: per joint weight of shape (K, )
+ reduction_override: reduction method, e.g., 'none', 'sum', 'mean'
+ body_pose: shape (B, 69), for loss computation
+ betas: shape (B, D), for loss computation
+
+ Returns:
+ losses: a dict that contains all losses
+ """
+ losses = {}
+
+ weight = self._get_weight(**joint_weights)
+
+ # 2D keypoint loss
+ if keypoints2d is not None and not self._skip_loss(
+ self.keypoints2d_mse_loss, keypoints2d_weight):
+ # bs = model_joints.shape[0]
+ # projected_joints = perspective_projection(
+ # model_joints,
+ # torch.eye(3).expand((bs, 3, 3)).to(model_joints.device),
+ # torch.zeros((bs, 3)).to(model_joints.device), 5000.0,
+ # torch.Tensor([self.img_res / 2,
+ # self.img_res / 2]).to(model_joints.device))
+ projected_joints_xyd = self.camera.transform_points_screen(
+ model_joints)
+ projected_joints = projected_joints_xyd[..., :2]
+
+ # normalize keypoints to [-1,1]
+ projected_joints = 2 * projected_joints / (self.img_res - 1) - 1
+ keypoints2d = 2 * keypoints2d / (self.img_res - 1) - 1
+
+ keypoint2d_loss = self.keypoints2d_mse_loss(
+ pred=projected_joints,
+ pred_conf=model_joint_conf,
+ target=keypoints2d,
+ target_conf=keypoints2d_conf,
+ keypoint_weight=weight,
+ loss_weight_override=keypoints2d_weight,
+ reduction_override=reduction_override)
+ losses['keypoint2d_loss'] = keypoint2d_loss
+
+ # 3D keypoint loss
+ if keypoints3d is not None and not self._skip_loss(
+ self.keypoints3d_mse_loss, keypoints3d_weight):
+ keypoints3d_loss = self.keypoints3d_mse_loss(
+ pred=model_joints,
+ pred_conf=model_joint_conf,
+ target=keypoints3d,
+ target_conf=keypoints3d_conf,
+ keypoint_weight=weight,
+ loss_weight_override=keypoints3d_weight,
+ reduction_override=reduction_override)
+ losses['keypoints3d_loss'] = keypoints3d_loss
+
+ # regularizer to prevent betas from taking large values
+ if not self._skip_loss(self.shape_prior_loss, shape_prior_weight):
+ shape_prior_loss = self.shape_prior_loss(
+ betas=betas,
+ loss_weight_override=shape_prior_weight,
+ reduction_override=reduction_override)
+ losses['shape_prior_loss'] = shape_prior_loss
+
+ # joint prior loss
+ if not self._skip_loss(self.joint_prior_loss, joint_prior_weight):
+ joint_prior_loss = self.joint_prior_loss(
+ body_pose=body_pose,
+ loss_weight_override=joint_prior_weight,
+ reduction_override=reduction_override)
+ losses['joint_prior_loss'] = joint_prior_loss
+
+ # smooth body loss
+ if not self._skip_loss(self.smooth_loss, smooth_loss_weight):
+ smooth_loss = self.smooth_loss(
+ body_pose=body_pose,
+ loss_weight_override=smooth_loss_weight,
+ reduction_override=reduction_override)
+ losses['smooth_loss'] = smooth_loss
+
+ # pose prior loss
+ if not self._skip_loss(self.pose_prior_loss, pose_prior_weight):
+ pose_prior_loss = self.pose_prior_loss(
+ body_pose=body_pose,
+ loss_weight_override=pose_prior_weight,
+ reduction_override=reduction_override)
+ losses['pose_prior_loss'] = pose_prior_loss
+
+ # pose reg loss
+ if not self._skip_loss(self.pose_reg_loss, pose_reg_weight):
+ pose_reg_loss = self.pose_reg_loss(
+ body_pose=body_pose,
+ loss_weight_override=pose_reg_weight,
+ reduction_override=reduction_override)
+ losses['pose_reg_loss'] = pose_reg_loss
+
+ # limb length loss
+ if not self._skip_loss(self.limb_length_loss, limb_length_weight):
+ limb_length_loss = self.limb_length_loss(
+ pred=model_joints,
+ pred_conf=model_joint_conf,
+ target=keypoints3d,
+ target_conf=keypoints3d_conf,
+ loss_weight_override=limb_length_weight,
+ reduction_override=reduction_override)
+ losses['limb_length_loss'] = limb_length_loss
+
+ if self.verbose:
+ msg = ''
+ for loss_name, loss in losses.items():
+ msg += f'{loss_name}={loss.mean().item():.6f}, '
+ if self.verbose:
+ print(msg.strip(', '))
+
+ total_loss = 0
+ for loss_name, loss in losses.items():
+ if loss.ndim == 3:
+ total_loss = total_loss + loss.sum(dim=(2, 1))
+ elif loss.ndim == 2:
+ total_loss = total_loss + loss.sum(dim=-1)
+ else:
+ total_loss = total_loss + loss
+ losses['total_loss'] = total_loss
+
+ return losses
+
+ def _match_init_batch_size(self, init_param: torch.Tensor,
+ init_param_body_model: torch.Tensor,
+ batch_size: int) -> torch.Tensor:
+ """A helper function to ensure body model parameters have the same
+ batch size as the input keypoints.
+
+ Args:
+ init_param: input initial body model parameters, may be None
+ init_param_body_model: initial body model parameters from the
+ body model
+ batch_size: batch size of keypoints
+
+ Returns:
+ param: body model parameters with batch size aligned
+ """
+
+ # param takes init values
+ param = init_param.detach().clone() \
+ if init_param is not None \
+ else init_param_body_model.detach().clone()
+
+ # expand batch dimension to match batch size
+ param_batch_size = param.shape[0]
+ if param_batch_size != batch_size:
+ if param_batch_size == 1:
+ param = param.repeat(batch_size, *[1] * (param.ndim - 1))
+ else:
+ raise ValueError('Init param does not match the batch size of '
+ 'keypoints, and is not 1.')
+
+ # shape check
+ assert param.shape[0] == batch_size
+ assert param.shape[1:] == init_param_body_model.shape[1:], \
+ f'Shape mismatch: {param.shape} vs {init_param_body_model.shape}'
+
+ return param
+
+ def _set_keypoint_idxs(self) -> None:
+ """Set keypoint indices to 1) body parts to be assigned different
+ weights 2) be ignored for keypoint loss computation.
+
+ Returns:
+ None
+ """
+ convention = self.body_model.keypoint_dst
+
+ # obtain ignore keypoint indices
+ if self.ignore_keypoints is not None:
+ self.ignore_keypoint_idxs = []
+ for keypoint_name in self.ignore_keypoints:
+ keypoint_idx = get_keypoint_idx(
+ keypoint_name, convention=convention)
+ if keypoint_idx != -1:
+ self.ignore_keypoint_idxs.append(keypoint_idx)
+
+ # obtain body part keypoint indices
+ shoulder_keypoint_idxs = get_keypoint_idxs_by_part(
+ 'shoulder', convention=convention)
+ hip_keypoint_idxs = get_keypoint_idxs_by_part(
+ 'hip', convention=convention)
+ self.shoulder_hip_keypoint_idxs = [
+ *shoulder_keypoint_idxs, *hip_keypoint_idxs
+ ]
+
+ def _get_weight(self,
+ use_shoulder_hip_only: bool = False,
+ body_weight: float = 1.0) -> torch.Tensor:
+ """Get per keypoint weight.
+
+ Notes:
+ K: number of keypoints
+
+ Args:
+ use_shoulder_hip_only: whether to use only shoulder and hip
+ keypoints for loss computation. This is useful in the
+ warming-up stage to find a reasonably good initialization.
+ body_weight: weight of body keypoints. Body part segmentation
+ definition is included in the HumanData convention.
+
+ Returns:
+ weight: per keypoint weight tensor of shape (K)
+ """
+
+ num_keypoint = self.body_model.num_joints
+
+ if use_shoulder_hip_only:
+ weight = torch.zeros([num_keypoint]).to(self.device)
+ weight[self.shoulder_hip_keypoint_idxs] = 1.0
+ weight = weight * body_weight
+ else:
+ weight = torch.ones([num_keypoint]).to(self.device)
+ weight = weight * body_weight
+
+ if hasattr(self, 'ignore_keypoint_idxs'):
+ weight[self.ignore_keypoint_idxs] = 0.0
+
+ return weight
+
+ def _expand_betas(self, batch_size, betas):
+ """A helper function to expand the betas's first dim to match batch
+ size such that the same beta parameters can be used for all frames in a
+ video sequence.
+
+ Notes:
+ B: batch size
+ K: number of keypoints
+ D: shape dimension
+
+ Args:
+ batch_size: batch size
+ betas: shape (B, D)
+
+ Returns:
+ betas_video: expanded betas
+ """
+ # no expansion needed
+ if batch_size == betas.shape[0]:
+ return betas
+
+ # first dim is 1
+ else:
+ feat_dim = betas.shape[-1]
+ betas_video = betas.view(1, feat_dim).expand(batch_size, feat_dim)
+
+ return betas_video
+
+ @staticmethod
+ def _compute_relative_change(pre_v, cur_v):
+ """Compute relative loss change. If relative change is small enough, we
+ can apply early stop to accelerate the optimization. (1) When one of
+ the value is larger than 1, we calculate the relative change by diving
+ their max value. (2) When both values are smaller than 1, it degrades
+ to absolute change. Intuitively, if two values are small and close,
+ dividing the difference by the max value may yield a large value.
+
+ Args:
+ pre_v: previous value
+ cur_v: current value
+
+ Returns:
+ float: relative change
+ """
+ return np.abs(pre_v - cur_v) / max([np.abs(pre_v), np.abs(cur_v), 1])
+
+ @staticmethod
+ def _skip_loss(loss, loss_weight_override):
+ """Whether to skip loss computation. If loss is None, it will directly
+ skip the loss to avoid RuntimeError. If loss is not None, the table
+ below shows the return value. If the return value is True, it means the
+ computation of loss can be skipped. As the result is 0 even if it is
+ calculated, we can skip it to save computational cost.
+
+ | loss.loss_weight | loss_weight_override | returns |
+ | ---------------- | -------------------- | ------- |
+ | == 0 | None | True |
+ | != 0 | None | False |
+ | == 0 | == 0 | True |
+ | != 0 | == 0 | True |
+ | == 0 | != 0 | False |
+ | != 0 | != 0 | False |
+
+ Args:
+ loss: loss is an object that has attribute loss_weight.
+ loss.loss_weight is assigned when loss is initialized.
+ loss_weight_override: loss_weight used to override loss.loss_weight
+
+ Returns:
+ bool: True means skipping loss computation, and vice versa
+ """
+ if (loss is None) or (loss.loss_weight == 0 and loss_weight_override is
+ None) or (loss_weight_override == 0):
+ return True
+ return False
diff --git a/detrsmpl/models/registrants/smplifyx.py b/detrsmpl/models/registrants/smplifyx.py
new file mode 100644
index 0000000000000000000000000000000000000000..7402440ddb4287a67b560c408d8d63e450301729
--- /dev/null
+++ b/detrsmpl/models/registrants/smplifyx.py
@@ -0,0 +1,489 @@
+import torch
+from mmcv.runner import build_optimizer
+
+from detrsmpl.core.conventions.keypoints_mapping import (
+ get_keypoint_idx,
+ get_keypoint_idxs_by_part,
+)
+from .smplify import OptimizableParameters, SMPLify
+
+
+class SMPLifyX(SMPLify):
+ """Re-implementation of SMPLify-X with extended features.
+
+ - video input
+ - 3D keypoints
+ """
+ def __call__(self,
+ keypoints2d: torch.Tensor = None,
+ keypoints2d_conf: torch.Tensor = None,
+ keypoints3d: torch.Tensor = None,
+ keypoints3d_conf: torch.Tensor = None,
+ init_global_orient: torch.Tensor = None,
+ init_transl: torch.Tensor = None,
+ init_body_pose: torch.Tensor = None,
+ init_betas: torch.Tensor = None,
+ init_left_hand_pose: torch.Tensor = None,
+ init_right_hand_pose: torch.Tensor = None,
+ init_expression: torch.Tensor = None,
+ init_jaw_pose: torch.Tensor = None,
+ init_leye_pose: torch.Tensor = None,
+ init_reye_pose: torch.Tensor = None,
+ return_verts: bool = False,
+ return_joints: bool = False,
+ return_full_pose: bool = False,
+ return_losses: bool = False) -> dict:
+ """Run registration.
+
+ Notes:
+ B: batch size
+ K: number of keypoints
+ D: body shape dimension
+ D_H: hand pose dimension
+ D_E: expression dimension
+ Provide only keypoints2d or keypoints3d, not both.
+
+ Args:
+ keypoints2d: 2D keypoints of shape (B, K, 2)
+ keypoints2d_conf: 2D keypoint confidence of shape (B, K)
+ keypoints3d: 3D keypoints of shape (B, K, 3).
+ keypoints3d_conf: 3D keypoint confidence of shape (B, K)
+ init_global_orient: initial global_orient of shape (B, 3)
+ init_transl: initial transl of shape (B, 3)
+ init_body_pose: initial body_pose of shape (B, 69)
+ init_betas: initial betas of shape (B, D)
+ init_left_hand_pose: initial left hand pose of shape (B, D_H)
+ init_right_hand_pose: initial right hand pose of shape (B, D_H)
+ init_expression: initial left hand pose of shape (B, D_E)
+ init_jaw_pose: initial jaw pose of shape (B, 3)
+ init_leye_pose: initial left eye pose of shape (B, 3)
+ init_reye_pose: initial right eye pose of shape (B, 3)
+ return_verts: whether to return vertices
+ return_joints: whether to return joints
+ return_full_pose: whether to return full pose
+ return_losses: whether to return loss dict
+
+ Returns:
+ ret: a dictionary that includes body model parameters,
+ and optional attributes such as vertices and joints
+ """
+
+ assert keypoints2d is not None or keypoints3d is not None, \
+ 'Neither of 2D nor 3D keypoints are provided.'
+ assert not (keypoints2d is not None and keypoints3d is not None), \
+ 'Do not provide both 2D and 3D keypoints.'
+ batch_size = keypoints2d.shape[0] if keypoints2d is not None \
+ else keypoints3d.shape[0]
+
+ global_orient = self._match_init_batch_size(
+ init_global_orient, self.body_model.global_orient, batch_size)
+ transl = self._match_init_batch_size(init_transl,
+ self.body_model.transl,
+ batch_size)
+ body_pose = self._match_init_batch_size(init_body_pose,
+ self.body_model.body_pose,
+ batch_size)
+ left_hand_pose = self._match_init_batch_size(
+ init_left_hand_pose, self.body_model.left_hand_pose, batch_size)
+ right_hand_pose = self._match_init_batch_size(
+ init_right_hand_pose, self.body_model.right_hand_pose, batch_size)
+ expression = self._match_init_batch_size(init_expression,
+ self.body_model.expression,
+ batch_size)
+ jaw_pose = self._match_init_batch_size(init_jaw_pose,
+ self.body_model.jaw_pose,
+ batch_size)
+ leye_pose = self._match_init_batch_size(init_leye_pose,
+ self.body_model.leye_pose,
+ batch_size)
+ reye_pose = self._match_init_batch_size(init_reye_pose,
+ self.body_model.reye_pose,
+ batch_size)
+ if init_betas is None and self.use_one_betas_per_video:
+ betas = torch.zeros(1, self.body_model.betas.shape[-1]).to(
+ self.device)
+ else:
+ betas = self._match_init_batch_size(init_betas,
+ self.body_model.betas,
+ batch_size)
+
+ for i in range(self.num_epochs):
+ for stage_idx, stage_config in enumerate(self.stage_config):
+ # print(stage_name)
+ self._optimize_stage(
+ global_orient=global_orient,
+ transl=transl,
+ body_pose=body_pose,
+ betas=betas,
+ left_hand_pose=left_hand_pose,
+ right_hand_pose=right_hand_pose,
+ expression=expression,
+ jaw_pose=jaw_pose,
+ leye_pose=leye_pose,
+ reye_pose=reye_pose,
+ keypoints2d=keypoints2d,
+ keypoints2d_conf=keypoints2d_conf,
+ keypoints3d=keypoints3d,
+ keypoints3d_conf=keypoints3d_conf,
+ **stage_config,
+ )
+
+ return {
+ 'global_orient': global_orient,
+ 'transl': transl,
+ 'body_pose': body_pose,
+ 'betas': betas,
+ 'left_hand_pose': left_hand_pose,
+ 'right_hand_pose': right_hand_pose,
+ 'expression': expression,
+ 'jaw_pose': jaw_pose,
+ 'leye_pose': leye_pose,
+ 'reye_pose': reye_pose
+ }
+
+ def _optimize_stage(self,
+ betas: torch.Tensor,
+ body_pose: torch.Tensor,
+ global_orient: torch.Tensor,
+ transl: torch.Tensor,
+ left_hand_pose: torch.Tensor,
+ right_hand_pose: torch.Tensor,
+ expression: torch.Tensor,
+ jaw_pose: torch.Tensor,
+ leye_pose: torch.Tensor,
+ reye_pose: torch.Tensor,
+ fit_global_orient: bool = True,
+ fit_transl: bool = True,
+ fit_body_pose: bool = True,
+ fit_betas: bool = True,
+ fit_left_hand_pose: bool = True,
+ fit_right_hand_pose: bool = True,
+ fit_expression: bool = True,
+ fit_jaw_pose: bool = True,
+ fit_leye_pose: bool = True,
+ fit_reye_pose: bool = True,
+ keypoints2d: torch.Tensor = None,
+ keypoints2d_conf: torch.Tensor = None,
+ keypoints2d_weight: float = None,
+ keypoints3d: torch.Tensor = None,
+ keypoints3d_conf: torch.Tensor = None,
+ keypoints3d_weight: float = None,
+ shape_prior_weight: float = None,
+ joint_prior_weight: float = None,
+ smooth_loss_weight: float = None,
+ pose_prior_weight: float = None,
+ pose_reg_weight: float = None,
+ limb_length_weight: float = None,
+ joint_weights: dict = {},
+ ftol: float = 1e-4,
+ num_iter: int = 1) -> None:
+ """Optimize a stage of body model parameters according to
+ configuration.
+
+ Notes:
+ B: batch size
+ K: number of keypoints
+ D: shape dimension
+
+ Args:
+ betas: shape (B, D)
+ body_pose: shape (B, 69)
+ global_orient: shape (B, 3)
+ transl: shape (B, 3)
+ fit_global_orient: whether to optimize global_orient
+ fit_transl: whether to optimize transl
+ fit_body_pose: whether to optimize body_pose
+ fit_betas: whether to optimize betas
+ fit_left_hand_pose: whether to optimize left hand pose
+ fit_right_hand_pose: whether to optimize right hand pose
+ fit_expression: whether to optimize expression
+ fit_jaw_pose: whether to optimize jaw pose
+ fit_leye_pose: whether to optimize left eye pose
+ fit_reye_pose: whether to optimize right eye pose
+ keypoints2d: 2D keypoints of shape (B, K, 2)
+ keypoints2d_conf: 2D keypoint confidence of shape (B, K)
+ keypoints2d_weight: weight of 2D keypoint loss
+ keypoints3d: 3D keypoints of shape (B, K, 3).
+ keypoints3d_conf: 3D keypoint confidence of shape (B, K)
+ keypoints3d_weight: weight of 3D keypoint loss
+ shape_prior_weight: weight of shape prior loss
+ joint_prior_weight: weight of joint prior loss
+ smooth_loss_weight: weight of smooth loss
+ pose_prior_weight: weight of pose prior loss
+ pose_reg_weight: weight of pose regularization loss
+ limb_length_weight: weight of limb length loss
+ joint_weights: per joint weight of shape (K, )
+ num_iter: number of iterations
+ ftol: early stop tolerance for relative change in loss
+
+ Returns:
+ None
+ """
+
+ parameters = OptimizableParameters()
+ parameters.set_param(fit_global_orient, global_orient)
+ parameters.set_param(fit_transl, transl)
+ parameters.set_param(fit_body_pose, body_pose)
+ parameters.set_param(fit_betas, betas)
+ parameters.set_param(fit_left_hand_pose, left_hand_pose)
+ parameters.set_param(fit_right_hand_pose, right_hand_pose)
+ parameters.set_param(fit_expression, expression)
+ parameters.set_param(fit_jaw_pose, jaw_pose)
+ parameters.set_param(fit_leye_pose, leye_pose)
+ parameters.set_param(fit_reye_pose, reye_pose)
+
+ optimizer = build_optimizer(parameters, self.optimizer)
+
+ pre_loss = None
+ for iter_idx in range(num_iter):
+
+ def closure():
+ # body_pose_fixed = use_reference_spine(body_pose,
+ # init_body_pose)
+
+ optimizer.zero_grad()
+ betas_video = self._expand_betas(body_pose.shape[0], betas)
+
+ loss_dict = self.evaluate(
+ global_orient=global_orient,
+ body_pose=body_pose,
+ betas=betas_video,
+ transl=transl,
+ left_hand_pose=left_hand_pose,
+ right_hand_pose=right_hand_pose,
+ expression=expression,
+ jaw_pose=jaw_pose,
+ leye_pose=leye_pose,
+ reye_pose=reye_pose,
+ keypoints2d=keypoints2d,
+ keypoints2d_conf=keypoints2d_conf,
+ keypoints2d_weight=keypoints2d_weight,
+ keypoints3d=keypoints3d,
+ keypoints3d_conf=keypoints3d_conf,
+ keypoints3d_weight=keypoints3d_weight,
+ joint_prior_weight=joint_prior_weight,
+ shape_prior_weight=shape_prior_weight,
+ smooth_loss_weight=smooth_loss_weight,
+ pose_prior_weight=pose_prior_weight,
+ pose_reg_weight=pose_reg_weight,
+ limb_length_weight=limb_length_weight,
+ joint_weights=joint_weights)
+
+ loss = loss_dict['total_loss']
+ loss.backward()
+ return loss
+
+ loss = optimizer.step(closure)
+ if iter_idx > 0 and pre_loss is not None and ftol > 0:
+ loss_rel_change = self._compute_relative_change(
+ pre_loss, loss.item())
+ if loss_rel_change < ftol:
+ print(f'[ftol={ftol}] Early stop at {iter_idx} iter!')
+ break
+ pre_loss = loss.item()
+
+ def evaluate(
+ self,
+ betas: torch.Tensor = None,
+ body_pose: torch.Tensor = None,
+ global_orient: torch.Tensor = None,
+ transl: torch.Tensor = None,
+ left_hand_pose: torch.Tensor = None,
+ right_hand_pose: torch.Tensor = None,
+ expression: torch.Tensor = None,
+ jaw_pose: torch.Tensor = None,
+ leye_pose: torch.Tensor = None,
+ reye_pose: torch.Tensor = None,
+ keypoints2d: torch.Tensor = None,
+ keypoints2d_conf: torch.Tensor = None,
+ keypoints2d_weight: float = None,
+ keypoints3d: torch.Tensor = None,
+ keypoints3d_conf: torch.Tensor = None,
+ keypoints3d_weight: float = None,
+ shape_prior_weight: float = None,
+ joint_prior_weight: float = None,
+ smooth_loss_weight: float = None,
+ pose_prior_weight: float = None,
+ pose_reg_weight: float = None,
+ limb_length_weight: float = None,
+ joint_weights: dict = {},
+ return_verts: bool = False,
+ return_full_pose: bool = False,
+ return_joints: bool = False,
+ reduction_override: str = None,
+ ):
+ """Evaluate fitted parameters through loss computation. This function
+ serves two purposes: 1) internally, for loss backpropagation 2)
+ externally, for fitting quality evaluation.
+
+ Notes:
+ B: batch size
+ K: number of keypoints
+ D: body shape dimension
+ D_H: hand pose dimension
+ D_E: expression dimension
+
+ Args:
+ betas: shape (B, D)
+ body_pose: shape (B, 69)
+ global_orient: shape (B, 3)
+ transl: shape (B, 3)
+ left_hand_pose: shape (B, D_H)
+ right_hand_pose: shape (B, D_H)
+ expression: shape (B, D_E)
+ jaw_pose: shape (B, 3)
+ leye_pose: shape (B, 3)
+ reye_pose: shape (B, 3)
+ keypoints2d: 2D keypoints of shape (B, K, 2)
+ keypoints2d_conf: 2D keypoint confidence of shape (B, K)
+ keypoints2d_weight: weight of 2D keypoint loss
+ keypoints3d: 3D keypoints of shape (B, K, 3).
+ keypoints3d_conf: 3D keypoint confidence of shape (B, K)
+ keypoints3d_weight: weight of 3D keypoint loss
+ shape_prior_weight: weight of shape prior loss
+ joint_prior_weight: weight of joint prior loss
+ smooth_loss_weight: weight of smooth loss
+ pose_prior_weight: weight of pose prior loss
+ pose_reg_weight: weight of pose regularization loss
+ limb_length_weight: weight of limb length loss
+ joint_weights: per joint weight of shape (K, )
+ return_verts: whether to return vertices
+ return_joints: whether to return joints
+ return_full_pose: whether to return full pose
+ reduction_override: reduction method, e.g., 'none', 'sum', 'mean'
+
+ Returns:
+ ret: a dictionary that includes body model parameters,
+ and optional attributes such as vertices and joints
+ """
+
+ ret = {}
+
+ body_model_output = self.body_model(global_orient=global_orient,
+ body_pose=body_pose,
+ betas=betas,
+ transl=transl,
+ left_hand_pose=left_hand_pose,
+ right_hand_pose=right_hand_pose,
+ expression=expression,
+ jaw_pose=jaw_pose,
+ leye_pose=leye_pose,
+ reye_pose=reye_pose,
+ return_verts=return_verts,
+ return_full_pose=return_full_pose)
+
+ model_joints = body_model_output['joints']
+ model_joint_mask = body_model_output['joint_mask']
+
+ loss_dict = self._compute_loss(model_joints,
+ model_joint_mask,
+ keypoints2d=keypoints2d,
+ keypoints2d_conf=keypoints2d_conf,
+ keypoints2d_weight=keypoints2d_weight,
+ keypoints3d=keypoints3d,
+ keypoints3d_conf=keypoints3d_conf,
+ keypoints3d_weight=keypoints3d_weight,
+ joint_prior_weight=joint_prior_weight,
+ shape_prior_weight=shape_prior_weight,
+ smooth_loss_weight=smooth_loss_weight,
+ pose_prior_weight=pose_prior_weight,
+ pose_reg_weight=pose_reg_weight,
+ limb_length_weight=limb_length_weight,
+ joint_weights=joint_weights,
+ reduction_override=reduction_override,
+ body_pose=body_pose,
+ betas=betas)
+ ret.update(loss_dict)
+
+ if return_verts:
+ ret['vertices'] = body_model_output['vertices']
+ if return_full_pose:
+ ret['full_pose'] = body_model_output['full_pose']
+ if return_joints:
+ ret['joints'] = model_joints
+
+ return ret
+
+ def _set_keypoint_idxs(self):
+ """Set keypoint indices to 1) body parts to be assigned different
+ weights 2) be ignored for keypoint loss computation.
+
+ Returns:
+ None
+ """
+ convention = self.body_model.keypoint_dst
+
+ # obtain ignore keypoint indices
+ if self.ignore_keypoints is not None:
+ self.ignore_keypoint_idxs = []
+ for keypoint_name in self.ignore_keypoints:
+ keypoint_idx = get_keypoint_idx(keypoint_name,
+ convention=convention)
+ if keypoint_idx != -1:
+ self.ignore_keypoint_idxs.append(keypoint_idx)
+
+ # obtain body part keypoint indices
+ shoulder_keypoint_idxs = get_keypoint_idxs_by_part(
+ 'shoulder', convention=convention)
+ hip_keypoint_idxs = get_keypoint_idxs_by_part('hip',
+ convention=convention)
+ self.shoulder_hip_keypoint_idxs = [
+ *shoulder_keypoint_idxs, *hip_keypoint_idxs
+ ]
+
+ # head keypoints include all facial landmarks
+ self.face_keypoint_idxs = get_keypoint_idxs_by_part(
+ 'head', convention=convention)
+
+ left_hand_keypoint_idxs = get_keypoint_idxs_by_part(
+ 'left_hand', convention=convention)
+ right_hand_keypoint_idxs = get_keypoint_idxs_by_part(
+ 'right_hand', convention=convention)
+ self.hand_keypoint_idxs = [
+ *left_hand_keypoint_idxs, *right_hand_keypoint_idxs
+ ]
+
+ self.body_keypoint_idxs = get_keypoint_idxs_by_part(
+ 'body', convention=convention)
+
+ def _get_weight(self,
+ use_shoulder_hip_only: bool = False,
+ body_weight: float = 1.0,
+ hand_weight: float = 1.0,
+ face_weight: float = 1.0):
+ """Get per keypoint weight.
+
+ Notes:
+ K: number of keypoints
+
+ Args:
+ use_shoulder_hip_only: whether to use only shoulder and hip
+ keypoints for loss computation. This is useful in the
+ warming-up stage to find a reasonably good initialization.
+ body_weight: weight of body keypoints. Body part segmentation
+ definition is included in the HumanData convention.
+ hand_weight: weight of hand keypoints.
+ face_weight: weight of face keypoints.
+
+ Returns:
+ weight: per keypoint weight tensor of shape (K)
+ """
+ num_keypoint = self.body_model.num_joints
+
+ if use_shoulder_hip_only:
+ weight = torch.zeros([num_keypoint]).to(self.device)
+ weight[self.shoulder_hip_keypoint_idxs] = 1.0
+ else:
+ weight = torch.ones([num_keypoint]).to(self.device)
+
+ weight[self.body_keypoint_idxs] = \
+ weight[self.body_keypoint_idxs] * body_weight
+ weight[self.hand_keypoint_idxs] = \
+ weight[self.hand_keypoint_idxs] * hand_weight
+ weight[self.face_keypoint_idxs] = \
+ weight[self.face_keypoint_idxs] * face_weight
+
+ if hasattr(self, 'ignore_keypoint_idxs'):
+ weight[self.ignore_keypoint_idxs] = 0.0
+
+ return weight
diff --git a/detrsmpl/models/utils/SMPLX.py b/detrsmpl/models/utils/SMPLX.py
new file mode 100644
index 0000000000000000000000000000000000000000..d76142d71e070b3f166b4b39e74306ceff12983d
--- /dev/null
+++ b/detrsmpl/models/utils/SMPLX.py
@@ -0,0 +1,669 @@
+from typing import List
+
+import torch
+import torch.nn.functional as F
+from smplx.utils import find_joint_kin_chain
+
+from detrsmpl.core.conventions.keypoints_mapping import (
+ get_keypoint_idx,
+ get_keypoint_idxs_by_part,
+)
+from detrsmpl.utils.geometry import weak_perspective_projection
+
+
+class SMPLXHandMergeFunc():
+ """This function use predictions from hand model to update the hand params
+ (right_hand_pose, left_hand_pose, wrist_pose) in predictions from body
+ model."""
+ def __init__(self, body_model, convention='smplx'):
+ self.body_model = body_model
+ self.convention = convention
+ self.left_hand_idxs = get_keypoint_idxs_by_part(
+ 'left_hand', self.convention)
+ self.left_wrist_idx = get_keypoint_idx('left_wrist', self.convention)
+ self.left_hand_idxs.append(self.left_wrist_idx)
+ self.left_wrist_kin_chain = find_joint_kin_chain(
+ self.left_wrist_idx, self.body_model.parents)
+
+ self.right_hand_idxs = get_keypoint_idxs_by_part(
+ 'right_hand', self.convention)
+ self.right_wrist_idx = get_keypoint_idx('right_wrist', self.convention)
+ self.right_hand_idxs.append(self.right_wrist_idx)
+ self.right_wrist_kin_chain = find_joint_kin_chain(
+ self.right_wrist_idx, self.body_model.parents)
+
+ def __call__(self, body_predictions, hand_predictions):
+ """Function
+ Args:
+ body_predictions (dict): The prediction from body model.
+ hand_predictions (dict): The prediction from hand model.
+ Returns:
+ dict: Merged prediction.
+ """
+ pred_param = body_predictions['pred_param']
+ global_orient = pred_param['global_orient']
+ body_pose = pred_param['body_pose']
+ pred_cam = body_predictions['pred_cam']
+ batch_size = pred_cam.shape[0]
+ device = pred_cam.device
+ hands_from_body_idxs = torch.arange(0,
+ 2 * batch_size,
+ dtype=torch.long,
+ device=device)
+ right_hand_from_body_idxs = hands_from_body_idxs[:batch_size]
+ left_hand_from_body_idxs = hands_from_body_idxs[batch_size:]
+
+ parent_rots = []
+ right_wrist_parent_rot = find_joint_global_rotation(
+ self.right_wrist_kin_chain[1:], global_orient, body_pose)
+
+ left_wrist_parent_rot = find_joint_global_rotation(
+ self.left_wrist_kin_chain[1:], global_orient, body_pose)
+ left_to_right_wrist_parent_rot = flip_rotmat(left_wrist_parent_rot)
+
+ parent_rots += [right_wrist_parent_rot, left_to_right_wrist_parent_rot]
+ parent_rots = torch.cat(parent_rots, dim=0)
+
+ wrist_pose_from_hand = hand_predictions['pred_param']['global_orient']
+ # Undo the rotation of the parent joints to make the wrist rotation
+ # relative again
+ wrist_pose_from_hand = torch.matmul(
+ parent_rots.reshape(-1, 3, 3).transpose(1, 2),
+ wrist_pose_from_hand.reshape(-1, 3, 3))
+
+ right_hand_wrist = wrist_pose_from_hand[right_hand_from_body_idxs]
+ left_hand_wrist = flip_rotmat(
+ wrist_pose_from_hand[left_hand_from_body_idxs])
+ right_hand_pose = hand_predictions['pred_param']['right_hand_pose'][
+ right_hand_from_body_idxs]
+ left_hand_pose = flip_rotmat(
+ hand_predictions['pred_param']['right_hand_pose']
+ [left_hand_from_body_idxs])
+
+ body_predictions['pred_param']['right_hand_pose'] = right_hand_pose
+ body_predictions['pred_param']['left_hand_pose'] = left_hand_pose
+ body_predictions['pred_param']['body_pose'][:, self.right_wrist_idx -
+ 1] = right_hand_wrist
+ body_predictions['pred_param']['body_pose'][:, self.left_wrist_idx -
+ 1] = left_hand_wrist
+
+ return body_predictions
+
+
+class SMPLXFaceMergeFunc():
+ """This function use predictions from face model to update the face params
+ (jaw_pose, expression) in predictions from body model."""
+ def __init__(self,
+ body_model,
+ convention='smplx',
+ num_expression_coeffs=10):
+ self.body_model = body_model
+ self.convention = convention
+ self.num_expression_coeffs = num_expression_coeffs
+
+ def __call__(self, body_predictions, face_predictions):
+ """Function
+ Args:
+ body_predictions (dict): The prediction from body model.
+ face_predictions (dict): The prediction from face model.
+ Returns:
+ dict: Merged prediction.
+ """
+ body_predictions['pred_param']['jaw_pose'] = face_predictions[
+ 'pred_param']['jaw_pose']
+ body_predictions['pred_param']['expression'] = face_predictions[
+ 'pred_param']['expression'][:, :self.num_expression_coeffs]
+ return body_predictions
+
+
+def points_to_bbox(points, bbox_scale_factor: float = 1.0):
+ """Get scaled bounding box from keypoints 2D."""
+ min_coords, _ = torch.min(points, dim=1)
+ xmin, ymin = min_coords[:, 0], min_coords[:, 1]
+ max_coords, _ = torch.max(points, dim=1)
+ xmax, ymax = max_coords[:, 0], max_coords[:, 1]
+
+ center = torch.stack([xmax + xmin, ymax + ymin], dim=-1) * 0.5
+
+ width = (xmax - xmin)
+ height = (ymax - ymin)
+
+ # Convert the bounding box to a square box
+ size = torch.max(width, height) * bbox_scale_factor
+
+ return center, size
+
+
+def get_crop_info(points,
+ img_metas,
+ scale_factor: float = 1.0,
+ crop_size: int = 256):
+ """Get the transformation of points on the cropped image to the points on
+ the original image."""
+ device = points.device
+ dtype = points.dtype
+ batch_size = points.shape[0]
+ # Get the image to crop transformations and bounding box sizes
+ crop_transforms = []
+ img_bbox_sizes = []
+ for img_meta in img_metas:
+ crop_transforms.append(img_meta['crop_transform'])
+ img_bbox_sizes.append(img_meta['scale'].max())
+
+ img_bbox_sizes = torch.tensor(img_bbox_sizes, dtype=dtype, device=device)
+
+ crop_transforms = torch.tensor(crop_transforms, dtype=dtype, device=device)
+
+ crop_transforms = torch.cat([
+ crop_transforms,
+ torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device).expand(
+ [batch_size, 1, 3])
+ ],
+ dim=1)
+
+ inv_crop_transforms = torch.inverse(crop_transforms)
+
+ # center on the cropped body image
+ center_body_crop, bbox_size = points_to_bbox(
+ points, bbox_scale_factor=scale_factor)
+
+ orig_bbox_size = bbox_size / crop_size * img_bbox_sizes
+
+ # Compute the center of the crop in the original image
+ center = (torch.einsum(
+ 'bij,bj->bi', [inv_crop_transforms[:, :2, :2], center_body_crop]) +
+ inv_crop_transforms[:, :2, 2])
+
+ return {
+ 'center': center.reshape(-1, 2),
+ 'orig_bbox_size': orig_bbox_size,
+ # 'bbox_size': bbox_size.reshape(-1),
+ 'inv_crop_transforms': inv_crop_transforms,
+ # 'center_body_crop': 2 * center_body_crop / (crop_size-1) - 1,
+ }
+
+
+def concat_images(images: List[torch.Tensor]):
+ """Concat images of different size."""
+ sizes = [img.shape[1:] for img in images]
+ H, W = [max(s) for s in zip(*sizes)]
+ batch_size = len(images)
+ batched_shape = (batch_size, images[0].shape[0], H, W)
+ batched = torch.zeros(batched_shape,
+ device=images[0].device,
+ dtype=images[0].dtype)
+ for ii, img in enumerate(images):
+ shape = img.shape
+ batched[ii, :shape[0], :shape[1], :shape[2]] = img
+ return batched
+
+
+def flip_rotmat(pose_rotmat):
+ """Flip function.
+
+ Flip rotmat.
+ """
+ rot_mats = pose_rotmat.reshape(-1, 9).clone()
+
+ rot_mats[:, [1, 2, 3, 6]] *= -1
+ return rot_mats.view_as(pose_rotmat)
+
+
+def find_joint_global_rotation(kin_chain, root_pose, body_pose):
+ """Computes the absolute rotation of a joint from the kinematic chain."""
+ # Create a single vector with all the poses
+ parents_pose = torch.cat([root_pose, body_pose], dim=1)[:, kin_chain]
+ output_pose = parents_pose[:, 0]
+ for idx in range(1, parents_pose.shape[1]):
+ output_pose = torch.bmm(parents_pose[:, idx], output_pose)
+ return output_pose
+
+
+class CropSampler():
+ """This function crops the HD images using bilinear interpolation."""
+ def __init__(self, crop_size: int = 256) -> None:
+ """Uses bilinear sampling to extract square crops.
+
+ This module expects a high resolution image as input and a bounding
+ box, described by its' center and size. It then proceeds to extract
+ a sub-image using the provided information through bilinear
+ interpolation.
+
+ Parameters
+ ----------
+ crop_size: int
+ The desired size for the crop.
+ """
+ super(CropSampler, self).__init__()
+
+ self.crop_size = crop_size
+ x = torch.arange(0, crop_size, dtype=torch.float32) / (crop_size - 1)
+ grid_y, grid_x = torch.meshgrid(x, x)
+
+ points = torch.stack([grid_y.flatten(), grid_x.flatten()], axis=1)
+
+ self.grid = points.unsqueeze(dim=0)
+
+ def _sample_padded(self, full_imgs, sampling_grid):
+ """"""
+ # Get the sub-images using bilinear interpolation
+ return F.grid_sample(full_imgs, sampling_grid, align_corners=True)
+
+ def __call__(self, full_imgs, center, bbox_size):
+ """Crops the HD images using the provided bounding boxes.
+
+ Parameters
+ ----------
+ full_imgs: ImageList
+ An image list structure with the full resolution images
+ center: torch.Tensor
+ A Bx2 tensor that contains the coordinates of the center of
+ the bounding box that will be cropped from the original
+ image
+ bbox_size: torch.Tensor
+ A size B tensor that contains the size of the corp
+
+ Returns
+ -------
+ cropped_images: torch.Tensoror
+ The images cropped from the high resolution input
+ sampling_grid: torch.Tensor
+ The grid used to sample the crops
+ """
+
+ batch_size, _, H, W = full_imgs.shape
+ self.grid = self.grid.to(device=full_imgs.device)
+ transforms = torch.eye(3,
+ dtype=full_imgs.dtype,
+ device=full_imgs.device).reshape(
+ 1, 3, 3).expand(batch_size, -1,
+ -1).contiguous()
+
+ hd_to_crop = torch.eye(3,
+ dtype=full_imgs.dtype,
+ device=full_imgs.device).reshape(
+ 1, 3, 3).expand(batch_size, -1,
+ -1).contiguous()
+
+ # Create the transformation that maps crop pixels to image coordinates,
+ # i.e. pixel (0, 0) from the crop_size x crop_size grid gets mapped to
+ # the top left of the bounding box, pixel
+ # (crop_size - 1, crop_size - 1) to the bottom right corner of the
+ # bounding box
+ transforms[:, 0, 0] = bbox_size # / (self.crop_size - 1)
+ transforms[:, 1, 1] = bbox_size # / (self.crop_size - 1)
+ transforms[:, 0, 2] = center[:, 0] - bbox_size * 0.5
+ transforms[:, 1, 2] = center[:, 1] - bbox_size * 0.5
+
+ hd_to_crop[:, 0, 0] = 2 * (self.crop_size - 1) / bbox_size
+ hd_to_crop[:, 1, 1] = 2 * (self.crop_size - 1) / bbox_size
+ hd_to_crop[:, 0,
+ 2] = -(center[:, 0] - bbox_size * 0.5) * hd_to_crop[:, 0,
+ 0] - 1
+ hd_to_crop[:, 1,
+ 2] = -(center[:, 1] - bbox_size * 0.5) * hd_to_crop[:, 1,
+ 1] - 1
+
+ size_bbox_sizer = torch.eye(3,
+ dtype=full_imgs.dtype,
+ device=full_imgs.device).reshape(
+ 1, 3, 3).expand(batch_size, -1,
+ -1).contiguous()
+
+ # Normalize the coordinates to [-1, 1] for the grid_sample function
+ size_bbox_sizer[:, 0, 0] = 2.0 / (W - 1)
+ size_bbox_sizer[:, 1, 1] = 2.0 / (H - 1)
+ size_bbox_sizer[:, :2, 2] = -1
+
+ # full_transform = transforms
+ full_transform = torch.bmm(size_bbox_sizer, transforms)
+
+ batch_grid = self.grid.expand(batch_size, -1, -1)
+ # Convert the grid to image coordinates using the transformations above
+ sampling_grid = (
+ torch.bmm(full_transform[:, :2, :2], batch_grid.transpose(1, 2)) +
+ full_transform[:, :2, [2]]).transpose(1, 2)
+ sampling_grid = sampling_grid.reshape(-1, self.crop_size,
+ self.crop_size,
+ 2).transpose(1, 2)
+
+ out_images = self._sample_padded(full_imgs, sampling_grid)
+
+ return {
+ 'images': out_images,
+ 'sampling_grid': sampling_grid.reshape(batch_size, -1, 2),
+ 'transform': transforms,
+ 'hd_to_crop': hd_to_crop,
+ }
+
+
+class SMPLXHandCropFunc():
+ """This function crop hand image from the original image.
+
+ Use the output keypoints predicted by the body model to locate the hand
+ position.
+ """
+ def __init__(self,
+ model_head,
+ body_model,
+ convention='smplx',
+ img_res=256,
+ scale_factor=2.0,
+ crop_size=224,
+ condition_hand_wrist_pose=True,
+ condition_hand_shape=False,
+ condition_hand_finger_pose=True):
+ self.model_head = model_head
+ self.body_model = body_model
+ self.img_res = img_res
+ self.convention = convention
+ self.left_hand_idxs = get_keypoint_idxs_by_part(
+ 'left_hand', self.convention)
+ left_wrist_idx = get_keypoint_idx('left_wrist', self.convention)
+ self.left_hand_idxs.append(left_wrist_idx)
+ self.left_wrist_kin_chain = find_joint_kin_chain(
+ left_wrist_idx, self.body_model.parents)
+
+ self.right_hand_idxs = get_keypoint_idxs_by_part(
+ 'right_hand', self.convention)
+ right_wrist_idx = get_keypoint_idx('right_wrist', self.convention)
+ self.right_hand_idxs.append(right_wrist_idx)
+ self.right_wrist_kin_chain = find_joint_kin_chain(
+ right_wrist_idx, self.body_model.parents)
+
+ self.scale_factor = scale_factor
+ self.hand_cropper = CropSampler(crop_size)
+
+ self.condition_hand_wrist_pose = condition_hand_wrist_pose
+ self.condition_hand_shape = condition_hand_shape
+ self.condition_hand_finger_pose = condition_hand_finger_pose
+
+ def build_hand_mean(self, global_orient, body_pose, betas, left_hand_pose,
+ raw_right_hand_pose, batch_size):
+ """Builds the initial point for the iterative regressor of the hand."""
+ hand_mean = []
+
+ # if self.condition_hand_on_body:
+ # Convert the absolute pose to the latent representation
+ if self.condition_hand_wrist_pose:
+ # Compute the absolute pose of the right wrist
+ right_wrist_pose_abs = find_joint_global_rotation(
+ self.right_wrist_kin_chain, global_orient, body_pose)
+ right_wrist_pose = right_wrist_pose_abs[:, :3, :2].contiguous(
+ ).reshape(batch_size, -1)
+
+ # Compute the absolute rotation for the left wrist
+ left_wrist_pose_abs = find_joint_global_rotation(
+ self.left_wrist_kin_chain, global_orient, body_pose)
+ # Flip the left wrist to the right
+ left_to_right_wrist_pose = flip_rotmat(left_wrist_pose_abs)
+
+ # Convert to the latent representation
+ left_to_right_wrist_pose = left_to_right_wrist_pose[:, :3, :
+ 2].contiguous(
+ ).reshape(
+ batch_size,
+ -1)
+ else:
+ right_wrist_pose = self.model_head.get_mean('global_orient',
+ batch_size=batch_size)
+ left_to_right_wrist_pose = self.model_head.get_mean(
+ 'global_orient', batch_size=batch_size)
+
+ # Convert the pose of the left hand to the right hand and project
+ # it to the encoder space
+ left_to_right_hand_pose = flip_rotmat(
+ left_hand_pose)[:, :, :3, :2].contiguous().reshape(batch_size, -1)
+ right_hand_pose = raw_right_hand_pose.reshape(batch_size, -1)
+ camera_mean = self.model_head.get_mean('camera', batch_size=batch_size)
+
+ shape_condition = (betas if self.condition_hand_shape else
+ self.model_head.get_mean('shape',
+ batch_size=batch_size))
+ right_finger_pose_condition = (
+ right_hand_pose if self.condition_hand_finger_pose else
+ self.model_head.get_mean('right_hand_pose', batch_size=batch_size))
+ right_hand_mean = torch.cat([
+ right_wrist_pose, right_finger_pose_condition, shape_condition,
+ camera_mean
+ ],
+ dim=1)
+
+ left_finger_pose_condition = (
+ left_to_right_hand_pose if self.condition_hand_finger_pose else
+ self.model_head.get_mean('right_hand_pose', batch_size=batch_size))
+ # Should be Bx31
+ left_hand_mean = torch.cat([
+ left_to_right_wrist_pose, left_finger_pose_condition,
+ shape_condition, camera_mean
+ ],
+ dim=1)
+
+ hand_mean += [right_hand_mean, left_hand_mean]
+ hand_mean = torch.cat(hand_mean, dim=0)
+
+ return hand_mean
+
+ def __call__(self, body_predictions, img_metas):
+ """Function
+ Args:
+ body_predictions (dict): The prediction from body model.
+ img_metas (dict): Information of the input images.
+ Returns:
+ all_hand_imgs (torch.tensor): Cropped hand images.
+ hand_mean (torch.tensor): Mean value of hand params.
+ crop_info (dict): Hand crop transforms.
+ """
+ pred_param = body_predictions['pred_param']
+ pred_cam = body_predictions['pred_cam']
+ pred_raw = body_predictions['pred_raw']
+ pred_output = self.body_model(**pred_param)
+
+ pred_keypoints3d = pred_output['joints']
+ pred_keypoints2d = weak_perspective_projection(
+ pred_keypoints3d,
+ scale=pred_cam[:, 0],
+ translation=pred_cam[:, 1:3])
+ # concat ori_img
+ full_images = []
+ for img_meta in img_metas:
+ full_images.append(img_meta['ori_img'].to(device=pred_cam.device))
+ full_imgs = concat_images(full_images)
+
+ # left hand
+ left_hand_joints = (pred_keypoints2d[:, self.left_hand_idxs] * 0.5 +
+ 0.5) * (self.img_res - 1)
+ left_hand_points_to_crop = get_crop_info(left_hand_joints, img_metas,
+ self.scale_factor,
+ self.img_res)
+ left_hand_center = left_hand_points_to_crop['center']
+ left_hand_orig_bbox_size = left_hand_points_to_crop['orig_bbox_size']
+ left_hand_inv_crop_transforms = left_hand_points_to_crop[
+ 'inv_crop_transforms']
+
+ left_hand_cropper_out = self.hand_cropper(full_imgs, left_hand_center,
+ left_hand_orig_bbox_size)
+ left_hand_crops = left_hand_cropper_out['images']
+ # left_hand_points = left_hand_cropper_out['sampling_grid']
+ left_hand_crop_transform = left_hand_cropper_out['transform']
+
+ # right hand
+ right_hand_joints = (pred_keypoints2d[:, self.right_hand_idxs] * 0.5 +
+ 0.5) * (self.img_res - 1)
+ right_hand_points_to_crop = get_crop_info(right_hand_joints, img_metas,
+ self.scale_factor,
+ self.img_res)
+ right_hand_center = right_hand_points_to_crop['center']
+ right_hand_orig_bbox_size = right_hand_points_to_crop['orig_bbox_size']
+ # right_hand_inv_crop_transforms = right_hand_points_to_crop[
+ # 'inv_crop_transforms']
+ right_hand_cropper_out = self.hand_cropper(full_imgs,
+ right_hand_center,
+ right_hand_orig_bbox_size)
+ right_hand_crops = right_hand_cropper_out['images']
+ # right_hand_points = right_hand_cropper_out['sampling_grid']
+ right_hand_crop_transform = right_hand_cropper_out['transform']
+
+ # concat
+ all_hand_imgs = []
+ all_hand_imgs.append(right_hand_crops)
+ all_hand_imgs.append(torch.flip(left_hand_crops, dims=(-1, )))
+
+ # [right_hand , left hand]
+ all_hand_imgs = torch.cat(all_hand_imgs, dim=0)
+ hand_mean = self.build_hand_mean(pred_param['global_orient'],
+ pred_param['body_pose'],
+ pred_param['betas'],
+ pred_param['left_hand_pose'],
+ pred_raw['raw_right_hand_pose'],
+ batch_size=full_imgs.shape[0])
+ crop_info = dict(
+ hand_inv_crop_transforms=left_hand_inv_crop_transforms,
+ left_hand_crop_transform=left_hand_crop_transform,
+ right_hand_crop_transform=right_hand_crop_transform)
+ return all_hand_imgs, hand_mean, crop_info
+
+
+class SMPLXFaceCropFunc():
+ """This function crop face image from the original image.
+
+ Use the output keypoints predicted by the facce model to locate the face
+ position.
+ """
+ def __init__(self,
+ model_head,
+ body_model,
+ convention='smplx',
+ img_res=256,
+ scale_factor=2.0,
+ crop_size=256,
+ num_betas=10,
+ num_expression_coeffs=10,
+ condition_face_neck_pose=False,
+ condition_face_jaw_pose=True,
+ condition_face_shape=False,
+ condition_face_expression=True):
+ self.model_head = model_head
+ self.body_model = body_model
+ self.img_res = img_res
+ self.convention = convention
+ self.num_betas = num_betas
+ self.num_expression_coeffs = num_expression_coeffs
+
+ self.face_idx = get_keypoint_idxs_by_part('head', self.convention)
+ neck_idx = get_keypoint_idx('neck', self.convention)
+ self.neck_kin_chain = find_joint_kin_chain(neck_idx,
+ self.body_model.parents)
+
+ self.condition_face_neck_pose = condition_face_neck_pose
+ self.condition_face_jaw_pose = condition_face_jaw_pose
+ self.condition_face_shape = condition_face_shape
+ self.condition_face_expression = condition_face_expression
+
+ self.scale_factor = scale_factor
+ self.face_cropper = CropSampler(crop_size)
+
+ def build_face_mean(self, global_orient, body_pose, betas, raw_jaw_pose,
+ expression, batch_size):
+ """Builds the initial point for the iterative regressor of the face."""
+ face_mean = []
+ # Compute the absolute pose of the right wrist
+ neck_pose_abs = find_joint_global_rotation(self.neck_kin_chain,
+ global_orient, body_pose)
+ # Convert the absolute neck pose to offsets
+ neck_pose = neck_pose_abs[:, :3, :2].contiguous().reshape(
+ batch_size, -1)
+
+ camera_mean = self.model_head.get_mean('camera', batch_size=batch_size)
+
+ neck_pose_condition = (neck_pose if self.condition_face_neck_pose else
+ self.model_head.get_mean('global_orient',
+ batch_size=batch_size))
+
+ jaw_pose_condition = (raw_jaw_pose.reshape(batch_size, -1)
+ if self.condition_face_jaw_pose else
+ self.model_head.get_mean('jaw_pose',
+ batch_size=batch_size))
+ face_num_betas = self.model_head.get_num_betas()
+ shape_padding_size = face_num_betas - self.num_betas
+ betas_condition = (
+ F.pad(betas.reshape(batch_size, -1),
+ (0, shape_padding_size)) if self.condition_face_shape else
+ self.model_head.get_mean('shape', batch_size=batch_size))
+
+ face_num_expression_coeffs = self.model_head.get_num_expression_coeffs(
+ )
+ expr_padding_size = face_num_expression_coeffs \
+ - self.num_expression_coeffs
+ expression_condition = (
+ F.pad(expression.reshape(batch_size, -1),
+ (0, expr_padding_size)) if self.condition_face_expression
+ else self.model_head.get_mean('expression', batch_size=batch_size))
+
+ # Should be Bx(Head pose params)
+ face_mean.append(
+ torch.cat([
+ neck_pose_condition,
+ jaw_pose_condition,
+ betas_condition,
+ expression_condition,
+ camera_mean.reshape(batch_size, -1),
+ ],
+ dim=1))
+
+ face_mean = torch.cat(face_mean, dim=0)
+ return face_mean
+
+ def __call__(self, body_predictions, img_metas):
+ """Function
+ Args:
+ body_predictions (dict): The prediction from body model.
+ img_metas (dict): Information of the input images.
+ Returns:
+ all_face_imgs (torch.tensor): Cropped face images.
+ face_mean (torch.tensor): Mean value of face params.
+ crop_info (dict): Face crop transforms.
+ """
+ pred_param = body_predictions['pred_param']
+ pred_cam = body_predictions['pred_cam']
+ pred_raw = body_predictions['pred_raw']
+
+ pred_output = self.body_model(**pred_param)
+
+ pred_keypoints3d = pred_output['joints']
+ pred_keypoints2d = weak_perspective_projection(
+ pred_keypoints3d,
+ scale=pred_cam[:, 0],
+ translation=pred_cam[:, 1:3])
+ # concat ori_img
+ full_images = []
+ for img_meta in img_metas:
+ full_images.append(img_meta['ori_img'].to(device=pred_cam.device))
+ full_imgs = concat_images(full_images)
+
+ face_joints = (pred_keypoints2d[:, self.face_idx] * 0.5 +
+ 0.5) * (self.img_res - 1)
+ face_points_to_crop = get_crop_info(face_joints, img_metas,
+ self.scale_factor, self.img_res)
+ face_center = face_points_to_crop['center']
+ face_orig_bbox_size = face_points_to_crop['orig_bbox_size']
+ face_inv_crop_transforms = face_points_to_crop['inv_crop_transforms']
+
+ face_cropper_out = self.face_cropper(full_imgs, face_center,
+ face_orig_bbox_size)
+ face_crops = face_cropper_out['images']
+ # face_points = face_cropper_out['sampling_grid']
+ face_crop_transform = face_cropper_out['transform']
+
+ all_face_imgs = [face_crops]
+ all_face_imgs = torch.cat(all_face_imgs, dim=0)
+
+ face_mean = self.build_face_mean(pred_param['global_orient'],
+ pred_param['body_pose'],
+ pred_param['betas'],
+ pred_raw['raw_jaw_pose'],
+ pred_param['expression'],
+ batch_size=full_imgs.shape[0])
+ crop_info = dict(face_inv_crop_transforms=face_inv_crop_transforms,
+ face_crop_transform=face_crop_transform)
+ return all_face_imgs, face_mean, crop_info
diff --git a/detrsmpl/models/utils/__init__.py b/detrsmpl/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..565fd43e8ec73a284cf1a013eb7eefacd8e1f984
--- /dev/null
+++ b/detrsmpl/models/utils/__init__.py
@@ -0,0 +1,23 @@
+from .builder import (
+ build_linear_layer,
+ build_positional_encoding,
+ build_transformer,
+)
+from .fits_dict import FitsDict
+from .inverse_kinematics import batch_inverse_kinematics_transform
+from .res_layer import ResLayer, SimplifiedBasicBlock
+from .SMPLX import (
+ SMPLXFaceCropFunc,
+ SMPLXFaceMergeFunc,
+ SMPLXHandCropFunc,
+ SMPLXHandMergeFunc,
+)
+
+
+__all__ = [
+ 'build_linear_layer', 'build_positional_encoding',
+ 'FitsDict', 'ResLayer', 'SimplifiedBasicBlock',
+ 'batch_inverse_kinematics_transform', 'SMPLXHandCropFunc',
+ 'SMPLXFaceMergeFunc', 'SMPLXFaceCropFunc', 'SMPLXHandMergeFunc',
+
+]
diff --git a/detrsmpl/models/utils/builder.py b/detrsmpl/models/utils/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e34647e4f0dda82e34c57b48b78549e09d406c67
--- /dev/null
+++ b/detrsmpl/models/utils/builder.py
@@ -0,0 +1,61 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.utils import Registry, build_from_cfg
+
+from .positional_encoding import (
+ LearnedPositionalEncoding,
+ SinePositionalEncoding,
+)
+
+TRANSFORMER = Registry('Transformer')
+LINEAR_LAYERS = Registry('linear layers')
+POSITIONAL_ENCODING = Registry('position encoding')
+
+LINEAR_LAYERS.register_module('Linear', module=nn.Linear)
+POSITIONAL_ENCODING.register_module('SinePositionalEncoding',
+ module=SinePositionalEncoding)
+POSITIONAL_ENCODING.register_module('LearnedPositionalEncoding',
+ module=LearnedPositionalEncoding)
+
+
+def build_transformer(cfg, default_args=None):
+ """Builder for Transformer."""
+ return build_from_cfg(cfg, TRANSFORMER, default_args)
+
+
+def build_linear_layer(cfg, *args, **kwargs):
+ """Build linear layer.
+ Args:
+ cfg (None or dict): The linear layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an linear layer.
+ args (argument list): Arguments passed to the `__init__`
+ method of the corresponding linear layer.
+ kwargs (keyword arguments): Keyword arguments passed to the `__init__`
+ method of the corresponding linear layer.
+ Returns:
+ nn.Module: Created linear layer.
+ """
+ if cfg is None:
+ cfg_ = dict(type='Linear')
+ else:
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in LINEAR_LAYERS:
+ raise KeyError(f'Unrecognized linear type {layer_type}')
+ else:
+ linear_layer = LINEAR_LAYERS.get(layer_type)
+
+ layer = linear_layer(*args, **kwargs, **cfg_)
+
+ return layer
+
+
+def build_positional_encoding(cfg, default_args=None):
+ """Builder for Position Encoding."""
+ return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
diff --git a/detrsmpl/models/utils/fits_dict.py b/detrsmpl/models/utils/fits_dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..c090170364cac97ba376e3a4253120ed441fa5ce
--- /dev/null
+++ b/detrsmpl/models/utils/fits_dict.py
@@ -0,0 +1,134 @@
+# ------------------------------------------------------------------------------
+# Adapted from https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py
+# Original licence please see docs/additional_licenses.md
+# ------------------------------------------------------------------------------
+
+import os
+
+import cv2
+import numpy as np
+import torch
+
+from detrsmpl.utils.transforms import aa_to_rotmat
+
+train_datasets = ['h36m', 'mpi_inf_3dhp', 'lsp', 'lspet', 'mpii', 'coco']
+static_fits_load_dir = 'data/static_fits'
+save_dir = 'data/spin_fits'
+
+# Permutation of SMPL pose parameters when flipping the shape
+SMPL_JOINTS_FLIP_PERM = [
+ 0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21,
+ 20, 23, 22
+]
+SMPL_POSE_FLIP_PERM = []
+for i in SMPL_JOINTS_FLIP_PERM:
+ SMPL_POSE_FLIP_PERM.append(3 * i)
+ SMPL_POSE_FLIP_PERM.append(3 * i + 1)
+ SMPL_POSE_FLIP_PERM.append(3 * i + 2)
+
+
+class FitsDict():
+ """Dictionary keeping track of the best fit per image in the training set.
+
+ Ref: https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py
+ """
+ def __init__(self, fits='static') -> None:
+ assert fits in ['static', 'final']
+ self.fits = fits
+ self.fits_dict = {}
+
+ # array used to flip SMPL pose parameters
+ self.flipped_parts = torch.tensor(SMPL_POSE_FLIP_PERM,
+ dtype=torch.int64)
+ # Load dictionary state
+ # for ds_name, ds in train_dataset.dataset_dict.items():
+ for ds_name in train_datasets:
+
+ # h36m has gt so no static fits
+ if ds_name == 'h36m' or self.fits == 'static':
+ dict_file = os.path.join(static_fits_load_dir,
+ ds_name + '_fits.npy')
+ content = np.load(dict_file)
+ self.fits_dict[ds_name] = torch.from_numpy(content)
+ del content
+ elif self.fits == 'final':
+ dict_file = os.path.join('data/final_fits', ds_name + '.npz')
+ # load like this to save mem
+ content = np.load(dict_file)
+ pose = torch.from_numpy(content['pose'])
+ betas = torch.from_numpy(content['betas'])
+ del content
+ params = torch.cat([pose, betas], dim=-1)
+ self.fits_dict[ds_name] = params
+
+ def save(self):
+ """Save dictionary state to disk."""
+ for ds_name in train_datasets:
+ dict_file = os.path.join(save_dir, ds_name + '_fits.npy')
+ np.save(dict_file, self.fits_dict[ds_name].cpu().numpy())
+
+ def __getitem__(self, x):
+ """Retrieve dictionary entries."""
+ dataset_name, ind, rot, is_flipped = x
+ batch_size = len(dataset_name)
+ pose = torch.zeros((batch_size, 72))
+ betas = torch.zeros((batch_size, 10))
+ for ds, i, n in zip(dataset_name, ind, range(batch_size)):
+ params = self.fits_dict[ds][i]
+ pose[n, :] = params[:72]
+ betas[n, :] = params[72:]
+ pose = pose.clone()
+
+ # Apply flipping and rotation
+ pose = self.rotate_pose(self.flip_pose(pose, is_flipped), rot)
+
+ betas = betas.clone()
+ return pose, betas
+
+ def __setitem__(self, x, val):
+ """Update dictionary entries."""
+ dataset_name, ind, rot, is_flipped, update = x
+ pose, betas = val
+ batch_size = len(dataset_name)
+
+ # Undo flipping and rotation
+ pose = self.flip_pose(self.rotate_pose(pose, -rot), is_flipped)
+
+ params = torch.cat((pose, betas), dim=-1).cpu()
+ for ds, i, n in zip(dataset_name, ind, range(batch_size)):
+ if update[n]:
+ self.fits_dict[ds][i] = params[n]
+
+ def flip_pose(self, pose, is_flipped):
+ """flip SMPL pose parameters."""
+ is_flipped = is_flipped.bool()
+ pose_f = pose.clone()
+ pose_f[is_flipped, :] = pose[is_flipped][:, self.flipped_parts]
+ # we also negate the second and the third dimension of the
+ # axis-angle representation
+ pose_f[is_flipped, 1::3] *= -1
+ pose_f[is_flipped, 2::3] *= -1
+ return pose_f
+
+ def rotate_pose(self, pose, rot):
+ """Rotate SMPL pose parameters by rot degrees."""
+ pose = pose.clone()
+ cos = torch.cos(-np.pi * rot / 180.)
+ sin = torch.sin(-np.pi * rot / 180.)
+ zeros = torch.zeros_like(cos)
+ r3 = torch.zeros(cos.shape[0], 1, 3, device=cos.device)
+ r3[:, 0, -1] = 1
+ R = torch.cat([
+ torch.stack([cos, -sin, zeros], dim=-1).unsqueeze(1),
+ torch.stack([sin, cos, zeros], dim=-1).unsqueeze(1), r3
+ ],
+ dim=1)
+ global_pose = pose[:, :3]
+ global_pose_rotmat = R @ aa_to_rotmat(global_pose)
+ global_pose_rotmat = global_pose_rotmat.cpu().numpy()
+ global_pose_np = np.zeros((global_pose.shape[0], 3))
+ for i in range(global_pose.shape[0]):
+ aa, _ = cv2.Rodrigues(global_pose_rotmat[i])
+ global_pose_np[i, :] = aa.squeeze()
+ pose[:, :3] = torch.from_numpy(global_pose_np).to(pose.device)
+ return pose
diff --git a/detrsmpl/models/utils/inverse_kinematics.py b/detrsmpl/models/utils/inverse_kinematics.py
new file mode 100644
index 0000000000000000000000000000000000000000..cff524e2275170960c9b68ff59aa382298acdae4
--- /dev/null
+++ b/detrsmpl/models/utils/inverse_kinematics.py
@@ -0,0 +1,432 @@
+"""This script is based on the release codes:
+
+"HybrIK: A Hybrid Analytical-Neural Inverse Kinematics Solution for 3D Human
+Pose and Shape Estimation. CVPR 2021"
+(https://github.com/Jeff-sjtu/HybrIK).
+"""
+
+from __future__ import absolute_import, division, print_function
+
+import torch
+
+from detrsmpl.utils.transforms import aa_to_rotmat
+
+
+def batch_inverse_kinematics_transform(pose_skeleton,
+ global_orient,
+ phis,
+ rest_pose,
+ children,
+ parents,
+ dtype=torch.float32,
+ train=False,
+ leaf_thetas=None):
+ """Applies inverse kinematics transform to joints in a batch.
+
+ Args:
+ pose_skeleton (torch.tensor):
+ Locations of estimated pose skeleton with shape (Bx29x3)
+ global_orient (torch.tensor|none):
+ Tensor of global rotation matrices with shape (Bx1x3x3)
+ phis (torch.tensor):
+ Rotation on bone axis parameters with shape (Bx23x2)
+ rest_pose (torch.tensor):
+ Locations of rest (Template) pose with shape (Bx29x3)
+ children (List[int]): list of indexes of kinematic children with len 29
+ parents (List[int]): list of indexes of kinematic parents with len 29
+ dtype (torch.dtype, optional):
+ Data type of the created tensors. Default: torch.float32
+ train (bool):
+ Store True in train mode. Default: False
+ leaf_thetas (torch.tensor, optional):
+ Rotation matrixes for 5 leaf joints (Bx5x3x3). Default: None
+
+
+ Returns:
+ rot_mats (torch.tensor):
+ Rotation matrics of all joints with shape (Bx29x3x3)
+ rotate_rest_pose (torch.tensor):
+ Locations of rotated rest/ template pose with shape (Bx29x3)
+ """
+ batch_size = pose_skeleton.shape[0]
+ device = pose_skeleton.device
+
+ rel_rest_pose = rest_pose.clone()
+ # vec_t_k = t_k - t_pa(k)
+ rel_rest_pose[:, 1:] -= rest_pose[:, parents[1:]].clone()
+ rel_rest_pose = torch.unsqueeze(rel_rest_pose, dim=-1)
+
+ # rotate the T pose
+ rotate_rest_pose = torch.zeros_like(rel_rest_pose)
+ # set up the root
+ rotate_rest_pose[:, 0] = rel_rest_pose[:, 0]
+
+ rel_pose_skeleton = torch.unsqueeze(pose_skeleton.clone(), dim=-1).detach()
+ rel_pose_skeleton[:, 1:] -= rel_pose_skeleton[:, parents[1:]].clone()
+ rel_pose_skeleton[:, 0] = rel_rest_pose[:, 0]
+
+ # the predicted final pose
+ final_pose_skeleton = torch.unsqueeze(pose_skeleton.clone(), dim=-1)
+ if train:
+ final_pose_skeleton[:, 1:] -= \
+ final_pose_skeleton[:, parents[1:]].clone()
+ final_pose_skeleton[:, 0] = rel_rest_pose[:, 0]
+ else:
+ final_pose_skeleton += \
+ rel_rest_pose[:, 0:1] - final_pose_skeleton[:, 0:1]
+
+ rel_rest_pose = rel_rest_pose
+ rel_pose_skeleton = rel_pose_skeleton
+ final_pose_skeleton = final_pose_skeleton
+ rotate_rest_pose = rotate_rest_pose
+
+ assert phis.dim() == 3
+ phis = phis / (torch.norm(phis, dim=2, keepdim=True) + 1e-8)
+
+ if train:
+ global_orient_mat = batch_get_pelvis_orient(rel_pose_skeleton.clone(),
+ rel_rest_pose.clone(),
+ parents, children, dtype)
+ else:
+ global_orient_mat = batch_get_pelvis_orient_svd(
+ rel_pose_skeleton.clone(), rel_rest_pose.clone(), parents,
+ children, dtype)
+
+ rot_mat_chain = [global_orient_mat]
+ rot_mat_local = [global_orient_mat]
+ # leaf nodes rot_mats
+ if leaf_thetas is not None:
+ leaf_cnt = 0
+ leaf_rot_mats = leaf_thetas.view([batch_size, 5, 3, 3])
+
+ for i in range(1, parents.shape[0]):
+ if children[i] == -1:
+ # leaf nodes
+ if leaf_thetas is not None:
+ rot_mat = leaf_rot_mats[:, leaf_cnt, :, :]
+ leaf_cnt += 1
+
+ rotate_rest_pose[:, i] = rotate_rest_pose[:, parents[
+ i]] + torch.matmul(rot_mat_chain[parents[i]],
+ rel_rest_pose[:, i])
+
+ rot_mat_chain.append(
+ torch.matmul(rot_mat_chain[parents[i]], rot_mat))
+ rot_mat_local.append(rot_mat)
+ elif children[i] == -3:
+ # three children
+ rotate_rest_pose[:, i] = rotate_rest_pose[:, parents[i]] + \
+ torch.matmul(rot_mat_chain[parents[i]], rel_rest_pose[:, i])
+
+ spine_child = []
+ for c in range(1, parents.shape[0]):
+ if parents[c] == i and c not in spine_child:
+ spine_child.append(c)
+
+ # original
+ spine_child = []
+ for c in range(1, parents.shape[0]):
+ if parents[c] == i and c not in spine_child:
+ spine_child.append(c)
+
+ children_final_loc = []
+ children_rest_loc = []
+ for c in spine_child:
+ temp = final_pose_skeleton[:, c] - rotate_rest_pose[:, i]
+ children_final_loc.append(temp)
+
+ children_rest_loc.append(rel_rest_pose[:, c].clone())
+
+ rot_mat = batch_get_3children_orient_svd(children_final_loc,
+ children_rest_loc,
+ rot_mat_chain[parents[i]],
+ spine_child, dtype)
+
+ rot_mat_chain.append(
+ torch.matmul(rot_mat_chain[parents[i]], rot_mat))
+ rot_mat_local.append(rot_mat)
+ else:
+ # Naive Hybrik
+ if train:
+ # i: the index of k-th joint
+ child_rest_loc = rel_rest_pose[:, i]
+ child_final_loc = final_pose_skeleton[:, i]
+
+ # q_pa(k) = q_pa^2(k) + R_pa(k)(t_pa(k) - t_pa^2(k))
+ rotate_rest_pose[:, i] = rotate_rest_pose[:, parents[i]] + \
+ torch.matmul(rot_mat_chain[parents[i]], rel_rest_pose[:, i])
+ # Adaptive HybrIK
+ if not train:
+ # children[i]: the index of k-th joint
+ child_rest_loc = rel_rest_pose[:, children[i]]
+ child_final_loc = final_pose_skeleton[:, children[
+ i]] - rotate_rest_pose[:, i]
+
+ orig_vec = rel_pose_skeleton[:, children[i]]
+ template_vec = rel_rest_pose[:, children[i]]
+ norm_t = torch.norm(template_vec, dim=1, keepdim=True)
+ orig_vec = orig_vec * norm_t / torch.norm(
+ orig_vec, dim=1, keepdim=True)
+
+ diff = torch.norm(child_final_loc - orig_vec,
+ dim=1,
+ keepdim=True)
+ big_diff_idx = torch.where(diff > 15 / 1000)[0]
+
+ child_final_loc[big_diff_idx] = orig_vec[big_diff_idx]
+
+ # train: vec_p_k = R_pa(k).T * (p_k - p_pa(k))
+ # test: vec_p_k = R_pa(k).T * (p_k - q_pa(k))
+ child_final_loc = torch.matmul(
+ rot_mat_chain[parents[i]].transpose(1, 2), child_final_loc)
+
+ # (B, 1, 1)
+ child_final_norm = torch.norm(child_final_loc, dim=1, keepdim=True)
+ child_rest_norm = torch.norm(child_rest_loc, dim=1, keepdim=True)
+
+ # vec_n
+ axis = torch.cross(child_rest_loc, child_final_loc, dim=1)
+ axis_norm = torch.norm(axis, dim=1, keepdim=True)
+
+ # (B, 1, 1)
+ cos = torch.sum(
+ child_rest_loc * child_final_loc, dim=1,
+ keepdim=True) / (child_rest_norm * child_final_norm + 1e-8)
+ sin = axis_norm / (child_rest_norm * child_final_norm + 1e-8)
+
+ # (B, 3, 1)
+ axis = axis / (axis_norm + 1e-8)
+
+ # Convert location revolve to rot_mat by rodrigues
+ # (B, 1, 1)
+ rx, ry, rz = torch.split(axis, 1, dim=1)
+ zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device)
+
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros],
+ dim=1).view((batch_size, 3, 3))
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
+ rot_mat_loc = ident + sin * K + (1 - cos) * torch.bmm(K, K)
+
+ # Convert spin to rot_mat
+ # (B, 3, 1)
+ spin_axis = child_rest_loc / child_rest_norm
+ # (B, 1, 1)
+ rx, ry, rz = torch.split(spin_axis, 1, dim=1)
+ zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device)
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros],
+ dim=1).view((batch_size, 3, 3))
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
+ # (B, 1, 1)
+ cos, sin = torch.split(phis[:, i - 1], 1, dim=1)
+ cos = torch.unsqueeze(cos, dim=2)
+ sin = torch.unsqueeze(sin, dim=2)
+ rot_mat_spin = ident + sin * K + (1 - cos) * torch.bmm(K, K)
+ rot_mat = torch.matmul(rot_mat_loc, rot_mat_spin)
+
+ rot_mat_chain.append(
+ torch.matmul(rot_mat_chain[parents[i]], rot_mat))
+ rot_mat_local.append(rot_mat)
+
+ # (B, K + 1, 3, 3)
+ rot_mats = torch.stack(rot_mat_local, dim=1)
+
+ return rot_mats, rotate_rest_pose.squeeze(-1)
+
+
+def batch_get_pelvis_orient_svd(rel_pose_skeleton, rel_rest_pose, parents,
+ children, dtype):
+ """Get pelvis orientation svd for batch data.
+
+ Args:
+ rel_pose_skeleton (torch.tensor):
+ Locations of root-normalized pose skeleton with shape (Bx29x3)
+ rel_rest_pose (torch.tensor):
+ Locations of rest/ template pose with shape (Bx29x3)
+ parents (List[int]): list of indexes of kinematic parents with len 29
+ children (List[int]): list of indexes of kinematic children with len 29
+ dtype (torch.dtype, optional):
+ Data type of the created tensors, the default is torch.float32
+
+ Returns:
+ rot_mat (torch.tensor):
+ Rotation matrix of pelvis with shape (Bx3x3)
+ """
+ pelvis_child = [int(children[0])]
+ for i in range(1, parents.shape[0]):
+ if parents[i] == 0 and i not in pelvis_child:
+ pelvis_child.append(i)
+
+ rest_mat = []
+ target_mat = []
+ for child in pelvis_child:
+ rest_mat.append(rel_rest_pose[:, child].clone())
+ target_mat.append(rel_pose_skeleton[:, child].clone())
+
+ rest_mat = torch.cat(rest_mat, dim=2)
+ target_mat = torch.cat(target_mat, dim=2)
+ S = rest_mat.bmm(target_mat.transpose(1, 2))
+
+ mask_zero = S.sum(dim=(1, 2))
+
+ S_non_zero = S[mask_zero != 0].reshape(-1, 3, 3)
+
+ U, _, V = torch.svd(S_non_zero)
+
+ rot_mat = torch.zeros_like(S)
+ rot_mat[mask_zero == 0] = torch.eye(3, device=S.device)
+
+ rot_mat_non_zero = torch.bmm(V, U.transpose(1, 2))
+ rot_mat[mask_zero != 0] = rot_mat_non_zero
+
+ assert torch.sum(torch.isnan(rot_mat)) == 0, ('rot_mat', rot_mat)
+
+ return rot_mat
+
+
+def batch_get_pelvis_orient(rel_pose_skeleton, rel_rest_pose, parents,
+ children, dtype):
+ """Get pelvis orientation for batch data.
+
+ Args:
+ rel_pose_skeleton (torch.tensor):
+ Locations of root-normalized pose skeleton with shape (Bx29x3)
+ rel_rest_pose (torch.tensor):
+ Locations of rest/ template pose with shape (Bx29x3)
+ parents (List[int]): list of indexes of kinematic parents with len 29
+ children (List[int]): list of indexes of kinematic children with len 29
+ dtype (torch.dtype, optional):
+ Data type of the created tensors, the default is torch.float32
+
+ Returns:
+ rot_mat (torch.tensor):
+ Rotation matrix of pelvis with shape (Bx3x3)
+ """
+ batch_size = rel_pose_skeleton.shape[0]
+ device = rel_pose_skeleton.device
+
+ assert children[0] == 3
+ pelvis_child = [int(children[0])]
+ for i in range(1, parents.shape[0]):
+ if parents[i] == 0 and i not in pelvis_child:
+ pelvis_child.append(i)
+
+ spine_final_loc = rel_pose_skeleton[:, int(children[0])].clone()
+ spine_rest_loc = rel_rest_pose[:, int(children[0])].clone()
+ # spine_norm = torch.norm(spine_final_loc, dim=1, keepdim=True)
+ # spine_norm = spine_final_loc / (spine_norm + 1e-8)
+
+ # rot_mat_spine = vectors2rotmat(spine_rest_loc, spine_final_loc, dtype)
+
+ # (B, 1, 1)
+ vec_final_norm = torch.norm(spine_final_loc, dim=1, keepdim=True)
+ vec_rest_norm = torch.norm(spine_rest_loc, dim=1, keepdim=True)
+
+ spine_norm = spine_final_loc / (vec_final_norm + 1e-8)
+
+ # (B, 3, 1)
+ axis = torch.cross(spine_rest_loc, spine_final_loc, dim=1)
+ axis_norm = torch.norm(axis, dim=1, keepdim=True)
+ axis = axis / (axis_norm + 1e-8)
+ angle = torch.arccos(
+ torch.sum(spine_rest_loc * spine_final_loc, dim=1, keepdim=True) /
+ (vec_rest_norm * vec_final_norm + 1e-8))
+ axis_angle = (angle * axis).squeeze()
+ # aa to rotmat
+ rot_mat_spine = aa_to_rotmat(axis_angle)
+
+ assert torch.sum(torch.isnan(rot_mat_spine)) == 0, ('rot_mat_spine',
+ rot_mat_spine)
+ center_final_loc = 0
+ center_rest_loc = 0
+ for child in pelvis_child:
+ if child == int(children[0]):
+ continue
+ center_final_loc = center_final_loc + rel_pose_skeleton[:,
+ child].clone()
+ center_rest_loc = center_rest_loc + rel_rest_pose[:, child].clone()
+ center_final_loc = center_final_loc / (len(pelvis_child) - 1)
+ center_rest_loc = center_rest_loc / (len(pelvis_child) - 1)
+
+ center_rest_loc = torch.matmul(rot_mat_spine, center_rest_loc)
+
+ center_final_loc = center_final_loc - torch.sum(
+ center_final_loc * spine_norm, dim=1, keepdim=True) * spine_norm
+ center_rest_loc = center_rest_loc - torch.sum(
+ center_rest_loc * spine_norm, dim=1, keepdim=True) * spine_norm
+
+ center_final_loc_norm = torch.norm(center_final_loc, dim=1, keepdim=True)
+ center_rest_loc_norm = torch.norm(center_rest_loc, dim=1, keepdim=True)
+
+ # (B, 3, 1)
+ axis = torch.cross(center_rest_loc, center_final_loc, dim=1)
+ axis_norm = torch.norm(axis, dim=1, keepdim=True)
+
+ # (B, 1, 1)
+ cos = torch.sum(
+ center_rest_loc * center_final_loc, dim=1,
+ keepdim=True) / (center_rest_loc_norm * center_final_loc_norm + 1e-8)
+ sin = axis_norm / (center_rest_loc_norm * center_final_loc_norm + 1e-8)
+
+ assert torch.sum(torch.isnan(cos)) == 0, ('cos', cos)
+ assert torch.sum(torch.isnan(sin)) == 0, ('sin', sin)
+ # (B, 3, 1)
+ axis = axis / (axis_norm + 1e-8)
+
+ # Convert location revolve to rot_mat by rodrigues
+ # (B, 1, 1)
+ rx, ry, rz = torch.split(axis, 1, dim=1)
+ zeros = torch.zeros((batch_size, 1, 1), dtype=dtype, device=device)
+
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
+ .view((batch_size, 3, 3))
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
+ rot_mat_center = ident + sin * K + (1 - cos) * torch.bmm(K, K)
+
+ rot_mat = torch.matmul(rot_mat_center, rot_mat_spine)
+
+ return rot_mat
+
+
+def batch_get_3children_orient_svd(rel_pose_skeleton, rel_rest_pose,
+ rot_mat_chain_parent, children_list, dtype):
+ """Get pelvis orientation for batch data.
+
+ Args:
+ rel_pose_skeleton (torch.tensor):
+ Locations of root-normalized pose skeleton with shape (Bx29x3)
+ rel_rest_pose (torch.tensor):
+ Locations of rest/ template pose with shape (Bx29x3)
+ rot_mat_chain_parents (torch.tensor):
+ parent's rotation matrix with shape (Bx3x3)
+ children (List[int]): list of indexes of kinematic children with len 29
+ dtype (torch.dtype, optional):
+ Data type of the created tensors, the default is torch.float32
+
+ Returns:
+ rot_mat (torch.tensor):
+ Child's rotation matrix with shape (Bx3x3)
+ """
+ rest_mat = []
+ target_mat = []
+ for c, child in enumerate(children_list):
+ if isinstance(rel_pose_skeleton, list):
+ target = rel_pose_skeleton[c].clone()
+ template = rel_rest_pose[c].clone()
+ else:
+ target = rel_pose_skeleton[:, child].clone()
+ template = rel_rest_pose[:, child].clone()
+
+ target = torch.matmul(rot_mat_chain_parent.transpose(1, 2), target)
+
+ target_mat.append(target)
+ rest_mat.append(template)
+
+ rest_mat = torch.cat(rest_mat, dim=2)
+ target_mat = torch.cat(target_mat, dim=2)
+ S = rest_mat.bmm(target_mat.transpose(1, 2))
+
+ U, _, V = torch.svd(S)
+
+ rot_mat = torch.bmm(V, U.transpose(1, 2))
+ assert torch.sum(torch.isnan(rot_mat)) == 0, ('3children rot_mat', rot_mat)
+ return rot_mat
diff --git a/detrsmpl/models/utils/positional_encoding.py b/detrsmpl/models/utils/positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..c668c5e3564aea1de10f0042a9e458a86fc8e297
--- /dev/null
+++ b/detrsmpl/models/utils/positional_encoding.py
@@ -0,0 +1,159 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from mmcv.runner import BaseModule
+
+
+class SinePositionalEncoding(BaseModule):
+ """Position encoding with sine and cosine functions.
+
+ See `End-to-End Object Detection with Transformers
+ `_ for details.
+
+ Args:
+ num_feats (int): The feature dimension for each position
+ along x-axis or y-axis. Note the final returned dimension
+ for each position is 2 times of this value.
+ temperature (int, optional): The temperature used for scaling
+ the position embedding. Defaults to 10000.
+ normalize (bool, optional): Whether to normalize the position
+ embedding. Defaults to False.
+ scale (float, optional): A scale factor that scales the position
+ embedding. The scale will be used only when `normalize` is True.
+ Defaults to 2*pi.
+ eps (float, optional): A value added to the denominator for
+ numerical stability. Defaults to 1e-6.
+ offset (float): offset add to embed when do the normalization.
+ Defaults to 0.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+ def __init__(self,
+ num_feats,
+ temperature=10000,
+ normalize=False,
+ scale=2 * math.pi,
+ eps=1e-6,
+ offset=0.,
+ init_cfg=None):
+ super(SinePositionalEncoding, self).__init__(init_cfg)
+ if normalize:
+ assert isinstance(scale, (float, int)), 'when normalize is set,' \
+ 'scale should be provided and in float or int type, ' \
+ f'found {type(scale)}'
+ self.num_feats = num_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ self.scale = scale
+ self.eps = eps
+ self.offset = offset
+
+ def forward(self, mask):
+ """Forward function for `SinePositionalEncoding`.
+
+ Args:
+ mask (Tensor): ByteTensor mask. Non-zero values representing
+ ignored positions, while zero values means valid positions
+ for this image. Shape [bs, h, w].
+
+ Returns:
+ pos (Tensor): Returned position embedding with shape
+ [bs, num_feats*2, h, w].
+ """
+ # For convenience of exporting to ONNX, it's required to convert
+ # `masks` from bool to int.
+ mask = mask.to(torch.int)
+ not_mask = 1 - mask # logical_not
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ y_embed = (y_embed + self.offset) / \
+ (y_embed[:, -1:, :] + self.eps) * self.scale
+ x_embed = (x_embed + self.offset) / \
+ (x_embed[:, :, -1:] + self.eps) * self.scale
+ dim_t = torch.arange(self.num_feats,
+ dtype=torch.float32,
+ device=mask.device)
+ dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats)
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ # use `view` instead of `flatten` for dynamically exporting to ONNX
+ B, H, W = mask.size()
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
+ dim=4).view(B, H, W, -1)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
+ dim=4).view(B, H, W, -1)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(num_feats={self.num_feats}, '
+ repr_str += f'temperature={self.temperature}, '
+ repr_str += f'normalize={self.normalize}, '
+ repr_str += f'scale={self.scale}, '
+ repr_str += f'eps={self.eps})'
+ return repr_str
+
+
+class LearnedPositionalEncoding(BaseModule):
+ """Position embedding with learnable embedding weights.
+
+ Args:
+ num_feats (int): The feature dimension for each position
+ along x-axis or y-axis. The final returned dimension for
+ each position is 2 times of this value.
+ row_num_embed (int, optional): The dictionary size of row embeddings.
+ Default 50.
+ col_num_embed (int, optional): The dictionary size of col embeddings.
+ Default 50.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+ def __init__(self,
+ num_feats,
+ row_num_embed=50,
+ col_num_embed=50,
+ init_cfg=dict(type='Uniform', layer='Embedding')):
+ super(LearnedPositionalEncoding, self).__init__(init_cfg)
+ self.row_embed = nn.Embedding(row_num_embed, num_feats)
+ self.col_embed = nn.Embedding(col_num_embed, num_feats)
+ self.num_feats = num_feats
+ self.row_num_embed = row_num_embed
+ self.col_num_embed = col_num_embed
+
+ def forward(self, mask):
+ """Forward function for `LearnedPositionalEncoding`.
+
+ Args:
+ mask (Tensor): ByteTensor mask. Non-zero values representing
+ ignored positions, while zero values means valid positions
+ for this image. Shape [bs, h, w].
+
+ Returns:
+ pos (Tensor): Returned position embedding with shape
+ [bs, num_feats*2, h, w].
+ """
+ h, w = mask.shape[-2:]
+ x = torch.arange(w, device=mask.device)
+ y = torch.arange(h, device=mask.device)
+ x_embed = self.col_embed(x)
+ y_embed = self.row_embed(y)
+ pos = torch.cat(
+ (x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(
+ 1, w, 1)),
+ dim=-1).permute(2, 0,
+ 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1)
+ return pos
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(num_feats={self.num_feats}, '
+ repr_str += f'row_num_embed={self.row_num_embed}, '
+ repr_str += f'col_num_embed={self.col_num_embed})'
+ return repr_str
diff --git a/detrsmpl/models/utils/res_layer.py b/detrsmpl/models/utils/res_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..40cd79c13d9808097daa4934c3b9763565e4628b
--- /dev/null
+++ b/detrsmpl/models/utils/res_layer.py
@@ -0,0 +1,187 @@
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmcv.runner import BaseModule, Sequential
+from torch import nn as nn
+
+
+class ResLayer(Sequential):
+ """ResLayer to build ResNet style backbone.
+
+ Args:
+ block (nn.Module): block used to build ResLayer.
+ inplanes (int): inplanes of block.
+ planes (int): planes of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ downsample_first (bool): Downsample at the first block or last block.
+ False for Hourglass, True for ResNet. Default: True
+ """
+ def __init__(self,
+ block,
+ inplanes,
+ planes,
+ num_blocks,
+ stride=1,
+ avg_down=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ downsample_first=True,
+ **kwargs):
+ self.block = block
+
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = []
+ conv_stride = stride
+ if avg_down:
+ conv_stride = 1
+ downsample.append(
+ nn.AvgPool2d(kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False))
+ downsample.extend([
+ build_conv_layer(conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False),
+ build_norm_layer(norm_cfg, planes * block.expansion)[1]
+ ])
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ if downsample_first:
+ layers.append(
+ block(inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ inplanes = planes * block.expansion
+ for _ in range(1, num_blocks):
+ layers.append(
+ block(inplanes=inplanes,
+ planes=planes,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+
+ else: # downsample_first=False is for HourglassModule
+ for _ in range(num_blocks - 1):
+ layers.append(
+ block(inplanes=inplanes,
+ planes=inplanes,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ layers.append(
+ block(inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ super(ResLayer, self).__init__(*layers)
+
+
+class SimplifiedBasicBlock(BaseModule):
+ """Simplified version of original basic residual block. This is used in
+ `SCNet `_.
+
+ - Norm layer is now optional
+ - Last ReLU in forward function is removed
+ """
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None,
+ init_fg=None):
+ super(SimplifiedBasicBlock, self).__init__(init_fg)
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+ assert not with_cp, 'Not implemented yet.'
+ self.with_norm = norm_cfg is not None
+ with_bias = True if norm_cfg is None else False
+ self.conv1 = build_conv_layer(conv_cfg,
+ inplanes,
+ planes,
+ 3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=with_bias)
+ if self.with_norm:
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg,
+ planes,
+ postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(conv_cfg,
+ planes,
+ planes,
+ 3,
+ padding=1,
+ bias=with_bias)
+ if self.with_norm:
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg,
+ planes,
+ postfix=2)
+ self.add_module(self.norm2_name, norm2)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name) if self.with_norm else None
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name) if self.with_norm else None
+
+ def forward(self, x):
+ """Forward function."""
+
+ identity = x
+
+ out = self.conv1(x)
+ if self.with_norm:
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ if self.with_norm:
+ out = self.norm2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
diff --git a/detrsmpl/models/utils/transformer.py b/detrsmpl/models/utils/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1637938e74bba51a2e23ec0964bf0e5df30851ed
--- /dev/null
+++ b/detrsmpl/models/utils/transformer.py
@@ -0,0 +1,717 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import warnings
+
+import torch
+import torch.nn as nn
+from mmcv.cnn.bricks.registry import (
+ TRANSFORMER_LAYER,
+ TRANSFORMER_LAYER_SEQUENCE,
+)
+from mmcv.cnn.bricks.transformer import (
+ BaseTransformerLayer,
+ TransformerLayerSequence,
+ build_transformer_layer_sequence,
+)
+from mmcv.runner.base_module import BaseModule
+# from mmcv.utils import to_2tuple
+from torch.nn.init import normal_
+
+# from mmdet.models.utils.builder import TRANSFORMER
+from .builder import TRANSFORMER
+
+# import torch.nn.functional as F
+from mmcv.cnn import ( # build_activation_layer,; build_conv_layer,
+ build_norm_layer, xavier_init,
+)
+
+# from typing import Sequence
+
+try:
+ from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
+
+except ImportError:
+ warnings.warn(
+ '`MultiScaleDeformableAttention` in MMCV has been moved to '
+ '`mmcv.ops.multi_scale_deform_attn`, please update your MMCV')
+ from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
+
+
+def inverse_sigmoid(x, eps=1e-5):
+ """Inverse function of sigmoid.
+
+ Args:
+ x (Tensor): The tensor to do the
+ inverse.
+ eps (float): EPS avoid numerical
+ overflow. Defaults 1e-5.
+ Returns:
+ Tensor: The x has passed the inverse
+ function of sigmoid, has same
+ shape with input.
+ """
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
+
+@TRANSFORMER_LAYER.register_module()
+class DetrTransformerDecoderLayer(BaseTransformerLayer):
+ """Implements decoder layer in DETR transformer.
+
+ Args:
+ attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
+ Configs for self_attention or cross_attention, the order
+ should be consistent with it in `operation_order`. If it is
+ a dict, it would be expand to the number of attention in
+ `operation_order`.
+ feedforward_channels (int): The hidden dimension for FFNs.
+ ffn_dropout (float): Probability of an element to be zeroed
+ in ffn. Default 0.0.
+ operation_order (tuple[str]): The execution order of operation
+ in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
+ Default:None
+ act_cfg (dict): The activation config for FFNs. Default: `LN`
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: `LN`.
+ ffn_num_fcs (int): The number of fully-connected layers in FFNs.
+ Default:2.
+ """
+ def __init__(self,
+ attn_cfgs,
+ feedforward_channels,
+ ffn_dropout=0.0,
+ operation_order=None,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'),
+ ffn_num_fcs=2,
+ **kwargs):
+ super(DetrTransformerDecoderLayer,
+ self).__init__(attn_cfgs=attn_cfgs,
+ feedforward_channels=feedforward_channels,
+ ffn_dropout=ffn_dropout,
+ operation_order=operation_order,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ ffn_num_fcs=ffn_num_fcs,
+ **kwargs)
+ assert len(operation_order) == 6
+ assert set(operation_order) == set(
+ ['self_attn', 'norm', 'cross_attn', 'ffn'])
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class DetrTransformerEncoder(TransformerLayerSequence):
+ """TransformerEncoder of DETR.
+
+ Args:
+ post_norm_cfg (dict): Config of last normalization layer. Default:
+ `LN`. Only used when `self.pre_norm` is `True`
+ """
+ def __init__(self, *args, post_norm_cfg=dict(type='LN'), **kwargs):
+ super(DetrTransformerEncoder, self).__init__(*args, **kwargs)
+ if post_norm_cfg is not None:
+ self.post_norm = build_norm_layer(
+ post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None
+ else:
+ assert not self.pre_norm, f'Use prenorm in ' \
+ f'{self.__class__.__name__},' \
+ f'Please specify post_norm_cfg'
+ self.post_norm = None
+
+ def forward(self, *args, **kwargs):
+ """Forward function for `TransformerCoder`.
+
+ Returns:
+ Tensor: forwarded results with shape [num_query, bs, embed_dims].
+ """
+ x = super(DetrTransformerEncoder, self).forward(*args, **kwargs)
+ if self.post_norm is not None:
+ x = self.post_norm(x)
+ return x
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class DetrTransformerDecoder(TransformerLayerSequence):
+ """Implements the decoder in DETR transformer.
+
+ Args:
+ return_intermediate (bool): Whether to return intermediate outputs.
+ post_norm_cfg (dict): Config of last normalization layer. Default:
+ `LN`.
+ """
+ def __init__(self,
+ *args,
+ post_norm_cfg=dict(type='LN'),
+ return_intermediate=False,
+ **kwargs):
+
+ super(DetrTransformerDecoder, self).__init__(*args, **kwargs)
+ self.return_intermediate = return_intermediate
+ if post_norm_cfg is not None:
+ self.post_norm = build_norm_layer(post_norm_cfg,
+ self.embed_dims)[1]
+ else:
+ self.post_norm = None
+
+ def forward(self, query, *args, **kwargs):
+ """Forward function for `TransformerDecoder`.
+
+ Args:
+ query (Tensor): Input query with shape
+ `(num_query, bs, embed_dims)`.
+
+ Returns:
+ Tensor: Results with shape [1, num_query, bs, embed_dims] when
+ return_intermediate is `False`, otherwise it has shape
+ [num_layers, num_query, bs, embed_dims].
+ """
+ if not self.return_intermediate:
+ x = super().forward(query, *args, **kwargs)
+ if self.post_norm:
+ x = self.post_norm(x)[None]
+ return x
+
+ intermediate = []
+ for layer in self.layers:
+ query = layer(query, *args, **kwargs)
+ if self.return_intermediate:
+ if self.post_norm is not None:
+ intermediate.append(self.post_norm(query))
+ else:
+ intermediate.append(query)
+ return torch.stack(intermediate)
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class DeformableDetrTransformerDecoder(TransformerLayerSequence):
+ """Implements the decoder in DETR transformer.
+
+ Args:
+ return_intermediate (bool): Whether to return intermediate outputs.
+ coder_norm_cfg (dict): Config of last normalization layer. Default:
+ `LN`.
+ """
+ def __init__(self, *args, return_intermediate=False, **kwargs):
+
+ super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs)
+ self.return_intermediate = return_intermediate
+
+ def forward(self,
+ query,
+ *args,
+ reference_points=None,
+ valid_ratios=None,
+ reg_branches=None,
+ **kwargs):
+ """Forward function for `TransformerDecoder`.
+
+ Args:
+ query (Tensor): Input query with shape
+ `(num_query, bs, embed_dims)`.
+ reference_points (Tensor): The reference
+ points of offset. has shape
+ (bs, num_query, 4) when as_two_stage,
+ otherwise has shape ((bs, num_query, 2).
+ valid_ratios (Tensor): The radios of valid
+ points on the feature map, has shape
+ (bs, num_levels, 2)
+ reg_branch: (obj:`nn.ModuleList`): Used for
+ refining the regression results. Only would
+ be passed when with_box_refine is True,
+ otherwise would be passed a `None`.
+
+ Returns:
+ Tensor: Results with shape [1, num_query, bs, embed_dims] when
+ return_intermediate is `False`, otherwise it has shape
+ [num_layers, num_query, bs, embed_dims].
+ """
+ output = query
+ intermediate = []
+ intermediate_reference_points = []
+ for lid, layer in enumerate(self.layers):
+ if reference_points.shape[-1] == 4:
+ reference_points_input = reference_points[:, :, None] * \
+ torch.cat([valid_ratios, valid_ratios], -1)[:, None]
+ else:
+ assert reference_points.shape[-1] == 2
+ reference_points_input = reference_points[:, :, None] * \
+ valid_ratios[:, None]
+ output = layer(output,
+ *args,
+ reference_points=reference_points_input,
+ **kwargs)
+ output = output.permute(1, 0, 2)
+
+ if reg_branches is not None:
+ tmp = reg_branches[lid](output)
+ if reference_points.shape[-1] == 4:
+ new_reference_points = tmp + inverse_sigmoid(
+ reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ else:
+ assert reference_points.shape[-1] == 2
+ new_reference_points = tmp
+ new_reference_points[..., :2] = tmp[
+ ..., :2] + inverse_sigmoid(reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ reference_points = new_reference_points.detach()
+
+ output = output.permute(1, 0, 2)
+ if self.return_intermediate:
+ intermediate.append(output)
+ intermediate_reference_points.append(reference_points)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate), torch.stack(
+ intermediate_reference_points)
+
+ return output, reference_points
+
+
+@TRANSFORMER.register_module()
+class Transformer(BaseModule):
+ """Implements the DETR transformer.
+
+ Following the official DETR implementation, this module copy-paste
+ from torch.nn.Transformer with modifications:
+
+ * positional encodings are passed in MultiheadAttention
+ * extra LN at the end of encoder is removed
+ * decoder returns a stack of activations from all decoding layers
+
+ See `paper: End-to-End Object Detection with Transformers
+ `_ for details.
+
+ Args:
+ encoder (`mmcv.ConfigDict` | Dict): Config of
+ TransformerEncoder. Defaults to None.
+ decoder ((`mmcv.ConfigDict` | Dict)): Config of
+ TransformerDecoder. Defaults to None
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Defaults to None.
+ """
+ def __init__(self, encoder=None, decoder=None, init_cfg=None):
+ super(Transformer, self).__init__(init_cfg=init_cfg)
+ self.encoder = build_transformer_layer_sequence(encoder)
+ self.decoder = build_transformer_layer_sequence(decoder)
+ self.embed_dims = self.encoder.embed_dims
+
+ def init_weights(self):
+ # follow the official DETR to init parameters
+ for m in self.modules():
+ if hasattr(m, 'weight') and m.weight.dim() > 1:
+ xavier_init(m, distribution='uniform')
+ self._is_init = True
+
+ def forward(self, x, mask, query_embed, pos_embed):
+ """Forward function for `Transformer`.
+
+ Args:
+ x (Tensor): Input query with shape [bs, c, h, w] where
+ c = embed_dims.
+ mask (Tensor): The key_padding_mask used for encoder and decoder,
+ with shape [bs, h, w].
+ query_embed (Tensor): The query embedding for decoder, with shape
+ [num_query, c].
+ pos_embed (Tensor): The positional encoding for encoder and
+ decoder, with the same shape as `x`.
+
+ Returns:
+ tuple[Tensor]: results of decoder containing the following tensor.
+
+ - out_dec: Output from decoder. If return_intermediate_dec \
+ is True output has shape [num_dec_layers, bs,
+ num_query, embed_dims], else has shape [1, bs, \
+ num_query, embed_dims].
+ - memory: Output results from encoder, with shape \
+ [bs, embed_dims, h, w].
+ """
+ bs, c, h, w = x.shape
+ # use `view` instead of `flatten` for dynamically exporting to ONNX
+ x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c]
+ pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(
+ 1, bs, 1) # [num_query, dim] -> [num_query, bs, dim]
+ mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w]
+ memory = self.encoder(query=x,
+ key=None,
+ value=None,
+ query_pos=pos_embed,
+ query_key_padding_mask=mask)
+ target = torch.zeros_like(query_embed)
+ # out_dec: [num_layers, num_query, bs, dim]
+ out_dec = self.decoder(query=target,
+ key=memory,
+ value=memory,
+ key_pos=pos_embed,
+ query_pos=query_embed,
+ key_padding_mask=mask)
+ out_dec = out_dec.transpose(1, 2)
+ memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)
+ return out_dec, memory
+
+
+@TRANSFORMER.register_module()
+class DeformableDetrTransformer(Transformer):
+ """Implements the DeformableDETR transformer.
+
+ Args:
+ as_two_stage (bool): Generate query from encoder features.
+ Default: False.
+ num_feature_levels (int): Number of feature maps from FPN:
+ Default: 4.
+ two_stage_num_proposals (int): Number of proposals when set
+ `as_two_stage` as True. Default: 300.
+ """
+ def __init__(self,
+ as_two_stage=False,
+ num_feature_levels=4,
+ two_stage_num_proposals=300,
+ **kwargs):
+ super(DeformableDetrTransformer, self).__init__(**kwargs)
+ self.as_two_stage = as_two_stage
+ self.num_feature_levels = num_feature_levels
+ self.two_stage_num_proposals = two_stage_num_proposals
+ self.embed_dims = self.encoder.embed_dims
+ self.init_layers()
+
+ def init_layers(self):
+ """Initialize layers of the DeformableDetrTransformer."""
+ self.level_embeds = nn.Parameter(
+ torch.Tensor(self.num_feature_levels, self.embed_dims))
+
+ if self.as_two_stage:
+ self.enc_output = nn.Linear(self.embed_dims, self.embed_dims)
+ self.enc_output_norm = nn.LayerNorm(self.embed_dims)
+ self.pos_trans = nn.Linear(self.embed_dims * 2,
+ self.embed_dims * 2)
+ self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2)
+ else:
+ self.reference_points = nn.Linear(self.embed_dims, 2)
+
+ def init_weights(self):
+ """Initialize the transformer weights."""
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MultiScaleDeformableAttention):
+ m.init_weights()
+ if not self.as_two_stage:
+ xavier_init(self.reference_points, distribution='uniform', bias=0.)
+ normal_(self.level_embeds)
+
+ def gen_encoder_output_proposals(self, memory, memory_padding_mask,
+ spatial_shapes):
+ """Generate proposals from encoded memory.
+
+ Args:
+ memory (Tensor) : The output of encoder,
+ has shape (bs, num_key, embed_dim). num_key is
+ equal the number of points on feature map from
+ all level.
+ memory_padding_mask (Tensor): Padding mask for memory.
+ has shape (bs, num_key).
+ spatial_shapes (Tensor): The shape of all feature maps.
+ has shape (num_level, 2).
+
+ Returns:
+ tuple: A tuple of feature map and bbox prediction.
+
+ - output_memory (Tensor): The input of decoder, \
+ has shape (bs, num_key, embed_dim). num_key is \
+ equal the number of points on feature map from \
+ all levels.
+ - output_proposals (Tensor): The normalized proposal \
+ after a inverse sigmoid, has shape \
+ (bs, num_keys, 4).
+ """
+
+ N, S, C = memory.shape
+ proposals = []
+ _cur = 0
+ for lvl, (H, W) in enumerate(spatial_shapes):
+ mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H * W)].view(
+ N, H, W, 1)
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
+
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(0,
+ H - 1,
+ H,
+ dtype=torch.float32,
+ device=memory.device),
+ torch.linspace(0,
+ W - 1,
+ W,
+ dtype=torch.float32,
+ device=memory.device))
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
+
+ scale = torch.cat([valid_W.unsqueeze(-1),
+ valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2)
+ grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale
+ wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
+ proposal = torch.cat((grid, wh), -1).view(N, -1, 4)
+ proposals.append(proposal)
+ _cur += (H * W)
+ output_proposals = torch.cat(proposals, 1)
+ output_proposals_valid = ((output_proposals > 0.01) &
+ (output_proposals < 0.99)).all(-1,
+ keepdim=True)
+ output_proposals = torch.log(output_proposals / (1 - output_proposals))
+ output_proposals = output_proposals.masked_fill(
+ memory_padding_mask.unsqueeze(-1), float('inf'))
+ output_proposals = output_proposals.masked_fill(
+ ~output_proposals_valid, float('inf'))
+
+ output_memory = memory
+ output_memory = output_memory.masked_fill(
+ memory_padding_mask.unsqueeze(-1), float(0))
+ output_memory = output_memory.masked_fill(~output_proposals_valid,
+ float(0))
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
+ return output_memory, output_proposals
+
+ @staticmethod
+ def get_reference_points(spatial_shapes, valid_ratios, device):
+ """Get the reference points used in decoder.
+
+ Args:
+ spatial_shapes (Tensor): The shape of all
+ feature maps, has shape (num_level, 2).
+ valid_ratios (Tensor): The radios of valid
+ points on the feature map, has shape
+ (bs, num_levels, 2)
+ device (obj:`device`): The device where
+ reference_points should be.
+
+ Returns:
+ Tensor: reference points used in decoder, has \
+ shape (bs, num_keys, num_levels, 2).
+ """
+ reference_points_list = []
+ for lvl, (H, W) in enumerate(spatial_shapes):
+ # TODO check this 0.5
+ ref_y, ref_x = torch.meshgrid(
+ torch.linspace(0.5,
+ H - 0.5,
+ H,
+ dtype=torch.float32,
+ device=device),
+ torch.linspace(0.5,
+ W - 0.5,
+ W,
+ dtype=torch.float32,
+ device=device))
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] *
+ H)
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] *
+ W)
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+ return reference_points
+
+ def get_valid_ratio(self, mask):
+ """Get the valid radios of feature maps of all level."""
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+ def get_proposal_pos_embed(self,
+ proposals,
+ num_pos_feats=128,
+ temperature=10000):
+ """Get the position embedding of proposal."""
+ scale = 2 * math.pi
+ dim_t = torch.arange(num_pos_feats,
+ dtype=torch.float32,
+ device=proposals.device)
+ dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
+ # N, L, 4
+ proposals = proposals.sigmoid() * scale
+ # N, L, 4, 128
+ pos = proposals[:, :, :, None] / dim_t
+ # N, L, 4, 64, 2
+ pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()),
+ dim=4).flatten(2)
+ return pos
+
+ def forward(self,
+ mlvl_feats,
+ mlvl_masks,
+ query_embed,
+ mlvl_pos_embeds,
+ reg_branches=None,
+ cls_branches=None,
+ smpl_branches=None,
+ **kwargs):
+ """Forward function for `Transformer`.
+
+ Args:
+ mlvl_feats (list(Tensor)): Input queries from
+ different level. Each element has shape
+ [bs, embed_dims, h, w].
+ mlvl_masks (list(Tensor)): The key_padding_mask from
+ different level used for encoder and decoder,
+ each element has shape [bs, h, w].
+ query_embed (Tensor): The query embedding for decoder,
+ with shape [num_query, c].
+ mlvl_pos_embeds (list(Tensor)): The positional encoding
+ of feats from different level, has the shape
+ [bs, embed_dims, h, w].
+ reg_branches (obj:`nn.ModuleList`): Regression heads for
+ feature maps from each decoder layer. Only would
+ be passed when
+ `with_box_refine` is True. Default to None.
+ cls_branches (obj:`nn.ModuleList`): Classification heads
+ for feature maps from each decoder layer. Only would
+ be passed when `as_two_stage`
+ is True. Default to None.
+
+
+ Returns:
+ tuple[Tensor]: results of decoder containing the following tensor.
+
+ - inter_states: Outputs from decoder. If
+ return_intermediate_dec is True output has shape \
+ (num_dec_layers, bs, num_query, embed_dims), else has \
+ shape (1, bs, num_query, embed_dims).
+ - init_reference_out: The initial value of reference \
+ points, has shape (bs, num_queries, 4).
+ - inter_references_out: The internal value of reference \
+ points in decoder, has shape \
+ (num_dec_layers, bs,num_query, embed_dims)
+ - enc_outputs_class: The classification score of \
+ proposals generated from \
+ encoder's feature maps, has shape \
+ (batch, h*w, num_classes). \
+ Only would be returned when `as_two_stage` is True, \
+ otherwise None.
+ - enc_outputs_coord_unact: The regression results \
+ generated from encoder's feature maps., has shape \
+ (batch, h*w, 4). Only would \
+ be returned when `as_two_stage` is True, \
+ otherwise None.
+ """
+ assert self.as_two_stage or query_embed is not None
+
+ feat_flatten = []
+ mask_flatten = []
+ lvl_pos_embed_flatten = []
+ spatial_shapes = []
+ for lvl, (feat, mask, pos_embed) in enumerate(
+ zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
+ bs, c, h, w = feat.shape
+ spatial_shape = (h, w)
+ spatial_shapes.append(spatial_shape)
+ feat = feat.flatten(2).transpose(1, 2)
+ mask = mask.flatten(1)
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
+ lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
+ feat_flatten.append(feat)
+ mask_flatten.append(mask)
+ feat_flatten = torch.cat(feat_flatten, 1)
+ mask_flatten = torch.cat(mask_flatten, 1)
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
+ spatial_shapes = torch.as_tensor(spatial_shapes,
+ dtype=torch.long,
+ device=feat_flatten.device)
+ level_start_index = torch.cat((spatial_shapes.new_zeros(
+ (1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ valid_ratios = torch.stack(
+ [self.get_valid_ratio(m) for m in mlvl_masks], 1)
+
+ reference_points = \
+ self.get_reference_points(spatial_shapes,
+ valid_ratios,
+ device=feat.device)
+
+ feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
+ lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(
+ 1, 0, 2) # (H*W, bs, embed_dims)
+ memory = self.encoder(query=feat_flatten,
+ key=None,
+ value=None,
+ query_pos=lvl_pos_embed_flatten,
+ query_key_padding_mask=mask_flatten,
+ spatial_shapes=spatial_shapes,
+ reference_points=reference_points,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ **kwargs)
+
+ memory = memory.permute(1, 0, 2)
+ bs, _, c = memory.shape
+ if self.as_two_stage:
+ output_memory, output_proposals = \
+ self.gen_encoder_output_proposals(
+ memory, mask_flatten, spatial_shapes)
+ enc_outputs_class = cls_branches[self.decoder.num_layers](
+ output_memory)
+ enc_outputs_coord_unact = \
+ reg_branches[
+ self.decoder.num_layers](output_memory) + output_proposals
+
+ topk = self.two_stage_num_proposals
+ # We only use the first channel in enc_outputs_class as foreground,
+ # the other (num_classes - 1) channels are actually not used.
+ # Its targets are set to be 0s, which indicates the first
+ # class (foreground) because we use [0, num_classes - 1] to
+ # indicate class labels, background class is indicated by
+ # num_classes (similar convention in RPN).
+ # See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa
+ # This follows the official implementation of Deformable DETR.
+ topk_proposals = torch.topk(enc_outputs_class[..., 0], topk,
+ dim=1)[1]
+ topk_coords_unact = torch.gather(
+ enc_outputs_coord_unact, 1,
+ topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
+ topk_coords_unact = topk_coords_unact.detach()
+ reference_points = topk_coords_unact.sigmoid()
+ init_reference_out = reference_points
+ pos_trans_out = self.pos_trans_norm(
+ self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
+ query_pos, query = torch.split(pos_trans_out, c, dim=2)
+ else:
+ query_pos, query = torch.split(query_embed, c, dim=1)
+ query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
+ query = query.unsqueeze(0).expand(bs, -1, -1)
+ reference_points = self.reference_points(query_pos).sigmoid()
+ init_reference_out = reference_points
+
+ # decoder
+ query = query.permute(1, 0, 2)
+ memory = memory.permute(1, 0, 2)
+ query_pos = query_pos.permute(1, 0, 2)
+ inter_states, inter_references = self.decoder(
+ query=query,
+ key=None,
+ value=memory,
+ query_pos=query_pos,
+ key_padding_mask=mask_flatten,
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ reg_branches=reg_branches,
+ smpl_branches=smpl_branches,
+ **kwargs)
+
+ inter_references_out = inter_references
+ if self.as_two_stage:
+ return inter_states, init_reference_out,\
+ inter_references_out, enc_outputs_class,\
+ enc_outputs_coord_unact
+ return inter_states, init_reference_out, \
+ inter_references_out, None, None
diff --git a/detrsmpl/utils/__init__.py b/detrsmpl/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/detrsmpl/utils/camera_utils.py b/detrsmpl/utils/camera_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7ef6e252297c7503b213a30f7dbc76237551d8c
--- /dev/null
+++ b/detrsmpl/utils/camera_utils.py
@@ -0,0 +1,210 @@
+import copy
+import os
+from typing import Iterable, Optional, Union
+
+import numpy as np
+import torch
+from pytorch3d.renderer.cameras import CamerasBase
+
+from detrsmpl.core.cameras import build_cameras
+from detrsmpl.core.conventions.cameras.convert_convention import (
+ convert_camera_matrix,
+ convert_world_view,
+)
+from detrsmpl.core.conventions.cameras.convert_projection import \
+ convert_perspective_to_weakperspective # prevent yapf isort conflict
+from detrsmpl.models.body_models.builder import build_body_model
+from detrsmpl.utils.transforms import aa_to_rotmat, rotmat_to_aa
+
+
+def convert_smpl_from_opencv_calibration(
+ R: Union[np.ndarray, torch.Tensor],
+ T: Union[np.ndarray, torch.Tensor],
+ K: Optional[Union[np.ndarray, torch.Tensor]] = None,
+ resolution: Optional[Union[Iterable[int], int]] = None,
+ verts: Optional[Union[np.ndarray, torch.Tensor]] = None,
+ poses: Optional[Union[np.ndarray, torch.Tensor]] = None,
+ transl: Optional[Union[np.ndarray, torch.Tensor]] = None,
+ model_path: Optional[str] = None,
+ betas: Optional[Union[np.ndarray, torch.Tensor]] = None,
+ model_type: Optional[str] = 'smpl',
+ gender: Optional[str] = 'neutral'):
+ """Convert opencv calibration smpl poses&transl parameters to model based
+ poses&transl or verts.
+
+ Args:
+ R (Union[np.ndarray, torch.Tensor]): (frame, 3, 3)
+ T (Union[np.ndarray, torch.Tensor]): [(frame, 3)
+ K (Optional[Union[np.ndarray, torch.Tensor]], optional):
+ (frame, 3, 3) or (frame, 4, 4). Defaults to None.
+ resolution (Optional[Union[Iterable[int], int]], optional):
+ (height, width). Defaults to None.
+ verts (Optional[Union[np.ndarray, torch.Tensor]], optional):
+ (frame, num_verts, 3). Defaults to None.
+ poses (Optional[Union[np.ndarray, torch.Tensor]], optional):
+ (frame, 72/165). Defaults to None.
+ transl (Optional[Union[np.ndarray, torch.Tensor]], optional):
+ (frame, 3). Defaults to None.
+ model_path (Optional[str], optional): model path.
+ Defaults to None.
+ betas (Optional[Union[np.ndarray, torch.Tensor]], optional):
+ (frame, 10). Defaults to None.
+ model_type (Optional[str], optional): choose in 'smpl' or 'smplx'.
+ Defaults to 'smpl'.
+ gender (Optional[str], optional): choose in 'male', 'female',
+ 'neutral'.
+ Defaults to 'neutral'.
+
+ Raises:
+ ValueError: wrong input poses or transl.
+
+ Returns:
+ Tuple[torch.Tensor]: Return converted poses, transl, pred_cam
+ or verts, pred_cam.
+ """
+ R_, T_ = convert_world_view(R, T)
+
+ RT = torch.eye(4, 4)[None]
+ RT[:, :3, :3] = R_
+ RT[:, :3, 3] = T_
+
+ if verts is not None:
+ poses = None
+ betas = None
+ transl = None
+ else:
+ assert poses is not None
+ assert transl is not None
+ if isinstance(poses, dict):
+ poses = copy.deepcopy(poses)
+ for k in poses:
+ if isinstance(poses[k], np.ndarray):
+ poses[k] = torch.Tensor(poses[k])
+ elif isinstance(poses, np.ndarray):
+ poses = torch.Tensor(poses)
+ elif isinstance(poses, torch.Tensor):
+ poses = poses.clone()
+ else:
+ raise ValueError(f'Wrong data type of poses: {type(poses)}.')
+
+ if isinstance(transl, np.ndarray):
+ transl = torch.Tensor(transl)
+ elif isinstance(transl, torch.Tensor):
+ transl = transl.clone()
+ else:
+ raise ValueError('Should pass valid `transl`.')
+ transl = transl.view(-1, 3)
+
+ if isinstance(betas, np.ndarray):
+ betas = torch.Tensor(betas)
+ elif isinstance(betas, torch.Tensor):
+ betas = betas.clone()
+
+ body_model = build_body_model(
+ dict(type=model_type,
+ model_path=os.path.join(model_path, model_type),
+ gender=gender,
+ model_type=model_type))
+ if isinstance(poses, dict):
+ poses.update({'transl': transl, 'betas': betas})
+ else:
+ if isinstance(poses, np.ndarray):
+ poses = torch.tensor(poses)
+ poses = body_model.tensor2dict(full_pose=poses,
+ transl=transl,
+ betas=betas)
+ model_output = body_model(**poses)
+ verts = model_output['vertices']
+
+ global_orient = poses['global_orient']
+ global_orient = rotmat_to_aa(R_ @ aa_to_rotmat(global_orient))
+ poses['global_orient'] = global_orient
+ poses['transl'] = None
+ verts_rotated = model_output['vertices']
+ rotated_pose = body_model.dict2tensor(poses)
+
+ verts_converted = verts.clone().view(-1, 3)
+ verts_converted = RT @ torch.cat(
+ [verts_converted,
+ torch.ones(verts_converted.shape[0], 1)], dim=-1).unsqueeze(-1)
+ verts_converted = verts_converted.squeeze(-1)
+ verts_converted = verts_converted[:, :3] / verts_converted[:, 3:]
+ verts_converted = verts_converted.view(verts.shape[0], -1, 3)
+ num_frame = verts_converted.shape[0]
+ if poses is not None:
+ transl = torch.mean(verts_converted - verts_rotated, dim=1)
+
+ orig_cam = None
+ if K is not None:
+ zmean = torch.mean(verts_converted, dim=1)[:, 2]
+
+ K, _, _ = convert_camera_matrix(K,
+ is_perspective=True,
+ convention_dst='opencv',
+ convention_src='opencv',
+ in_ndc_dst=True,
+ in_ndc_src=False,
+ resolution_src=resolution)
+ K = K.repeat(num_frame, 1, 1)
+
+ orig_cam = convert_perspective_to_weakperspective(
+ K=K, zmean=zmean, in_ndc=True, resolution=resolution)
+
+ if poses is not None:
+ orig_cam[:, 0, 3] += transl[:, 0]
+ orig_cam[:, 1, 3] += transl[:, 1]
+ if poses is not None:
+ return rotated_pose, orig_cam
+ else:
+ return verts_converted, orig_cam
+
+
+def project_points(points3d: Union[np.ndarray, torch.Tensor],
+ cameras: CamerasBase = None,
+ resolution: Iterable[int] = None,
+ K: Union[torch.Tensor, np.ndarray] = None,
+ R: Union[torch.Tensor, np.ndarray] = None,
+ T: Union[torch.Tensor, np.ndarray] = None,
+ convention: str = 'opencv',
+ in_ndc: bool = False) -> Union[torch.Tensor, np.ndarray]:
+ """Project 3d points to image.
+
+ Args:
+ points3d (Union[np.ndarray, torch.Tensor]): shape could be (..., 3).
+ cameras (CamerasBase): pytorch3d cameras or mmhuman3d cameras.
+ resolution (Iterable[int]): (height, width) for rectangle or width for
+ square.
+ K (Union[torch.Tensor, np.ndarray], optional): intrinsic matrix.
+ Defaults to None.
+ R (Union[torch.Tensor, np.ndarray], optional): rotation matrix.
+ Defaults to None.
+ T (Union[torch.Tensor, np.ndarray], optional): translation matrix.
+ Defaults to None.
+ convention (str, optional): camera convention. Defaults to 'opencv'.
+ in_ndc (bool, optional): whether in NDC. Defaults to False.
+
+ Returns:
+ Union[torch.Tensor, np.ndarray]: transformed points of shape (..., 2).
+ """
+ if cameras is None:
+ cameras = build_cameras(
+ dict(type='perspective',
+ convention=convention,
+ in_ndc=in_ndc,
+ resolution=resolution,
+ K=K,
+ R=R,
+ T=T))
+ if cameras.get_image_size() is not None:
+ image_size = cameras.get_image_size()
+ else:
+ image_size = resolution
+ if isinstance(points3d, np.ndarray):
+ points3d = torch.Tensor(points3d[..., :3]).to(cameras.device)
+ points2d = cameras.transform_points_screen(
+ points3d, image_size=image_size).cpu().numpy()
+ elif isinstance(points3d, torch.Tensor):
+ points3d = points3d[..., :3].to(cameras.device)
+ points2d = cameras.transform_points_screen(points3d,
+ image_size=image_size)
+ return points2d
diff --git a/detrsmpl/utils/collect_env.py b/detrsmpl/utils/collect_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1733944e021354df70349d38abdd68f1b705228
--- /dev/null
+++ b/detrsmpl/utils/collect_env.py
@@ -0,0 +1,16 @@
+from mmcv.utils import collect_env as collect_base_env
+from mmcv.utils import get_git_hash
+
+import detrsmpl
+
+
+def collect_env():
+ """Collect the information of the running environments."""
+ env_info = collect_base_env()
+ env_info['MMHuman3d'] = detrsmpl.__version__ + '+' + get_git_hash()[:7]
+ return env_info
+
+
+if __name__ == '__main__':
+ for name, val in collect_env().items():
+ print(f'{name}: {val}')
diff --git a/detrsmpl/utils/demo_utils.py b/detrsmpl/utils/demo_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ab077444573c2b645f3c2714c2ad0a818d94da9
--- /dev/null
+++ b/detrsmpl/utils/demo_utils.py
@@ -0,0 +1,823 @@
+import colorsys
+import os
+from collections import defaultdict
+from contextlib import contextmanager
+from functools import partial
+from pathlib import Path
+
+import mmcv
+import numpy as np
+from mmcv import Timer
+from scipy import interpolate
+
+from detrsmpl.core.post_processing import build_post_processing
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+
+def xyxy2xywh(bbox_xyxy):
+ """Transform the bbox format from x1y1x2y2 to xywh.
+
+ Args:
+ bbox_xyxy (np.ndarray): Bounding boxes (with scores), shaped (n, 4) or
+ (n, 5). (left, top, right, bottom, [score])
+
+ Returns:
+ np.ndarray: Bounding boxes (with scores),
+ shaped (n, 4) or (n, 5). (left, top, width, height, [score])
+ """
+ if not isinstance(bbox_xyxy, np.ndarray):
+ raise TypeError(
+ f'Input type is {type(bbox_xyxy)}, which should be numpy.ndarray.')
+ bbox_xywh = bbox_xyxy.copy()
+ bbox_xywh[..., 2] = bbox_xywh[..., 2] - bbox_xywh[..., 0]
+ bbox_xywh[..., 3] = bbox_xywh[..., 3] - bbox_xywh[..., 1]
+
+ return bbox_xywh
+
+
+def xywh2xyxy(bbox_xywh):
+ """Transform the bbox format from xywh to x1y1x2y2.
+
+ Args:
+ bbox_xywh (np.ndarray): Bounding boxes (with scores), shaped
+ (n, 4) or (n, 5). (left, top, width, height, [score])
+
+ Returns:
+ np.ndarray: Bounding boxes (with scores),
+ shaped (n, 4) or (n, 5). (left, top, right, bottom, [score])
+ """
+ if not isinstance(bbox_xywh, np.ndarray):
+ raise TypeError(
+ f'Input type is {type(bbox_xywh)}, which should be numpy.ndarray.')
+ bbox_xyxy = bbox_xywh.copy()
+ bbox_xyxy[..., 2] = bbox_xyxy[..., 2] + bbox_xyxy[..., 0] - 1
+ bbox_xyxy[..., 3] = bbox_xyxy[..., 3] + bbox_xyxy[..., 1] - 1
+
+ return bbox_xyxy
+
+
+def box2cs(bbox_xywh, aspect_ratio=1.0, bbox_scale_factor=1.25):
+ """Convert xywh coordinates to center and scale.
+
+ Args:
+ bbox_xywh (numpy.ndarray): the height of the bbox_xywh
+ aspect_ratio (int, optional): Defaults to 1.0
+ bbox_scale_factor (float, optional): Defaults to 1.25
+ Returns:
+ numpy.ndarray: center of the bbox
+ numpy.ndarray: the scale of the bbox w & h
+ """
+ if not isinstance(bbox_xywh, np.ndarray):
+ raise TypeError(
+ f'Input type is {type(bbox_xywh)}, which should be numpy.ndarray.')
+
+ bbox_xywh = bbox_xywh.copy()
+ pixel_std = 1
+ center = np.stack([
+ bbox_xywh[..., 0] + bbox_xywh[..., 2] * 0.5,
+ bbox_xywh[..., 1] + bbox_xywh[..., 3] * 0.5
+ ], -1)
+
+ mask_h = bbox_xywh[..., 2] > aspect_ratio * bbox_xywh[..., 3]
+ mask_w = ~mask_h
+
+ bbox_xywh[mask_h, 3] = bbox_xywh[mask_h, 2] / aspect_ratio
+ bbox_xywh[mask_w, 2] = bbox_xywh[mask_w, 3] * aspect_ratio
+ scale = np.stack([
+ bbox_xywh[..., 2] * 1.0 / pixel_std,
+ bbox_xywh[..., 3] * 1.0 / pixel_std
+ ], -1)
+ scale = scale * bbox_scale_factor
+
+ return center, scale
+
+
+def convert_crop_cam_to_orig_img(cam: np.ndarray,
+ bbox: np.ndarray,
+ img_width: int,
+ img_height: int,
+ aspect_ratio: float = 1.0,
+ bbox_scale_factor: float = 1.25,
+ bbox_format: Literal['xyxy', 'xywh',
+ 'cs'] = 'xyxy'):
+ """This function is modified from [VIBE](https://github.com/
+ mkocabas/VIBE/blob/master/lib/utils/demo_utils.py#L242-L259). Original
+ license please see docs/additional_licenses.md.
+
+ Args:
+ cam (np.ndarray): cam (ndarray, shape=(frame, 3) or
+ (frame,num_person, 3)):
+ weak perspective camera in cropped img coordinates
+ bbox (np.ndarray): bbox coordinates
+ img_width (int): original image width
+ img_height (int): original image height
+ aspect_ratio (float, optional): Defaults to 1.0.
+ bbox_scale_factor (float, optional): Defaults to 1.25.
+ bbox_format (Literal['xyxy', 'xywh', 'cs']): Defaults to 'xyxy'.
+ 'xyxy' means the left-up point and right-bottomn point of the
+ bbox.
+ 'xywh' means the left-up point and the width and height of the
+ bbox.
+ 'cs' means the center of the bbox (x,y) and the scale of the
+ bbox w & h.
+ Returns:
+ orig_cam: shape = (frame, 4) or (frame, num_person, 4)
+ """
+ if not isinstance(bbox, np.ndarray):
+ raise TypeError(
+ f'Input type is {type(bbox)}, which should be numpy.ndarray.')
+ bbox = bbox.copy()
+ if bbox_format == 'xyxy':
+ bbox_xywh = xyxy2xywh(bbox)
+ center, scale = box2cs(bbox_xywh, aspect_ratio, bbox_scale_factor)
+ bbox_cs = np.concatenate([center, scale], axis=-1)
+ elif bbox_format == 'xywh':
+ center, scale = box2cs(bbox, aspect_ratio, bbox_scale_factor)
+ bbox_cs = np.concatenate([center, scale], axis=-1)
+ elif bbox_format == 'cs':
+ bbox_cs = bbox
+ else:
+ raise ValueError('Only supports the format of `xyxy`, `cs` and `xywh`')
+
+ cx, cy, h = bbox_cs[..., 0], bbox_cs[..., 1], bbox_cs[..., 2] + 1e-6
+ hw, hh = img_width / 2., img_height / 2.
+ sx = cam[..., 0] * (1. / (img_width / h))
+ sy = cam[..., 0] * (1. / (img_height / h))
+ tx = ((cx - hw) / hw / (sx + 1e-6)) + cam[..., 1]
+ ty = ((cy - hh) / hh / (sy + 1e-6)) + cam[..., 2]
+
+ orig_cam = np.stack([sx, sy, tx, ty], axis=-1)
+ return orig_cam
+
+
+def convert_bbox_to_intrinsic(bboxes: np.ndarray,
+ img_width: int = 224,
+ img_height: int = 224,
+ bbox_scale_factor: float = 1.25,
+ bbox_format: Literal['xyxy', 'xywh'] = 'xyxy'):
+ """Convert bbox to intrinsic parameters.
+
+ Args:
+ bbox (np.ndarray): (frame, num_person, 4), (frame, 4), or (4,)
+ img_width (int): image width of training data.
+ img_height (int): image height of training data.
+ bbox_scale_factor (float): scale factor for expanding the bbox.
+ bbox_format (Literal['xyxy', 'xywh'] ): 'xyxy' means the left-up point
+ and right-bottomn point of the bbox.
+ 'xywh' means the left-up point and the width and height of the
+ bbox.
+ Returns:
+ np.ndarray: (frame, num_person, 3, 3), (frame, 3, 3) or (3,3)
+ """
+ if not isinstance(bboxes, np.ndarray):
+ raise TypeError(
+ f'Input type is {type(bboxes)}, which should be numpy.ndarray.')
+ assert bbox_format in ['xyxy', 'xywh']
+
+ if bbox_format == 'xyxy':
+ bboxes = xyxy2xywh(bboxes)
+
+ center_x = bboxes[..., 0] + bboxes[..., 2] / 2.0
+ center_y = bboxes[..., 1] + bboxes[..., 3] / 2.0
+
+ W = np.max(bboxes[..., 2:], axis=-1) * bbox_scale_factor
+
+ num_frame = bboxes.shape[0]
+ if bboxes.ndim == 3:
+ num_person = bboxes.shape[1]
+ Ks = np.zeros((num_frame, num_person, 3, 3))
+ elif bboxes.ndim == 2:
+ Ks = np.zeros((num_frame, 3, 3))
+ elif bboxes.ndim == 1:
+ Ks = np.zeros((3, 3))
+ else:
+ raise ValueError('Wrong input bboxes shape {bboxes.shape}')
+
+ Ks[..., 0, 0] = W / img_width
+ Ks[..., 1, 1] = W / img_height
+ Ks[..., 0, 2] = center_x - W / 2.0
+ Ks[..., 1, 2] = center_y - W / 2.0
+ Ks[..., 2, 2] = 1
+ return Ks
+
+
+def get_default_hmr_intrinsic(num_frame=1,
+ focal_length=1000,
+ det_width=224,
+ det_height=224) -> np.ndarray:
+ """Get default hmr intrinsic, defined by how you trained.
+
+ Args:
+ num_frame (int, optional): num of frames. Defaults to 1.
+ focal_length (int, optional): defined same as your training.
+ Defaults to 1000.
+ det_width (int, optional): the size you used to detect.
+ Defaults to 224.
+ det_height (int, optional): the size you used to detect.
+ Defaults to 224.
+
+ Returns:
+ np.ndarray: shape of (N, 3, 3)
+ """
+ K = np.zeros((num_frame, 3, 3))
+ K[:, 0, 0] = focal_length
+ K[:, 1, 1] = focal_length
+ K[:, 0, 2] = det_width / 2
+ K[:, 1, 2] = det_height / 2
+ K[:, 2, 2] = 1
+ return K
+
+
+def convert_kp2d_to_bbox(
+ kp2d: np.ndarray,
+ bbox_format: Literal['xyxy', 'xywh'] = 'xyxy') -> np.ndarray:
+ """Convert kp2d to bbox.
+
+ Args:
+ kp2d (np.ndarray): shape should be (num_frame, num_points, 2/3)
+ or (num_frame, num_person, num_points, 2/3).
+ bbox_format (Literal['xyxy', 'xywh'], optional): Defaults to 'xyxy'.
+
+ Returns:
+ np.ndarray: shape will be (num_frame, num_person, 4)
+ """
+ assert bbox_format in ['xyxy', 'xywh']
+ if kp2d.ndim == 2:
+ kp2d = kp2d[None, None]
+ elif kp2d.ndim == 3:
+ kp2d = kp2d[:, None]
+ num_frame, num_person, _, _ = kp2d.shape
+ x1 = np.max(kp2d[..., 0], axis=-2)
+ y1 = np.max(kp2d[..., 1], axis=-2)
+ x2 = np.max(kp2d[..., 2], axis=-2)
+ y2 = np.max(kp2d[..., 3], axis=-2)
+ bbox = np.concatenate([x1, y1, x2, y2], axis=-1)
+ assert bbox.shape == (num_frame, num_person, 4)
+ if bbox_format == 'xywh':
+ bbox = xyxy2xywh(bbox)
+ return bbox
+
+
+def convert_verts_to_cam_coord(verts,
+ pred_cams,
+ bboxes_xy,
+ focal_length=5000.,
+ bbox_scale_factor=1.25,
+ bbox_format='xyxy'):
+ """Convert vertices from the world coordinate to camera coordinate.
+
+ Args:
+ verts ([np.ndarray]): The vertices in the world coordinate.
+ The shape is (frame,num_person,6890,3), (frame,6890,3),
+ or (6890,3).
+ pred_cams ([np.ndarray]): Camera parameters estimated by HMR or SPIN.
+ The shape is (frame,num_person,3), (frame,3), or (3,).
+ bboxes_xy ([np.ndarray]): (frame, num_person, 4|5), (frame, 4|5),
+ or (4|5,)
+ focal_length ([float],optional): Defined same as your training.
+ bbox_scale_factor (float): scale factor for expanding the bbox.
+ bbox_format (Literal['xyxy', 'xywh'] ): 'xyxy' means the left-up point
+ and right-bottomn point of the bbox.
+ 'xywh' means the left-up point and the width and height of the
+ bbox.
+ Returns:
+ np.ndarray: The vertices in the camera coordinate.
+ The shape is (frame,num_person,6890,3) or (frame,6890,3).
+ np.ndarray: The intrinsic parameters of the pred_cam.
+ The shape is (num_frame, 3, 3).
+ """
+ K0 = get_default_hmr_intrinsic(focal_length=focal_length,
+ det_height=224,
+ det_width=224)
+ K1 = convert_bbox_to_intrinsic(bboxes_xy,
+ bbox_scale_factor=bbox_scale_factor,
+ bbox_format=bbox_format)
+ # K1K0(RX+T)-> K0(K0_inv K1K0)
+ Ks = np.linalg.inv(K0) @ K1 @ K0
+ # convert vertices from world to camera
+ cam_trans = np.concatenate([
+ pred_cams[..., [1]], pred_cams[..., [2]], 2 * focal_length /
+ (224 * pred_cams[..., [0]] + 1e-9)
+ ], -1)
+ verts = verts + cam_trans[..., None, :]
+ if verts.ndim == 4:
+ verts = np.einsum('fnij,fnkj->fnki', Ks, verts)
+ elif verts.ndim == 3:
+ verts = np.einsum('fij,fkj->fki', Ks, verts)
+ elif verts.ndim == 2:
+ verts = np.einsum('fij,fkj->fki', Ks, verts[None])
+ return verts, K0
+
+
+def smooth_process(x,
+ smooth_type='savgol',
+ cfg_base_dir='configs/_base_/post_processing/'):
+ """Smooth the array with the specified smoothing type.
+
+ Args:
+ x (np.ndarray): Shape should be (frame,num_person,K,C)
+ or (frame,K,C).
+ smooth_type (str, optional): Smooth type.
+ choose in ['oneeuro', 'gaus1d', 'savgol','smoothnet',
+ 'smoothnet_windowsize8','smoothnet_windowsize16',
+ 'smoothnet_windowsize32','smoothnet_windowsize64'].
+ Defaults to 'savgol'. 'smoothnet' is default with windowsize=8.
+ cfg_base_dir (str, optional): Config base dir,
+ default configs/_base_/post_processing/
+ Raises:
+ ValueError: check the input smoothing type.
+
+ Returns:
+ np.ndarray: Smoothed data. The shape should be
+ (frame,num_person,K,C) or (frame,K,C).
+ """
+ if smooth_type == 'smoothnet':
+ smooth_type = 'smoothnet_windowsize8'
+
+ assert smooth_type in [
+ 'oneeuro', 'gaus1d', 'savgol', 'smoothnet_windowsize8',
+ 'smoothnet_windowsize16', 'smoothnet_windowsize32',
+ 'smoothnet_windowsize64'
+ ]
+
+ cfg = os.path.join(cfg_base_dir, smooth_type + '.py')
+ if isinstance(cfg, str):
+ cfg = mmcv.Config.fromfile(cfg)
+ elif not isinstance(cfg, mmcv.Config):
+ raise TypeError('config must be a filename or Config object, '
+ f'but got {type(cfg)}')
+
+ x = x.copy()
+
+ assert x.ndim == 3 or x.ndim == 4
+
+ smooth_func = build_post_processing(dict(cfg['smooth_cfg']))
+
+ if x.ndim == 4:
+ for i in range(x.shape[1]):
+ x[:, i] = smooth_func(x[:, i])
+ elif x.ndim == 3:
+ x = smooth_func(x)
+
+ return x
+
+
+def speed_up_process(x,
+ speed_up_type='deciwatch',
+ cfg_base_dir='configs/_base_/post_processing/'):
+ """Speed up the process with the specified speed up type.
+
+ Args:
+ x (np.ndarray): Shape should be (frame,num_person,K,C)
+ or (frame,K,C).
+ speed_up_type (str, optional): Speed up type.
+ choose in ['deciwatch',
+ 'deciwatch_interval5_q1',
+ 'deciwatch_interval5_q2',
+ 'deciwatch_interval5_q3',
+ 'deciwatch_interval5_q4',
+ 'deciwatch_interval5_q5',
+ 'deciwatch_interval10_q1',
+ 'deciwatch_interval10_q2',
+ 'deciwatch_interval10_q3',
+ 'deciwatch_interval10_q4',
+ 'deciwatch_interval10_q5',]. Defaults to 'deciwatch'.
+ cfg_base_dir (str, optional): Config base dir.
+ Defaults to 'configs/_base_/post_processing/'
+
+ Raises:
+ ValueError: check the input speed up type.
+
+ Returns:
+ np.ndarray: Completed data. The shape should be
+ (frame,num_person,K,C) or (frame,K,C).
+ """
+
+ if speed_up_type == 'deciwatch':
+ speed_up_type = 'deciwatch_interval5_q3'
+ assert speed_up_type in [
+ 'deciwatch_interval5_q1',
+ 'deciwatch_interval5_q2',
+ 'deciwatch_interval5_q3',
+ 'deciwatch_interval5_q4',
+ 'deciwatch_interval5_q5',
+ 'deciwatch_interval10_q1',
+ 'deciwatch_interval10_q2',
+ 'deciwatch_interval10_q3',
+ 'deciwatch_interval10_q4',
+ 'deciwatch_interval10_q5',
+ ]
+
+ cfg = os.path.join(cfg_base_dir, speed_up_type + '.py')
+ if isinstance(cfg, str):
+ cfg = mmcv.Config.fromfile(cfg)
+ elif not isinstance(cfg, mmcv.Config):
+ raise TypeError('config must be a filename or Config object, '
+ f'but got {type(cfg)}')
+ x = x.clone()
+
+ assert x.ndim == 4 or x.ndim == 5
+
+ cfg_dict = cfg['speed_up_cfg']
+ cfg_dict['device'] = x.device
+
+ speed_up_func = build_post_processing(cfg_dict)
+
+ if x.ndim == 5:
+ for i in range(x.shape[1]):
+ x[:, i] = speed_up_func(x[:, i])
+ elif x.ndim == 4:
+ x = speed_up_func(x)
+
+ return np.array(x.cpu())
+
+
+def get_speed_up_interval(speed_up_type,
+ cfg_base_dir='configs/_base_/post_processing/'):
+ """Get the interval of specific speed up type.
+
+ Args:
+ speed_up_type (str, optional): Speed up type.
+ choose in ['deciwatch',
+ 'deciwatch_interval5_q1',
+ 'deciwatch_interval5_q2',
+ 'deciwatch_interval5_q3',
+ 'deciwatch_interval5_q4',
+ 'deciwatch_interval5_q5',
+ 'deciwatch_interval10_q1',
+ 'deciwatch_interval10_q2',
+ 'deciwatch_interval10_q3',
+ 'deciwatch_interval10_q4',
+ 'deciwatch_interval10_q5',]. Defaults to 'deciwatch'.
+ cfg_base_dir (str, optional): Config base dir,
+ default configs/_base_/post_processing/
+
+ Raises:
+ ValueError: check the input speed up type.
+
+ Returns:
+ int: speed up interval
+ """
+
+ if speed_up_type == 'deciwatch':
+ speed_up_type = 'deciwatch_interval5_q3'
+ assert speed_up_type in [
+ 'deciwatch_interval5_q1',
+ 'deciwatch_interval5_q2',
+ 'deciwatch_interval5_q3',
+ 'deciwatch_interval5_q4',
+ 'deciwatch_interval5_q5',
+ 'deciwatch_interval10_q1',
+ 'deciwatch_interval10_q2',
+ 'deciwatch_interval10_q3',
+ 'deciwatch_interval10_q4',
+ 'deciwatch_interval10_q5',
+ ]
+ cfg = os.path.join(cfg_base_dir, speed_up_type + '.py')
+ if isinstance(cfg, str):
+ cfg = mmcv.Config.fromfile(cfg)
+ elif not isinstance(cfg, mmcv.Config):
+ raise TypeError('config must be a filename or Config object, '
+ f'but got {type(cfg)}')
+
+ return cfg['speed_up_cfg']['interval']
+
+
+def speed_up_interpolate(selected_frames, speed_up_frames, smpl_poses,
+ smpl_betas, pred_cams, bboxes_xyxy):
+ """Interpolate smpl_betas, pred_cams, and bboxes_xyxyx for speed up.
+
+ Args:
+ selected_frames (np.ndarray): Shape should be (selected frame number).
+ speed_up_frames (int): Total speed up frame number
+ smpl_poses (np.ndarray): selected frame smpl poses parameter
+ smpl_betas (np.ndarray): selected frame smpl shape paeameter
+ pred_cams (np.ndarray): selected frame camera parameter
+ bboxes_xyxy (np.ndarray): selected frame bbox
+
+ Returns:
+ smpl_poses (np.ndarray): interpolated frame smpl poses parameter
+ smpl_betas (np.ndarray): interpolated frame smpl shape paeameter
+ pred_cams (np.ndarray): interpolated frame camera parameter
+ bboxes_xyxy (np.ndarray): interpolated frame bbox
+ """
+ selected_frames = selected_frames[selected_frames <= speed_up_frames]
+ pred_cams[:speed_up_frames, :] = interpolate.interp1d(
+ selected_frames, pred_cams[selected_frames, :], kind='linear',
+ axis=0)(np.arange(0, max(selected_frames)))
+ bboxes_xyxy[:speed_up_frames, :] = interpolate.interp1d(
+ selected_frames,
+ bboxes_xyxy[selected_frames, :],
+ kind='linear',
+ axis=0)(np.arange(0, max(selected_frames)))
+ smpl_betas[:speed_up_frames, :] = interpolate.interp1d(
+ selected_frames, smpl_betas[selected_frames, :], kind='linear',
+ axis=0)(np.arange(0, max(selected_frames)))
+
+ return smpl_poses, smpl_betas, pred_cams, bboxes_xyxy
+
+
+def process_mmtracking_results(mmtracking_results,
+ max_track_id,
+ bbox_thr=None):
+ """Process mmtracking results.
+
+ Args:
+ mmtracking_results ([list]): mmtracking_results.
+ bbox_thr (float): threshold for bounding boxes.
+ max_track_id (int): the maximum track id.
+ Returns:
+ person_results ([list]): a list of tracked bounding boxes
+ max_track_id (int): the maximum track id.
+ instance_num (int): the number of instance.
+ """
+ person_results = []
+ # 'track_results' is changed to 'track_bboxes'
+ # in https://github.com/open-mmlab/mmtracking/pull/300
+ if 'track_bboxes' in mmtracking_results:
+ tracking_results = mmtracking_results['track_bboxes'][0]
+ elif 'track_results' in mmtracking_results:
+ tracking_results = mmtracking_results['track_results'][0]
+
+ tracking_results = np.array(tracking_results)
+
+ if bbox_thr is not None:
+ assert tracking_results.shape[-1] == 6
+ valid_idx = np.where(tracking_results[:, 5] > bbox_thr)[0]
+ tracking_results = tracking_results[valid_idx]
+
+ for track in tracking_results:
+ person = {}
+ person['track_id'] = int(track[0])
+ if max_track_id < int(track[0]):
+ max_track_id = int(track[0])
+ person['bbox'] = track[1:]
+ person_results.append(person)
+ person_results = sorted(person_results, key=lambda x: x.get('track_id', 0))
+ instance_num = len(person_results)
+ return person_results, max_track_id, instance_num
+
+
+def process_mmdet_results(mmdet_results, cat_id=1, bbox_thr=None):
+ """Process mmdet results, and return a list of bboxes.
+
+ Args:
+ mmdet_results (list|tuple): mmdet results.
+ bbox_thr (float): threshold for bounding boxes.
+ cat_id (int): category id (default: 1 for human)
+
+ Returns:
+ person_results (list): a list of detected bounding boxes
+ """
+ if isinstance(mmdet_results, tuple):
+ det_results = mmdet_results[0]
+ else:
+ det_results = mmdet_results
+
+ bboxes = det_results[cat_id - 1]
+
+ person_results = []
+ bboxes = np.array(bboxes)
+
+ if bbox_thr is not None:
+ assert bboxes.shape[-1] == 5
+ valid_idx = np.where(bboxes[:, 4] > bbox_thr)[0]
+ bboxes = bboxes[valid_idx]
+
+ for bbox in bboxes:
+ person = {}
+ person['bbox'] = bbox
+ person_results.append(person)
+
+ return person_results
+
+
+def prepare_frames(input_path=None):
+ """Prepare frames from input_path.
+
+ Args:
+ input_path (str, optional): Defaults to None.
+
+ Raises:
+ ValueError: check the input path.
+
+ Returns:
+ List[np.ndarray]: prepared frames
+ """
+ if Path(input_path).is_file():
+ img_list = [mmcv.imread(input_path)]
+ if img_list[0] is None:
+ video = mmcv.VideoReader(input_path)
+ assert video.opened, f'Failed to load file {input_path}'
+ img_list = list(video)
+ elif Path(input_path).is_dir():
+ # input_type = 'folder'
+ file_list = [
+ os.path.join(input_path, fn) for fn in os.listdir(input_path)
+ if fn.lower().endswith(('.png', '.jpg'))
+ ]
+ file_list.sort()
+ img_list = [mmcv.imread(img_path) for img_path in file_list]
+ assert len(img_list), f'Failed to load image from {input_path}'
+ else:
+ raise ValueError('Input path should be an file or folder.'
+ f' Got invalid input path: {input_path}')
+ return img_list
+
+
+def extract_feature_sequence(extracted_results,
+ frame_idx,
+ causal,
+ seq_len,
+ step=1):
+ """Extract the target frame from person results, and pad the sequence to a
+ fixed length.
+
+ Args:
+ extracted_results (List[List[Dict]]): Multi-frame feature extraction
+ results stored in a nested list. Each element of the outer list
+ is the feature extraction results of a single frame, and each
+ element of the inner list is the feature information of one person,
+ which contains:
+ features (ndarray): extracted features
+ track_id (int): unique id of each person, required when
+ ``with_track_id==True```
+ bbox ((4, ) or (5, )): left, right, top, bottom, [score]
+ frame_idx (int): The index of the frame in the original video.
+ causal (bool): If True, the target frame is the first frame in
+ a sequence. Otherwise, the target frame is in the middle of a
+ sequence.
+ seq_len (int): The number of frames in the input sequence.
+ step (int): Step size to extract frames from the video.
+
+ Returns:
+ List[List[Dict]]: Multi-frame feature extraction results stored in a
+ nested list with a length of seq_len.
+ int: The target frame index in the padded sequence.
+ """
+
+ if causal:
+ frames_left = 0
+ frames_right = seq_len - 1
+ else:
+ frames_left = (seq_len - 1) // 2
+ frames_right = frames_left
+ num_frames = len(extracted_results)
+
+ # get the padded sequence
+ pad_left = max(0, frames_left - frame_idx // step)
+ pad_right = max(0, frames_right - (num_frames - 1 - frame_idx) // step)
+ start = max(frame_idx % step, frame_idx - frames_left * step)
+ end = min(num_frames - (num_frames - 1 - frame_idx) % step,
+ frame_idx + frames_right * step + 1)
+ extracted_results_seq = [extracted_results[0]] * pad_left + \
+ extracted_results[start:end:step] + [extracted_results[-1]] * pad_right
+ return extracted_results_seq
+
+
+def get_different_colors(number_of_colors,
+ flag=0,
+ alpha: float = 1.0,
+ mode: str = 'bgr',
+ int_dtype: bool = True):
+ """Get a numpy of colors of shape (N, 3)."""
+ mode = mode.lower()
+ assert set(mode).issubset({'r', 'g', 'b', 'a'})
+ nst0 = np.random.get_state()
+ np.random.seed(flag)
+ colors = []
+ for i in np.arange(0., 360., 360. / number_of_colors):
+ hue = i / 360.
+ lightness = (50 + np.random.rand() * 10) / 100.
+ saturation = (90 + np.random.rand() * 10) / 100.
+ colors.append(colorsys.hls_to_rgb(hue, lightness, saturation))
+ colors_np = np.asarray(colors)
+ if int_dtype:
+ colors_bgr = (255 * colors_np).astype(np.uint8)
+ else:
+ colors_bgr = colors_np.astype(np.float32)
+ # recover the random state
+ np.random.set_state(nst0)
+ color_dict = {}
+ if 'a' in mode:
+ color_dict['a'] = np.ones((colors_bgr.shape[0], 3)) * alpha
+ color_dict['b'] = colors_bgr[:, 0:1]
+ color_dict['g'] = colors_bgr[:, 1:2]
+ color_dict['r'] = colors_bgr[:, 2:3]
+ colors_final = []
+ for channel in mode:
+ colors_final.append(color_dict[channel])
+ colors_final = np.concatenate(colors_final, -1)
+ return colors_final
+
+
+class RunningAverage():
+ r"""A helper class to calculate running average in a sliding window.
+
+ Args:
+ window (int): The size of the sliding window.
+ """
+ def __init__(self, window: int = 1):
+ self.window = window
+ self._data = []
+
+ def update(self, value):
+ """Update a new data sample."""
+ self._data.append(value)
+ self._data = self._data[-self.window:]
+
+ def average(self):
+ """Get the average value of current window."""
+ return np.mean(self._data)
+
+
+class StopWatch:
+ r"""A helper class to measure FPS and detailed time consuming of each phase
+ in a video processing loop or similar scenarios.
+
+ Args:
+ window (int): The sliding window size to calculate the running average
+ of the time consuming.
+
+ Example:
+ >>> from mmpose.utils import StopWatch
+ >>> import time
+ >>> stop_watch = StopWatch(window=10)
+ >>> with stop_watch.timeit('total'):
+ >>> time.sleep(0.1)
+ >>> # 'timeit' support nested use
+ >>> with stop_watch.timeit('phase1'):
+ >>> time.sleep(0.1)
+ >>> with stop_watch.timeit('phase2'):
+ >>> time.sleep(0.2)
+ >>> time.sleep(0.2)
+ >>> report = stop_watch.report()
+ """
+ def __init__(self, window=1):
+ self.window = window
+ self._record = defaultdict(partial(RunningAverage, window=self.window))
+ self._timer_stack = []
+
+ @contextmanager
+ def timeit(self, timer_name='_FPS_'):
+ """Timing a code snippet with an assigned name.
+
+ Args:
+ timer_name (str): The unique name of the interested code snippet to
+ handle multiple timers and generate reports. Note that '_FPS_'
+ is a special key that the measurement will be in `fps` instead
+ of `millisecond`. Also see `report` and `report_strings`.
+ Default: '_FPS_'.
+ Note:
+ This function should always be used in a `with` statement, as shown
+ in the example.
+ """
+ self._timer_stack.append((timer_name, Timer()))
+ try:
+ yield
+ finally:
+ timer_name, timer = self._timer_stack.pop()
+ self._record[timer_name].update(timer.since_start())
+
+ def report(self, key=None):
+ """Report timing information.
+
+ Returns:
+ dict: The key is the timer name and the value is the \
+ corresponding average time consuming.
+ """
+ result = {
+ name: r.average() * 1000.
+ for name, r in self._record.items()
+ }
+
+ if '_FPS_' in result:
+ result['_FPS_'] = 1000. / result.pop('_FPS_')
+
+ if key is None:
+ return result
+ return result[key]
+
+ def report_strings(self):
+ """Report timing information in texture strings.
+
+ Returns:
+ list(str): Each element is the information string of a timed \
+ event, in format of '{timer_name}: {time_in_ms}'. \
+ Specially, if timer_name is '_FPS_', the result will \
+ be converted to fps.
+ """
+ result = self.report()
+ strings = []
+ if '_FPS_' in result:
+ strings.append(f'FPS: {result["_FPS_"]:>5.1f}')
+ strings += [f'{name}: {val:>3.0f}' for name, val in result.items()]
+ return strings
+
+ def reset(self):
+ self._record = defaultdict(list)
+ self._active_timer_stack = []
diff --git a/detrsmpl/utils/dist_utils.py b/detrsmpl/utils/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..83800706f8e4bb81395eb8e33ecc69028bd98f3e
--- /dev/null
+++ b/detrsmpl/utils/dist_utils.py
@@ -0,0 +1,67 @@
+from collections import OrderedDict
+
+import torch.distributed as dist
+from mmcv.runner import OptimizerHook
+from torch._utils import (
+ _flatten_dense_tensors,
+ _take_tensors,
+ _unflatten_dense_tensors,
+)
+
+
+def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
+ if bucket_size_mb > 0:
+ bucket_size_bytes = bucket_size_mb * 1024 * 1024
+ buckets = _take_tensors(tensors, bucket_size_bytes)
+ else:
+ buckets = OrderedDict()
+ for tensor in tensors:
+ tp = tensor.type()
+ if tp not in buckets:
+ buckets[tp] = []
+ buckets[tp].append(tensor)
+ buckets = buckets.values()
+
+ for bucket in buckets:
+ flat_tensors = _flatten_dense_tensors(bucket)
+ dist.all_reduce(flat_tensors)
+ flat_tensors.div_(world_size)
+ for tensor, synced in zip(
+ bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
+ tensor.copy_(synced)
+
+
+def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
+ grads = [
+ param.grad.data for param in params
+ if param.requires_grad and param.grad is not None
+ ]
+ world_size = dist.get_world_size()
+ if coalesce:
+ _allreduce_coalesced(grads, world_size, bucket_size_mb)
+ else:
+ for tensor in grads:
+ dist.all_reduce(tensor.div_(world_size))
+
+
+class DistOptimizerHook(OptimizerHook):
+ def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1):
+ self.grad_clip = grad_clip
+ self.coalesce = coalesce
+ self.bucket_size_mb = bucket_size_mb
+
+ def after_train_iter(self, runner):
+ runner.optimizer.zero_grad()
+ runner.outputs['loss'].backward()
+ if self.grad_clip is not None:
+ self.clip_grads(runner.model.parameters())
+ runner.optimizer.step()
+
+
+def reduce_mean(tensor):
+ """"Obtain the mean of tensor on different GPUs."""
+ if not (dist.is_available() and dist.is_initialized()):
+ return tensor
+ tensor = tensor.clone()
+ dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
+ return tensor
diff --git a/detrsmpl/utils/ffmpeg_utils.py b/detrsmpl/utils/ffmpeg_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a401f6d3ae16055c870f1ec3f8a6fa99ab612824
--- /dev/null
+++ b/detrsmpl/utils/ffmpeg_utils.py
@@ -0,0 +1,1376 @@
+import glob
+import json
+import os
+import shutil
+import string
+import subprocess
+import sys
+from pathlib import Path
+from typing import Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+
+from detrsmpl.utils.path_utils import check_input_path, prepare_output_path
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+
+class video_writer:
+
+ def __init__(self,
+ output_path: str,
+ resolution: Iterable[int],
+ fps: float = 30.0,
+ num_frame: int = 1e9,
+ disable_log: bool = False) -> None:
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+ height, width = resolution
+ width += width % 2
+ height += height % 2
+ command = [
+ 'ffmpeg',
+ '-y', # (optional) overwrite output file if it exists
+ '-f',
+ 'rawvideo',
+ '-pix_fmt',
+ 'bgr24',
+ '-s',
+ f'{int(width)}x{int(height)}',
+ '-r',
+ f'{fps}', # frames per second
+ '-loglevel',
+ 'error',
+ '-threads',
+ '1',
+ '-i',
+ '-', # The input comes from a pipe
+ '-vcodec',
+ 'libx264',
+ '-r',
+ f'{fps}', # frames per second
+ '-an', # Tells FFMPEG not to expect any audio
+ output_path,
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ process = subprocess.Popen(
+ command,
+ stdin=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ if process.stdin is None or process.stderr is None:
+ raise BrokenPipeError('No buffer received.')
+ self.process = process
+ self.num_frame = num_frame
+ self.len = 0
+
+ def write(self, image_array: np.ndarray):
+ if self.len <= self.num_frame:
+ try:
+ self.process.stdin.write(image_array.tobytes())
+ self.len += 1
+ except KeyboardInterrupt:
+ self.__del__()
+
+ def __del__(self):
+ self.process.stdin.close()
+ self.process.stderr.close()
+ self.process.wait()
+
+
+def array_to_video(
+ image_array: np.ndarray,
+ output_path: str,
+ fps: Union[int, float] = 30,
+ resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
+ disable_log: bool = False,
+) -> None:
+ """Convert an array to a video directly, gif not supported.
+
+ Args:
+ image_array (np.ndarray): shape should be (f * h * w * 3).
+ output_path (str): output video file path.
+ fps (Union[int, float, optional): fps. Defaults to 30.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of the output video.
+ Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check output path.
+ TypeError: check input array.
+
+ Returns:
+ None.
+ """
+ if not isinstance(image_array, np.ndarray):
+ raise TypeError('Input should be np.ndarray.')
+ assert image_array.ndim == 4
+ assert image_array.shape[-1] == 3
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+ if resolution:
+ height, width = resolution
+ width += width % 2
+ height += height % 2
+ else:
+ image_array = pad_for_libx264(image_array)
+ height, width = image_array.shape[1], image_array.shape[2]
+ command = [
+ 'ffmpeg',
+ '-y', # (optional) overwrite output file if it exists
+ '-f',
+ 'rawvideo',
+ '-s',
+ f'{int(width)}x{int(height)}', # size of one frame
+ '-pix_fmt',
+ 'bgr24',
+ '-r',
+ f'{fps}', # frames per second
+ '-loglevel',
+ 'error',
+ '-threads',
+ '4',
+ ''
+ '-i',
+ '-', # The input comes from a pipe
+ '-vcodec',
+ 'libx264',
+ '-an', # Tells FFMPEG not to expect any audio
+ output_path,
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ process = subprocess.Popen(
+ command,
+ stdin=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ if process.stdin is None or process.stderr is None:
+ raise BrokenPipeError('No buffer received.')
+ index = 0
+ while True:
+ if index >= image_array.shape[0]:
+ break
+ process.stdin.write(image_array[index].tobytes())
+ index += 1
+ process.stdin.close()
+ process.stderr.close()
+ process.wait()
+
+
+def array_to_images(
+ image_array: np.ndarray,
+ output_folder: str,
+ img_format: str = '%06d.png',
+ resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
+ disable_log: bool = False,
+) -> None:
+ """Convert an array to images directly.
+
+ Args:
+ image_array (np.ndarray): shape should be (f * h * w * 3).
+ output_folder (str): output folder for the images.
+ img_format (str, optional): format of the images.
+ Defaults to '%06d.png'.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): resolution(height, width) of output.
+ Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+
+ Raises:
+ FileNotFoundError: check output folder.
+ TypeError: check input array.
+
+ Returns:
+ None
+ """
+ prepare_output_path(
+ output_folder,
+ allowed_suffix=[],
+ tag='output image folder',
+ path_type='dir',
+ overwrite=True)
+
+ if not isinstance(image_array, np.ndarray):
+ raise TypeError('Input should be np.ndarray.')
+ assert image_array.ndim == 4
+ assert image_array.shape[-1] == 3
+ if resolution:
+ height, width = resolution
+ else:
+ height, width = image_array.shape[1], image_array.shape[2]
+ command = [
+ 'ffmpeg',
+ '-y', # (optional) overwrite output file if it exists
+ '-f',
+ 'rawvideo',
+ '-s',
+ f'{int(width)}x{int(height)}', # size of one frame
+ '-pix_fmt',
+ 'bgr24', # bgr24 for matching OpenCV
+ '-loglevel',
+ 'error',
+ '-threads',
+ '4',
+ '-i',
+ '-', # The input comes from a pipe
+ '-f',
+ 'image2',
+ '-start_number',
+ '0',
+ os.path.join(output_folder, img_format),
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ process = subprocess.Popen(
+ command,
+ stdin=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ bufsize=10**8,
+ close_fds=True)
+ if process.stdin is None or process.stderr is None:
+ raise BrokenPipeError('No buffer received.')
+ index = 0
+ while True:
+ if index >= image_array.shape[0]:
+ break
+ process.stdin.write(image_array[index].tobytes())
+ index += 1
+ process.stdin.close()
+ process.stderr.close()
+ process.wait()
+
+
+def video_to_array(
+ input_path: str,
+ resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
+ start: int = 0,
+ end: Optional[int] = None,
+ disable_log: bool = False,
+) -> np.ndarray:
+ """
+ Read a video/gif as an array of (f * h * w * 3).
+
+ Args:
+ input_path (str): input path.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): resolution(height, width) of output.
+ Defaults to None.
+ start (int, optional): start frame index. Inclusive.
+ If < 0, will be converted to frame_index range in [0, frame_num].
+ Defaults to 0.
+ end (int, optional): end frame index. Exclusive.
+ Could be positive int or negative int or None.
+ If None, all frames from start till the last frame are included.
+ Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+
+ Raises:
+ FileNotFoundError: check the input path.
+
+ Returns:
+ np.ndarray: shape will be (f * h * w * 3).
+ """
+ check_input_path(
+ input_path,
+ allowed_suffix=['.mp4', 'mkv', 'avi', '.gif'],
+ tag='input video',
+ path_type='file')
+
+ info = vid_info_reader(input_path)
+ if resolution:
+ height, width = resolution
+ else:
+ width, height = int(info['width']), int(info['height'])
+ num_frames = int(info['nb_frames'])
+ start = (min(start, num_frames - 1) + num_frames) % num_frames
+ end = (min(end, num_frames - 1) +
+ num_frames) % num_frames if end is not None else num_frames
+ command = [
+ 'ffmpeg',
+ '-i',
+ input_path,
+ '-filter_complex',
+ f'[0]trim=start_frame={start}:end_frame={end}[v0]',
+ '-map',
+ '[v0]',
+ '-pix_fmt',
+ 'bgr24', # bgr24 for matching OpenCV
+ '-s',
+ f'{int(width)}x{int(height)}',
+ '-f',
+ 'image2pipe',
+ '-vcodec',
+ 'rawvideo',
+ '-loglevel',
+ 'error',
+ 'pipe:'
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ # Execute FFmpeg as sub-process with stdout as a pipe
+ process = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=10**8)
+ if process.stdout is None:
+ raise BrokenPipeError('No buffer received.')
+ # Read decoded video frames from the PIPE until no more frames to read
+ array = []
+ while True:
+ # Read decoded video frame (in raw video format) from stdout process.
+ buffer = process.stdout.read(int(width * height * 3))
+ # Break the loop if buffer length is not W*H*3\
+ # (when FFmpeg streaming ends).
+ if len(buffer) != width * height * 3:
+ break
+ img = np.frombuffer(buffer, np.uint8).reshape(height, width, 3)
+ array.append(img[np.newaxis])
+ process.stdout.flush()
+ process.stdout.close()
+ process.wait()
+ return np.concatenate(array)
+
+
+def images_to_sorted_images(input_folder, output_folder, img_format='%06d'):
+ """Copy and rename a folder of images into a new folder following the
+ `img_format`.
+
+ Args:
+ input_folder (str): input folder.
+ output_folder (str): output folder.
+ img_format (str, optional): image format name, do not need extension.
+ Defaults to '%06d'.
+
+ Returns:
+ str: image format of the rename images.
+ """
+ img_format = img_format.rsplit('.', 1)[0]
+ file_list = []
+ os.makedirs(output_folder, exist_ok=True)
+ pngs = glob.glob(os.path.join(input_folder, '*.png'))
+ if pngs:
+ ext = 'png'
+ file_list.extend(pngs)
+ jpgs = glob.glob(os.path.join(input_folder, '*.jpg'))
+ if jpgs:
+ ext = 'jpg'
+ file_list.extend(jpgs)
+ file_list.sort()
+ for index, file_name in enumerate(file_list):
+ shutil.copy(
+ file_name,
+ os.path.join(output_folder, (img_format + '.%s') % (index, ext)))
+ return img_format + '.%s' % ext
+
+
+def images_to_array(
+ input_folder: str,
+ resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
+ img_format: str = '%06d.png',
+ start: int = 0,
+ end: Optional[int] = None,
+ remove_raw_files: bool = False,
+ disable_log: bool = False,
+) -> np.ndarray:
+ """
+ Read a folder of images as an array of (f * h * w * 3).
+
+ Args:
+ input_folder (str): folder of input images.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]]:
+ resolution(height, width) of output. Defaults to None.
+ img_format (str, optional): format of images to be read.
+ Defaults to '%06d.png'.
+ start (int, optional): start frame index. Inclusive.
+ If < 0, will be converted to frame_index range in [0, frame_num].
+ Defaults to 0.
+ end (int, optional): end frame index. Exclusive.
+ Could be positive int or negative int or None.
+ If None, all frames from start till the last frame are included.
+ Defaults to None.
+ remove_raw_files (bool, optional): whether remove raw images.
+ Defaults to False.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+
+ Returns:
+ np.ndarray: shape will be (f * h * w * 3).
+ """
+ check_input_path(
+ input_folder,
+ allowed_suffix=[''],
+ tag='input image folder',
+ path_type='dir')
+
+ input_folderinfo = Path(input_folder)
+
+ temp_input_folder = None
+ if img_format is None:
+ temp_input_folder = os.path.join(input_folderinfo.parent,
+ input_folderinfo.name + '_temp')
+ img_format = images_to_sorted_images(
+ input_folder=input_folder, output_folder=temp_input_folder)
+ input_folder = temp_input_folder
+
+ info = vid_info_reader(f'{input_folder}/{img_format}' % start)
+ width, height = int(info['width']), int(info['height'])
+ if resolution:
+ height, width = resolution
+ else:
+ width, height = int(info['width']), int(info['height'])
+
+ num_frames = len(os.listdir(input_folder))
+ start = max(start, 0) % num_frames
+ end = min(end, num_frames) % (num_frames + 1) \
+ if end is not None else num_frames
+ command = [
+ 'ffmpeg',
+ '-y',
+ '-threads',
+ '1',
+ '-start_number',
+ f'{start}',
+ '-i',
+ f'{input_folder}/{img_format}',
+ '-frames:v',
+ f'{end - start}',
+ '-f',
+ 'rawvideo',
+ '-pix_fmt',
+ 'bgr24', # bgr24 for matching OpenCV
+ '-s',
+ f'{int(width)}x{int(height)}',
+ '-loglevel',
+ 'error',
+ '-'
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ process = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=10**8)
+ if process.stdout is None:
+ raise BrokenPipeError('No buffer received.')
+ # Read decoded video frames from the PIPE until no more frames to read
+ array = []
+ while True:
+ # Read decoded video frame (in raw video format) from stdout process.
+ buffer = process.stdout.read(int(width * height * 3))
+ # Break the loop if buffer length is not W*H*3\
+ # (when FFmpeg streaming ends).
+
+ if len(buffer) != width * height * 3:
+ break
+ img = np.frombuffer(buffer, np.uint8).reshape(height, width, 3)
+ array.append(img[np.newaxis])
+ process.stdout.flush()
+ process.stdout.close()
+ process.wait()
+ if temp_input_folder is not None:
+ if Path(temp_input_folder).is_dir():
+ shutil.rmtree(temp_input_folder)
+ if remove_raw_files:
+ if Path(input_folder).is_dir():
+ shutil.rmtree(input_folder)
+
+ return np.concatenate(array)
+
+
+class vid_info_reader(object):
+
+ def __init__(self, input_path) -> None:
+ """Get video information from video, mimiced from ffmpeg-python.
+ https://github.com/kkroening/ffmpeg-python.
+
+ Args:
+ vid_file ([str]): video file path.
+
+ Raises:
+ FileNotFoundError: check the input path.
+
+ Returns:
+ None.
+ """
+ check_input_path(
+ input_path,
+ allowed_suffix=['.mp4', '.gif', '.png', '.jpg', '.jpeg'],
+ tag='input file',
+ path_type='file')
+ cmd = [
+ 'ffprobe', '-show_format', '-show_streams', '-of', 'json',
+ input_path
+ ]
+ process = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ out, _ = process.communicate()
+ probe = json.loads(out.decode('utf-8'))
+ video_stream = next((stream for stream in probe['streams']
+ if stream['codec_type'] == 'video'), None)
+ if video_stream is None:
+ print('No video stream found', file=sys.stderr)
+ sys.exit(1)
+ self.video_stream = video_stream
+
+ def __getitem__(
+ self,
+ key: Literal['index', 'codec_name', 'codec_long_name', 'profile',
+ 'codec_type', 'codec_time_base', 'codec_tag_string',
+ 'codec_tag', 'width', 'height', 'coded_width',
+ 'coded_height', 'has_b_frames', 'pix_fmt', 'level',
+ 'chroma_location', 'refs', 'is_avc', 'nal_length_size',
+ 'r_frame_rate', 'avg_frame_rate', 'time_base',
+ 'start_pts', 'start_time', 'duration_ts', 'duration',
+ 'bit_rate', 'bits_per_raw_sample', 'nb_frames',
+ 'disposition', 'tags']):
+ """Key (str): select in ['index', 'codec_name', 'codec_long_name',
+ 'profile', 'codec_type', 'codec_time_base', 'codec_tag_string',
+ 'codec_tag', 'width', 'height', 'coded_width', 'coded_height',
+ 'has_b_frames', 'pix_fmt', 'level', 'chroma_location', 'refs',
+ 'is_avc', 'nal_length_size', 'r_frame_rate', 'avg_frame_rate',
+ 'time_base', 'start_pts', 'start_time', 'duration_ts', 'duration',
+ 'bit_rate', 'bits_per_raw_sample', 'nb_frames', 'disposition',
+ 'tags']"""
+ return self.video_stream[key]
+
+
+def video_to_gif(
+ input_path: str,
+ output_path: str,
+ resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
+ fps: Union[float, int] = 15,
+ disable_log: bool = False,
+) -> None:
+ """Convert a video to a gif file.
+
+ Args:
+ input_path (str): video file path.
+ output_path (str): gif file path.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of the output video.
+ Defaults to None.
+ fps (Union[float, int], optional): frames per second. Defaults to 15.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None.
+ """
+ check_input_path(
+ input_path,
+ allowed_suffix=['.mp4'],
+ tag='input video',
+ path_type='file')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.gif'],
+ tag='output gif',
+ path_type='file',
+ overwrite=True)
+
+ info = vid_info_reader(input_path)
+ duration = info['duration']
+ if resolution:
+ height, width = resolution
+ else:
+ width, height = int(info['width']), int(info['height'])
+
+ command = [
+ 'ffmpeg', '-r',
+ str(info['r_frame_rate']), '-i', input_path, '-r', f'{fps}', '-s',
+ f'{width}x{height}', '-loglevel', 'error', '-t', f'{duration}',
+ '-threads', '4', '-y', output_path
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+
+
+def video_to_images(input_path: str,
+ output_folder: str,
+ resolution: Optional[Union[Tuple[int, int],
+ Tuple[float, float]]] = None,
+ img_format: str = '%06d.png',
+ start: int = 0,
+ end: Optional[int] = None,
+ disable_log: bool = False) -> None:
+ """Convert a video to a folder of images.
+
+ Args:
+ input_path (str): video file path
+ output_folder (str): output folder to store the images
+ resolution (Optional[Tuple[int, int]], optional):
+ (height, width) of output. defaults to None.
+ img_format (str, optional): format of images to be read.
+ Defaults to '%06d.png'.
+ start (int, optional): start frame index. Inclusive.
+ If < 0, will be converted to frame_index range in [0, frame_num].
+ Defaults to 0.
+ end (int, optional): end frame index. Exclusive.
+ Could be positive int or negative int or None.
+ If None, all frames from start till the last frame are included.
+ Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path
+ FileNotFoundError: check the output path
+
+ Returns:
+ None
+ """
+ check_input_path(
+ input_path,
+ allowed_suffix=['.mp4'],
+ tag='input video',
+ path_type='file')
+ prepare_output_path(
+ output_folder,
+ allowed_suffix=[],
+ tag='output image folder',
+ path_type='dir',
+ overwrite=True)
+ info = vid_info_reader(input_path)
+ num_frames = int(info['nb_frames'])
+ start = (min(start, num_frames - 1) + num_frames) % num_frames
+ end = (min(end, num_frames - 1) +
+ num_frames) % num_frames if end is not None else num_frames
+
+ command = [
+ 'ffmpeg', '-i', input_path, '-filter_complex',
+ f'[0]trim=start_frame={start}:end_frame={end}[v0]', '-map', '[v0]',
+ '-f', 'image2', '-v', 'error', '-start_number', '0', '-threads', '1',
+ f'{output_folder}/{img_format}'
+ ]
+ if resolution:
+ height, width = resolution
+ command.insert(3, '-s')
+ command.insert(4, '%dx%d' % (width, height))
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+
+
+def images_to_video(input_folder: str,
+ output_path: str,
+ remove_raw_file: bool = False,
+ img_format: str = '%06d.png',
+ fps: Union[int, float] = 30,
+ resolution: Optional[Union[Tuple[int, int],
+ Tuple[float, float]]] = None,
+ start: int = 0,
+ end: Optional[int] = None,
+ disable_log: bool = False) -> None:
+ """Convert a folder of images to a video.
+
+ Args:
+ input_folder (str): input image folder
+ output_path (str): output video file path
+ remove_raw_file (bool, optional): whether remove raw images.
+ Defaults to False.
+ img_format (str, optional): format to name the images].
+ Defaults to '%06d.png'.
+ fps (Union[int, float], optional): output video fps. Defaults to 30.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of output.
+ defaults to None.
+ start (int, optional): start frame index. Inclusive.
+ If < 0, will be converted to frame_index range in [0, frame_num].
+ Defaults to 0.
+ end (int, optional): end frame index. Exclusive.
+ Could be positive int or negative int or None.
+ If None, all frames from start till the last frame are included.
+ Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None
+ """
+ check_input_path(
+ input_folder,
+ allowed_suffix=[],
+ tag='input image folder',
+ path_type='dir')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+ input_folderinfo = Path(input_folder)
+ num_frames = len(os.listdir(input_folder))
+ start = (min(start, num_frames - 1) + num_frames) % num_frames
+ end = (min(end, num_frames - 1) +
+ num_frames) % num_frames if end is not None else num_frames
+ temp_input_folder = None
+ if img_format is None:
+ temp_input_folder = os.path.join(input_folderinfo.parent,
+ input_folderinfo.name + '_temp')
+ img_format = images_to_sorted_images(input_folder, temp_input_folder)
+
+ command = [
+ 'ffmpeg',
+ '-y',
+ '-threads',
+ '4',
+ '-start_number',
+ f'{start}',
+ '-r',
+ f'{fps}',
+ '-i',
+ f'{input_folder}/{img_format}'
+ if temp_input_folder is None else f'{temp_input_folder}/{img_format}',
+ '-frames:v',
+ f'{end - start}',
+ '-profile:v',
+ 'baseline',
+ '-level',
+ '3.0',
+ '-c:v',
+ 'libx264',
+ '-pix_fmt',
+ 'yuv420p',
+ '-vf',
+ 'scale=trunc(iw/2)*2:trunc(ih/2)*2', # Ensure width and height are divisible by 2
+ '-an',
+ '-v',
+ 'error',
+ '-loglevel',
+ 'error',
+ output_path,
+ ]
+ if resolution:
+ height, width = resolution
+ width += width % 2
+ height += height % 2
+ command.insert(1, '-s')
+ command.insert(2, '%dx%d' % (width, height))
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+ if remove_raw_file:
+ if Path(input_folder).is_dir():
+ shutil.rmtree(input_folder)
+ if temp_input_folder is not None:
+ if Path(temp_input_folder).is_dir():
+ shutil.rmtree(temp_input_folder)
+
+
+def images_to_gif(
+ input_folder: str,
+ output_path: str,
+ remove_raw_file: bool = False,
+ img_format: str = '%06d.png',
+ fps: int = 15,
+ resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
+ start: int = 0,
+ end: Optional[int] = None,
+ disable_log: bool = False,
+) -> None:
+ """Convert series of images to a video, similar to images_to_video, but
+ provide more suitable parameters.
+
+ Args:
+ input_folder (str): input image folder.
+ output_path (str): output gif file path.
+ remove_raw_file (bool, optional): whether remove raw images.
+ Defaults to False.
+ img_format (str, optional): format to name the images.
+ Defaults to '%06d.png'.
+ fps (int, optional): output video fps. Defaults to 15.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of output. Defaults to None.
+ start (int, optional): start frame index. Inclusive.
+ If < 0, will be converted to frame_index range in [0, frame_num].
+ Defaults to 0.
+ end (int, optional): end frame index. Exclusive.
+ Could be positive int or negative int or None.
+ If None, all frames from start till the last frame are included.
+ Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None
+ """
+ input_folderinfo = Path(input_folder)
+ check_input_path(
+ input_folder,
+ allowed_suffix=[],
+ tag='input image folder',
+ path_type='dir')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.gif'],
+ tag='output gif',
+ path_type='file',
+ overwrite=True)
+ num_frames = len(os.listdir(input_folder))
+ start = (min(start, num_frames - 1) + num_frames) % num_frames
+ end = (min(end, num_frames - 1) +
+ num_frames) % num_frames if end is not None else num_frames
+ temp_input_folder = None
+ if img_format is None:
+ file_list = []
+ temp_input_folder = os.path.join(input_folderinfo.parent,
+ input_folderinfo.name + '_temp')
+ os.makedirs(temp_input_folder, exist_ok=True)
+ pngs = glob.glob(os.path.join(input_folder, '*.png'))
+ ext = 'png'
+ if pngs:
+ ext = 'png'
+ file_list.extend(pngs)
+ jpgs = glob.glob(os.path.join(input_folder, '*.jpg'))
+ if jpgs:
+ ext = 'jpg'
+ file_list.extend(jpgs)
+ file_list.sort()
+ for index, file_name in enumerate(file_list):
+ shutil.copy(
+ file_name,
+ os.path.join(temp_input_folder, '%06d.%s' % (index + 1, ext)))
+ input_folder = temp_input_folder
+ img_format = '%06d.' + ext
+
+ command = [
+ 'ffmpeg',
+ '-y',
+ '-threads',
+ '4',
+ '-start_number',
+ f'{start}',
+ '-r',
+ f'{fps}',
+ '-i',
+ f'{input_folder}/{img_format}',
+ '-frames:v',
+ f'{end - start}',
+ '-loglevel',
+ 'error',
+ '-v',
+ 'error',
+ output_path,
+ ]
+ if resolution:
+ height, width = resolution
+ command.insert(1, '-s')
+ command.insert(2, '%dx%d' % (width, height))
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+ if remove_raw_file:
+ shutil.rmtree(input_folder)
+ if temp_input_folder is not None:
+ shutil.rmtree(temp_input_folder)
+
+
+def gif_to_video(input_path: str,
+ output_path: str,
+ fps: int = 30,
+ remove_raw_file: bool = False,
+ resolution: Optional[Union[Tuple[int, int],
+ Tuple[float, float]]] = None,
+ disable_log: bool = False) -> None:
+ """Convert a gif file to a video.
+
+ Args:
+ input_path (str): input gif file path.
+ output_path (str): output video file path.
+ fps (int, optional): fps. Defaults to 30.
+ remove_raw_file (bool, optional): whether remove original input file.
+ Defaults to False.
+ down_sample_scale (Union[int, float], optional): down sample scale.
+ Defaults to 1.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of output. Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None
+ """
+ check_input_path(
+ input_path, allowed_suffix=['.gif'], tag='input gif', path_type='file')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+ command = [
+ 'ffmpeg', '-i', input_path, '-r', f'{fps}', '-loglevel', 'error', '-y',
+ output_path, '-threads', '4'
+ ]
+ if resolution:
+ height, width = resolution
+ command.insert(3, '-s')
+ command.insert(4, '%dx%d' % (width, height))
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+ if remove_raw_file:
+ subprocess.call(['rm', '-f', input_path])
+
+
+def gif_to_images(input_path: str,
+ output_folder: str,
+ fps: int = 30,
+ img_format: str = '%06d.png',
+ resolution: Optional[Union[Tuple[int, int],
+ Tuple[float, float]]] = None,
+ disable_log: bool = False) -> None:
+ """Convert a gif file to a folder of images.
+
+ Args:
+ input_path (str): input gif file path.
+ output_folder (str): output folder to save the images.
+ fps (int, optional): fps. Defaults to 30.
+ img_format (str, optional): output image name format.
+ Defaults to '%06d.png'.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of output.
+ Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None
+ """
+ check_input_path(
+ input_path, allowed_suffix=['.gif'], tag='input gif', path_type='file')
+ prepare_output_path(
+ output_folder,
+ allowed_suffix=[],
+ tag='output image folder',
+ path_type='dir',
+ overwrite=True)
+ command = [
+ 'ffmpeg', '-r', f'{fps}', '-i', input_path, '-loglevel', 'error', '-f',
+ 'image2', '-v', 'error', '-threads', '4', '-y', '-start_number', '0',
+ f'{output_folder}/{img_format}'
+ ]
+ if resolution:
+ height, width = resolution
+ command.insert(3, '-s')
+ command.insert(4, '%dx%d' % (width, height))
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+
+
+def crop_video(
+ input_path: str,
+ output_path: str,
+ box: Optional[Union[List[int], Tuple[int, int, int, int]]] = None,
+ resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
+ disable_log: bool = False,
+) -> None:
+ """Spatially or temporally crop a video or gif file.
+
+ Args:
+ input_path (str): input video or gif file path.
+ output_path (str): output video or gif file path.
+ box (Iterable[int], optional): [x, y of the crop region left.
+ corner and width and height]. Defaults to [0, 0, 100, 100].
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of output. Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None'-start_number', f'{start}',
+ """
+ check_input_path(
+ input_path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='input video',
+ path_type='file')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+
+ info = vid_info_reader(input_path)
+ width, height = int(info['width']), int(info['height'])
+
+ if box is None:
+ box = [0, 0, width, height]
+
+ assert len(box) == 4
+ x, y, w, h = box
+ assert (w > 0 and h > 0)
+ command = [
+ 'ffmpeg', '-i', input_path, '-vcodec', 'libx264', '-vf',
+ 'crop=%d:%d:%d:%d' % (w, h, x, y), '-loglevel', 'error', '-y',
+ output_path
+ ]
+ if resolution:
+ height, width = resolution
+ width += width % 2
+ height += height % 2
+ command.insert(-1, '-s')
+ command.insert(-1, '%dx%d' % (width, height))
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+
+
+def slice_video(input_path: str,
+ output_path: str,
+ start: int = 0,
+ end: Optional[int] = None,
+ resolution: Optional[Union[Tuple[int, int],
+ Tuple[float, float]]] = None,
+ disable_log: bool = False) -> None:
+ """Temporally crop a video/gif into another video/gif.
+
+ Args:
+ input_path (str): input video or gif file path.
+ output_path (str): output video of gif file path.
+ start (int, optional): start frame index. Defaults to 0.
+ end (int, optional): end frame index. Exclusive.
+ Could be positive int or negative int or None.
+ If None, all frames from start till the last frame are included.
+ Defaults to None.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of output. Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ NoReturn
+ """
+ info = vid_info_reader(input_path)
+ num_frames = int(info['nb_frames'])
+ start = (min(start, num_frames - 1) + num_frames) % num_frames
+ end = (min(end, num_frames - 1) +
+ num_frames) % num_frames if end is not None else num_frames
+ command = [
+ 'ffmpeg', '-y', '-i', input_path, '-filter_complex',
+ f'[0]trim=start_frame={start}:end_frame={end}[v0]', '-map', '[v0]',
+ '-loglevel', 'error', '-vcodec', 'libx264', output_path
+ ]
+ if resolution:
+ height, width = resolution
+ width += width % 2
+ height += height % 2
+ command.insert(1, '-s')
+ command.insert(2, '%dx%d' % (width, height))
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+
+
+def spatial_concat_video(input_path_list: List[str],
+ output_path: str,
+ array: List[int] = [1, 1],
+ direction: Literal['h', 'w'] = 'h',
+ resolution: Union[Tuple[int,
+ int], List[int], List[float],
+ Tuple[float, float]] = (512, 512),
+ remove_raw_files: bool = False,
+ padding: int = 0,
+ disable_log: bool = False) -> None:
+ """Spatially concat some videos as an array video.
+
+ Args:
+ input_path_list (list): input video or gif file list.
+ output_path (str): output video or gif file path.
+ array (List[int], optional): line number and column number of
+ the video array]. Defaults to [1, 1].
+ direction (str, optional): [choose in 'h' or 'v', represent
+ horizontal and vertical separately].
+ Defaults to 'h'.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of output.
+ Defaults to (512, 512).
+ remove_raw_files (bool, optional): whether remove raw images.
+ Defaults to False.
+ padding (int, optional): width of pixels between videos.
+ Defaults to 0.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None
+ """
+ lowercase = string.ascii_lowercase
+ assert len(array) == 2
+ assert (array[0] * array[1]) >= len(input_path_list)
+ for path in input_path_list:
+ check_input_path(
+ path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='input video',
+ path_type='file')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+
+ command = ['ffmpeg']
+ height, width = resolution
+ scale_command = []
+ for index, vid_file in enumerate(input_path_list):
+ command.append('-i')
+ command.append(vid_file)
+ scale_command.append(
+ '[%d:v]scale=%d:%d:force_original_aspect_ratio=0[v%d];' %
+ (index, width, height, index))
+
+ scale_command = ' '.join(scale_command)
+ pad_command = '[v%d]pad=%d:%d[%s];' % (0, width * array[1] + padding *
+ (array[1] - 1),
+ height * array[0] + padding *
+ (array[0] - 1), lowercase[0])
+ for index in range(1, len(input_path_list)):
+ if direction == 'h':
+ pad_width = index % array[1] * (width + padding)
+ pad_height = index // array[1] * (height + padding)
+ else:
+ pad_width = index % array[0] * (width + padding)
+ pad_height = index // array[0] * (height + padding)
+
+ pad_command += '[%s][v%d]overlay=%d:%d' % (lowercase[index - 1], index,
+ pad_width, pad_height)
+ if index != len(input_path_list) - 1:
+ pad_command += '[%s];' % lowercase[index]
+
+ command += [
+ '-filter_complex',
+ '%s%s' % (scale_command, pad_command), '-loglevel', 'error', '-y',
+ output_path
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+
+ if remove_raw_files:
+ command = ['rm', '-f'] + input_path_list
+ subprocess.call(command)
+
+
+def temporal_concat_video(input_path_list: List[str],
+ output_path: str,
+ resolution: Union[Tuple[int, int],
+ Tuple[float, float]] = (512, 512),
+ remove_raw_files: bool = False,
+ disable_log: bool = False) -> None:
+ """Concat no matter videos or gifs into a temporal sequence, and save as a
+ new video or gif file.
+
+ Args:
+ input_path_list (List[str]): list of input video paths.
+ output_path (str): output video file path.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]]
+ , optional): (height, width) of output].
+ Defaults to (512,512).
+ remove_raw_files (bool, optional): whether remove the input videos.
+ Defaults to False.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None.
+ """
+ for path in input_path_list:
+ check_input_path(
+ path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='input video',
+ path_type='file')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+
+ height, width = resolution
+ command = ['ffmpeg']
+ concat_command = []
+ scale_command = []
+ for index, vid_file in enumerate(input_path_list):
+ command.append('-i')
+ command.append(vid_file)
+ scale_command.append(
+ '[%d:v]scale=%d:%d:force_original_aspect_ratio=0[v%d];' %
+ (index, width, height, index))
+ concat_command.append('[v%d]' % index)
+ concat_command = ''.join(concat_command)
+ scale_command = ''.join(scale_command)
+ command += [
+ '-filter_complex',
+ '%s%sconcat=n=%d:v=1:a=0[v]' %
+ (scale_command, concat_command, len(input_path_list)), '-loglevel',
+ 'error', '-map', '[v]', '-c:v', 'libx264', '-y', output_path
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+
+ if remove_raw_files:
+ command = ['rm'] + input_path_list
+ subprocess.call(command)
+
+
+def compress_video(input_path: str,
+ output_path: str,
+ compress_rate: int = 1,
+ down_sample_scale: Union[float, int] = 1,
+ fps: int = 30,
+ disable_log: bool = False) -> None:
+ """Compress a video file.
+
+ Args:
+ input_path (str): input video file path.
+ output_path (str): output video file path.
+ compress_rate (int, optional): compress rate, influents the bit rate.
+ Defaults to 1.
+ down_sample_scale (Union[float, int], optional): spatial down sample
+ scale. Defaults to 1.
+ fps (int, optional): Frames per second. Defaults to 30.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None.
+ """
+ input_pathinfo = Path(input_path)
+
+ check_input_path(
+ input_path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='input video',
+ path_type='file')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+
+ info = vid_info_reader(input_path)
+
+ width = int(info['width'])
+ height = int(info['height'])
+ bit_rate = int(info['bit_rate'])
+ duration = float(info['duration'])
+ if (output_path == input_path) or (not output_path):
+ temp_outpath = os.path.join(
+ os.path.abspath(input_pathinfo.parent),
+ 'temp_file' + input_pathinfo.suffix)
+ else:
+ temp_outpath = output_path
+ new_width = int(width / down_sample_scale)
+ new_width += new_width % 2
+ new_height = int(height / down_sample_scale)
+ new_height += new_height % 2
+ command = [
+ 'ffmpeg', '-y', '-r',
+ str(info['r_frame_rate']), '-i', input_path, '-loglevel', 'error',
+ '-b:v', f'{bit_rate / (compress_rate * down_sample_scale)}', '-r',
+ f'{fps}', '-t', f'{duration}', '-s',
+ '%dx%d' % (new_width, new_height), temp_outpath
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+ if (output_path == input_path) or (not output_path):
+ subprocess.call(['mv', '-f', temp_outpath, input_path])
+
+
+def pad_for_libx264(image_array):
+ """Pad zeros if width or height of image_array is not divisible by 2.
+ Otherwise you will get.
+
+ \"[libx264 @ 0x1b1d560] width not divisible by 2 \"
+
+ Args:
+ image_array (np.ndarray):
+ Image or images load by cv2.imread().
+ Possible shapes:
+ 1. [height, width]
+ 2. [height, width, channels]
+ 3. [images, height, width]
+ 4. [images, height, width, channels]
+
+ Returns:
+ np.ndarray:
+ A image with both edges divisible by 2.
+ """
+ if image_array.ndim == 2 or \
+ (image_array.ndim == 3 and image_array.shape[2] == 3):
+ hei_index = 0
+ wid_index = 1
+ elif image_array.ndim == 4 or \
+ (image_array.ndim == 3 and image_array.shape[2] != 3):
+ hei_index = 1
+ wid_index = 2
+ else:
+ return image_array
+ hei_pad = image_array.shape[hei_index] % 2
+ wid_pad = image_array.shape[wid_index] % 2
+ if hei_pad + wid_pad > 0:
+ pad_width = []
+ for dim_index in range(image_array.ndim):
+ if dim_index == hei_index:
+ pad_width.append((0, hei_pad))
+ elif dim_index == wid_index:
+ pad_width.append((0, wid_pad))
+ else:
+ pad_width.append((0, 0))
+ values = 0
+ image_array = \
+ np.pad(image_array,
+ pad_width,
+ mode='constant', constant_values=values)
+ return image_array
diff --git a/detrsmpl/utils/geometry.py b/detrsmpl/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cfb508f5f9305cc799a4258d76c0460f7b565d4
--- /dev/null
+++ b/detrsmpl/utils/geometry.py
@@ -0,0 +1,536 @@
+import numpy as np
+import torch
+from torch.nn import functional as F
+import torchgeometry as tgm
+
+def batch_rodrigues(theta):
+ """Convert axis-angle representation to rotation matrix.
+
+ Args:
+ theta: size = [B, 3]
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ l1norm = torch.norm(theta + 1e-8, p=2, dim=1)
+ angle = torch.unsqueeze(l1norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * normalized], dim=1)
+ return quat_to_rotmat(quat)
+
+
+def quat_to_rotmat(quat):
+ """Convert quaternion coefficients to rotation matrix.
+
+ Args:
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ norm_quat = quat
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
+ w = norm_quat[:, 0]
+ x = norm_quat[:, 1]
+ y = norm_quat[:, 2]
+ z = norm_quat[:, 3]
+ B = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w * x, w * y, w * z
+ xy, xz, yz = x * y, x * z, y * z
+
+ rotMat = torch.stack([
+ w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy,
+ w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz,
+ w2 - x2 - y2 + z2
+ ],
+ dim=1).view(B, 3, 3)
+ return rotMat
+
+
+def rot6d_to_rotmat(x):
+ """Convert 6D rotation representation to 3x3 rotation matrix.
+
+ Based on Zhou et al., "On the Continuity of Rotation
+ Representations in Neural Networks", CVPR 2019
+ Input:
+ (B,6) Batch of 6-D rotation representations
+ Output:
+ (B,3,3) Batch of corresponding rotation matrices
+ """
+ if isinstance(x, torch.Tensor):
+ x = x.reshape(-1, 3, 2)
+ elif isinstance(x, np.ndarray):
+ x = x.view(-1, 3, 2)
+ a1 = x[:, :, 0]
+ a2 = x[:, :, 1]
+ b1 = F.normalize(a1)
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
+ b3 = torch.cross(b1, b2)
+ return torch.stack((b1, b2, b3), dim=-1)
+
+def rot6d_to_axis_angle(x):
+ batch_size = x.shape[0]
+
+ x = x.view(-1, 3, 2)
+ a1 = x[:, :, 0]
+ a2 = x[:, :, 1]
+ b1 = F.normalize(a1)
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
+ b3 = torch.cross(b1, b2)
+ rot_mat = torch.stack((b1, b2, b3), dim=-1) # 3x3 rotation matrix
+
+ rot_mat = torch.cat([rot_mat, torch.zeros((batch_size, 3, 1)).cuda().float()], 2) # 3x4 rotation matrix
+ axis_angle = tgm.rotation_matrix_to_angle_axis(rot_mat).reshape(-1, 3) # axis-angle
+ axis_angle[torch.isnan(axis_angle)] = 0.0
+ return axis_angle
+
+def rotation_matrix_to_angle_axis(rotation_matrix):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+ Convert 3x4 rotation matrix to Rodrigues vector
+ Args:
+ rotation_matrix (Tensor): rotation matrix.
+ Returns:
+ Tensor: Rodrigues vector transformation.
+ Shape:
+ - Input: :math:`(N, 3, 4)`
+ - Output: :math:`(N, 3)`
+ Example:
+ >>> input = torch.rand(2, 3, 4) # Nx3x4
+ >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
+ """
+ if rotation_matrix.shape[1:] == (3, 3):
+ rot_mat = rotation_matrix.reshape(-1, 3, 3)
+ hom = torch.tensor([0, 0, 1],
+ dtype=torch.float32,
+ device=rotation_matrix.device)
+ hom = hom.reshape(1, 3, 1).expand(rot_mat.shape[0], -1, -1)
+ rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
+
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
+ aa = quaternion_to_angle_axis(quaternion)
+ aa[torch.isnan(aa)] = 0.0
+ return aa
+
+
+def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+ Convert quaternion vector to angle axis of rotation.
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
+ Args:
+ quaternion (torch.Tensor): tensor with quaternions.
+ Return:
+ torch.Tensor: tensor with angle axis of rotation.
+ Shape:
+ - Input: :math:`(*, 4)` where `*` means, any number of dimensions
+ - Output: :math:`(*, 3)`
+ Example:
+ >>> quaternion = torch.rand(2, 4) # Nx4
+ >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
+ """
+ if not torch.is_tensor(quaternion):
+ raise TypeError('Input type is not a torch.Tensor. Got {}'.format(
+ type(quaternion)))
+
+ if not quaternion.shape[-1] == 4:
+ raise ValueError(
+ 'Input must be a tensor of shape Nx4 or 4. Got {}'.format(
+ quaternion.shape))
+ # unpack input and compute conversion
+ q1: torch.Tensor = quaternion[..., 1]
+ q2: torch.Tensor = quaternion[..., 2]
+ q3: torch.Tensor = quaternion[..., 3]
+ sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
+
+ sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
+ cos_theta: torch.Tensor = quaternion[..., 0]
+ two_theta: torch.Tensor = 2.0 * torch.where(
+ cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta),
+ torch.atan2(sin_theta, cos_theta))
+
+ k_pos: torch.Tensor = two_theta / sin_theta
+ k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
+ k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
+
+ angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
+ angle_axis[..., 0] += q1 * k
+ angle_axis[..., 1] += q2 * k
+ angle_axis[..., 2] += q3 * k
+ return angle_axis
+
+
+def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+ Convert 3x4 rotation matrix to 4d quaternion vector
+ This algorithm is based on algorithm described in
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
+ Args:
+ rotation_matrix (Tensor): the rotation matrix to convert.
+ Return:
+ Tensor: the rotation in quaternion
+ Shape:
+ - Input: :math:`(N, 3, 4)`
+ - Output: :math:`(N, 4)`
+ Example:
+ >>> input = torch.rand(4, 3, 4) # Nx3x4
+ >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
+ """
+ if not torch.is_tensor(rotation_matrix):
+ raise TypeError('Input type is not a torch.Tensor. Got {}'.format(
+ type(rotation_matrix)))
+
+ if len(rotation_matrix.shape) > 3:
+ raise ValueError(
+ 'Input size must be a three dimensional tensor. Got {}'.format(
+ rotation_matrix.shape))
+ if not rotation_matrix.shape[-2:] == (3, 4):
+ raise ValueError(
+ 'Input size must be a N x 3 x 4 tensor. Got {}'.format(
+ rotation_matrix.shape))
+
+ rmat_t = torch.transpose(rotation_matrix, 1, 2)
+
+ mask_d2 = rmat_t[:, 2, 2] < eps
+
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
+
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
+ q0 = torch.stack([
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0,
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2]
+ ], -1)
+ t0_rep = t0.repeat(4, 1).t()
+
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
+ q1 = torch.stack([
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
+ t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]
+ ], -1)
+ t1_rep = t1.repeat(4, 1).t()
+
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
+ q2 = torch.stack([
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2
+ ], -1)
+ t2_rep = t2.repeat(4, 1).t()
+
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
+ q3 = torch.stack([
+ t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] - rmat_t[:, 1, 0]
+ ], -1)
+ t3_rep = t3.repeat(4, 1).t()
+
+ mask_c0 = mask_d2 * mask_d0_d1
+ mask_c1 = mask_d2 * ~mask_d0_d1
+ mask_c2 = ~mask_d2 * mask_d0_nd1
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
+
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
+ q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
+ t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
+ q *= 0.5
+ return q
+
+
+def perspective_projection(points, rotation, translation, focal_length,
+ camera_center):
+ """This function computes the perspective projection of a set of points.
+
+ Input:
+ points (bs, N, 3): 3D points
+ rotation (bs, 3, 3): Camera rotation
+ translation (bs, 3): Camera translation
+ focal_length (bs,) or scalar: Focal length
+ camera_center (bs, 2): Camera center
+ """
+ batch_size = points.shape[0]
+ K = torch.zeros([batch_size, 3, 3], device=points.device)
+ K[:, 0, 0] = focal_length
+ K[:, 1, 1] = focal_length
+ K[:, 2, 2] = 1.
+ K[:, :-1, -1] = camera_center
+
+ # Transform points
+ points = torch.einsum('bij,bkj->bki', rotation, points)
+ points = points + translation.unsqueeze(1)
+
+ # Apply perspective distortion
+ projected_points = points / points[:, :, -1].unsqueeze(-1)
+
+ # Apply camera intrinsics
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
+
+ return projected_points[:, :, :-1]
+
+
+def estimate_translation_np(S,
+ joints_2d,
+ joints_conf,
+ focal_length=5000,
+ img_size=224):
+ """Find camera translation that brings 3D joints S closest to 2D the
+ corresponding joints_2d.
+
+ Input:
+ S: (25, 3) 3D joint locations
+ joints: (25, 3) 2D joint locations and confidence
+ Returns:
+ (3,) camera translation vector
+ """
+
+ num_joints = S.shape[0]
+ # focal length
+ f = np.array([focal_length, focal_length])
+ # optical center
+ center = np.array([img_size / 2., img_size / 2.])
+
+ # transformations
+ Z = np.reshape(np.tile(S[:, 2], (2, 1)).T, -1)
+ XY = np.reshape(S[:, 0:2], -1)
+ OO = np.tile(center, num_joints)
+ F = np.tile(f, num_joints)
+ weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
+
+ # least squares
+ Q = np.array([
+ F * np.tile(np.array([1, 0]), num_joints),
+ F * np.tile(np.array([0, 1]), num_joints),
+ OO - np.reshape(joints_2d, -1)
+ ]).T
+ c = (np.reshape(joints_2d, -1) - OO) * Z - F * XY
+
+ # weighted least squares
+ W = np.diagflat(weight2)
+ Q = np.dot(W, Q)
+ c = np.dot(W, c)
+
+ # square matrix
+ A = np.dot(Q.T, Q)
+ b = np.dot(Q.T, c)
+
+ # solution
+ trans = np.linalg.solve(A, b)
+
+ return trans
+
+
+def estimate_translation(S, joints_2d, focal_length=5000., img_size=224.):
+ """Find camera translation that brings 3D joints S closest to 2D the
+ corresponding joints_2d.
+
+ Input:
+ S: (B, 49, 3) 3D joint locations
+ joints: (B, 49, 3) 2D joint locations and confidence
+ Returns:
+ (B, 3) camera translation vectors
+ """
+
+ device = S.device
+ # Use only joints 25:49 (GT joints)
+ S = S[:, 25:, :].cpu().numpy()
+ joints_2d = joints_2d[:, 25:, :].cpu().numpy()
+ joints_conf = joints_2d[:, :, -1]
+ joints_2d = joints_2d[:, :, :-1]
+ trans = np.zeros((S.shape[0], 3), dtype=np.float32)
+ # Find the translation for each example in the batch
+ for i in range(S.shape[0]):
+ S_i = S[i]
+ joints_i = joints_2d[i]
+ conf_i = joints_conf[i]
+ trans[i] = estimate_translation_np(S_i,
+ joints_i,
+ conf_i,
+ focal_length=focal_length,
+ img_size=img_size)
+ return torch.from_numpy(trans).to(device)
+
+
+def project_points(points_3d, camera, focal_length, img_res):
+ """Perform orthographic projection of 3D points using the camera
+ parameters, return projected 2D points in image plane.
+
+ Notes:
+ batch size: B
+ point number: N
+ Args:
+ points_3d (Tensor([B, N, 3])): 3D points.
+ camera (Tensor([B, 3])): camera parameters with the
+ 3 channel as (scale, translation_x, translation_y)
+ Returns:
+ points_2d (Tensor([B, N, 2])): projected 2D points
+ in image space.
+ """
+ batch_size = points_3d.shape[0]
+ device = points_3d.device
+ cam_t = torch.stack([
+ camera[:, 1], camera[:, 2], 2 * focal_length /
+ (img_res * camera[:, 0] + 1e-9)
+ ],
+ dim=-1)
+ camera_center = camera.new_zeros([batch_size, 2])
+ rot_t = torch.eye(3, device=device,
+ dtype=points_3d.dtype).unsqueeze(0).expand(
+ batch_size, -1, -1)
+ keypoints_2d = perspective_projection(points_3d,
+ rotation=rot_t,
+ translation=cam_t,
+ focal_length=focal_length,
+ camera_center=camera_center)
+ return keypoints_2d
+
+def project_points_new(points_3d, pred_cam, focal_length, camera_center):
+ """Perform orthographic projection of 3D points using the camera
+ parameters, return projected 2D points in image plane.
+
+ Notes:
+ batch size: B
+ point number: N
+ Args:
+ points_3d (Tensor([B, N, 3])): 3D points.
+ camera (Tensor([B, 3])): camera parameters with the
+ 3 channel as (scale, translation_x, translation_y)
+ Returns:
+ points_2d (Tensor([B, N, 2])): projected 2D points
+ in image space.
+ """
+ batch_size = points_3d.shape[0]
+ device = points_3d.device
+
+ (s, tx, ty) = (pred_cam[:, 0] + 1e-9), pred_cam[:, 1], pred_cam[:, 2]
+ depth, dx, dy = 1./s, tx/s, ty/s
+ cam_t = torch.stack([dx, dy, depth], 1)
+
+ # cam_t = torch.stack([
+ # camera[:, 1], camera[:, 2], 2 * focal_length /
+ # (img_res * camera[:, 0] + 1e-9)
+ # ],
+ # dim=-1)
+ rot_t = torch.eye(3, device=device,
+ dtype=points_3d.dtype).unsqueeze(0).expand(
+ batch_size, -1, -1)
+ keypoints_2d = perspective_projection(points_3d,
+ rotation=rot_t,
+ translation=cam_t,
+ focal_length=focal_length,
+ camera_center=camera_center)
+ return keypoints_2d
+
+
+
+
+def weak_perspective_projection(points, scale, translation):
+ """This function computes the weak perspective projection of a set of
+ points.
+
+ Input:
+ points (bs, N, 3): 3D points
+ scale (bs,1): scalar
+ translation (bs, 2): point 2D translation
+ """
+ projected_points = scale.view(
+ -1, 1, 1) * (points[:, :, :2] + translation.view(-1, 1, 2))
+
+ return projected_points
+
+
+def estimate_cam_weakperspective(joints3d,
+ joints2d,
+ joints2d_conf,
+ joints3d_conf,
+ img_size) -> torch.Tensor:
+ '''
+ img_size: wh
+ '''
+ w, h = img_size
+ if joints2d_conf is not None:
+ valid_ids = torch.where(joints2d_conf.view(-1) > 0)[0]
+ joints2d = joints2d[valid_ids]
+ if joints3d_conf is not None:
+ valid_ids = torch.where(joints3d_conf.view(-1) > 0)[0]
+ joints3d = joints3d[valid_ids]
+ x1 = torch.min(joints3d[..., 0])
+ x2 = torch.max(joints3d[..., 0])
+
+ y1 = torch.min(joints3d[..., 1])
+ y2 = torch.max(joints3d[..., 1])
+
+ # img_size = img_size if isinstance(img_size, int) else int(img_size[0])
+
+ u1 = 2*torch.min(joints2d[..., 0]) / w -1
+ u2 = 2*torch.max(joints2d[..., 0]) / w -1
+ v1 = (2 * torch.min(joints2d[..., 1])-h)/max(w,h)
+ v2 = (2 * torch.max(joints2d[..., 1])-h)/max(w,h)
+
+ # u1 = torch.min(joints2d[..., 0]) / w
+ # u2 = torch.max(joints2d[..., 0]) / w
+ # v1 = torch.min(joints2d[..., 1]) / h
+ # v2 = torch.max(joints2d[..., 1]) / h
+
+ sx = (u1 - u2) / (x1 - x2)
+ sy = (v1 - v2) / (y1 - y2)
+ s = torch.sqrt(sx * sy)
+
+ tx_1 = u1 / s - x1 # u1 = s*(tx_1 + x1)
+ ty_1 = v1 / s - y1 # v1 = s*(ty_1 + y1)
+
+ tx_2 = u2 / s - x2 # u2 = s*(tx_2 + x2)
+ ty_2 = v2 / s - y2 # v2 = s*(ty_2 + y2)
+
+ tx = (tx_1 + tx_2) / 2
+ ty = (ty_1 + ty_2) / 2
+ cam = torch.Tensor([s, tx, ty]).view(3)
+ return cam
+
+def estimate_cam_weakperspective_batch(
+ joints3d, joints2d,
+ joints2d_conf, joints3d_conf,
+ img_size):
+ '''
+ img_size: b,w,h
+ '''
+ device = joints3d.device
+ joints2d = joints2d.detach().cpu()
+ joints3d = joints3d.detach().cpu()
+
+ assert joints2d.ndim == 3 # B, J, 2
+ assert joints3d.ndim == 3 # B, J, 3
+
+ cam = torch.zeros(joints3d.shape[0], 3)
+ for i in range(joints3d.shape[0]):
+ joints3d_i = joints3d[i]
+ joints2d_i = joints2d[i]
+ if joints2d_conf is not None:
+ conf2d_i = joints2d_conf[i].detach().cpu()
+ else:
+ conf2d_i = None
+
+ if joints3d_conf is not None:
+ conf3d_i = joints3d_conf[i].detach().cpu()
+ else:
+ conf3d_i = None
+ cam[i] = estimate_cam_weakperspective(joints3d=joints3d_i,
+ joints2d=joints2d_i,
+ joints2d_conf=conf2d_i,
+ joints3d_conf=conf3d_i,
+ img_size=img_size[i])
+ return cam.to(device)
+
+def pred_cam_to_transl(pred_camera, focal_length, img_size):
+ pred_cam_t = torch.stack([
+ pred_camera[:, 1], pred_camera[:, 2], 2 * focal_length /
+ (img_size * pred_camera[:, 0] + 1e-9)
+ ],
+ dim=-1)
+ return pred_cam_t
\ No newline at end of file
diff --git a/detrsmpl/utils/keypoint_utils.py b/detrsmpl/utils/keypoint_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cacfd85a2a31ff0503a25dbb99ed057a379f7d2
--- /dev/null
+++ b/detrsmpl/utils/keypoint_utils.py
@@ -0,0 +1,61 @@
+from typing import Optional, Tuple, Union
+
+import numpy as np
+
+from detrsmpl.core.conventions.keypoints_mapping import KEYPOINTS_FACTORY
+from detrsmpl.core.conventions.keypoints_mapping.human_data import (
+ HUMAN_DATA_LIMBS_INDEX,
+ HUMAN_DATA_PALETTE,
+)
+
+
+def search_limbs(
+ data_source: str,
+ mask: Optional[Union[np.ndarray, tuple, list]] = None,
+ keypoints_factory: dict = KEYPOINTS_FACTORY) -> Tuple[dict, dict]:
+ """Search the corresponding limbs following the basis human_data limbs. The
+ mask could mask out the incorrect keypoints.
+
+ Args:
+ data_source (str): data source type.
+ mask (Optional[Union[np.ndarray, tuple, list]], optional):
+ refer to keypoints_mapping. Defaults to None.
+ keypoints_factory (dict, optional): Dict of all the conventions.
+ Defaults to KEYPOINTS_FACTORY.
+ Returns:
+ Tuple[dict, dict]: (limbs_target, limbs_palette).
+ """
+ limbs_source = HUMAN_DATA_LIMBS_INDEX
+ limbs_palette = HUMAN_DATA_PALETTE
+ keypoints_source = keypoints_factory['human_data']
+ keypoints_target = keypoints_factory[data_source]
+ limbs_target = {}
+ for k, part_limbs in limbs_source.items():
+ limbs_target[k] = []
+ for limb in part_limbs:
+ flag = False
+ if (keypoints_source[limb[0]]
+ in keypoints_target) and (keypoints_source[limb[1]]
+ in keypoints_target):
+ if mask is not None:
+ if mask[keypoints_target.index(keypoints_source[
+ limb[0]])] != 0 and mask[keypoints_target.index(
+ keypoints_source[limb[1]])] != 0:
+ flag = True
+ else:
+ flag = True
+ if flag:
+ limbs_target.setdefault(k, []).append([
+ keypoints_target.index(keypoints_source[limb[0]]),
+ keypoints_target.index(keypoints_source[limb[1]])
+ ])
+ if k in limbs_target:
+ if k == 'body':
+ np.random.seed(0)
+ limbs_palette[k] = np.random.randint(0,
+ high=255,
+ size=(len(
+ limbs_target[k]), 3))
+ else:
+ limbs_palette[k] = np.array(limbs_palette[k])
+ return limbs_target, limbs_palette
diff --git a/detrsmpl/utils/logger.py b/detrsmpl/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ca1e9451b5ac8dc5278d448d5e916dcb7ed525c
--- /dev/null
+++ b/detrsmpl/utils/logger.py
@@ -0,0 +1,7 @@
+import logging
+
+from mmcv.utils import get_logger
+
+
+def get_root_logger(log_file=None, log_level=logging.INFO):
+ return get_logger('mmhuman3d', log_file, log_level)
diff --git a/detrsmpl/utils/mesh_utils.py b/detrsmpl/utils/mesh_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..00ed5f253ccc99d5cb15c6ed0a68729385dc9e4d
--- /dev/null
+++ b/detrsmpl/utils/mesh_utils.py
@@ -0,0 +1,236 @@
+import warnings
+from typing import List, Optional, Union
+
+import torch
+from pytorch3d.io import IO
+from pytorch3d.io import load_objs_as_meshes as _load_objs_as_meshes
+from pytorch3d.io import save_obj
+from pytorch3d.renderer import TexturesUV, TexturesVertex
+from pytorch3d.structures import (
+ Meshes,
+ Pointclouds,
+ join_meshes_as_batch,
+ join_meshes_as_scene,
+ padded_to_list,
+)
+
+from .path_utils import prepare_output_path
+
+
+def join_batch_meshes_as_scene(
+ meshes: List[Meshes],
+ include_textures: bool = True,
+) -> Meshes:
+ """Join `meshes` as a scene each batch. Only for Pytorch3D `meshes`. The
+ Meshes must share the same batch size, and topology could be different.
+ They must all be on the same device. If `include_textures` is true, the
+ textures should be the same type, all be None is not accepted. If
+ `include_textures` is False, textures are ignored. The return meshes will
+ have no textures.
+
+ Args:
+ meshes (List[Meshes]): A `list` of `Meshes` with the same batches.
+ Required.
+ include_textures: (bool) whether to try to join the textures.
+
+ Returns:
+ New Meshes which has join different Meshes by each batch.
+ """
+ for mesh in meshes:
+ mesh._verts_list = padded_to_list(mesh.verts_padded(),
+ mesh.num_verts_per_mesh().tolist())
+ num_scene_size = len(meshes)
+ num_batch_size = len(meshes[0])
+ for i in range(num_scene_size):
+ assert len(
+ meshes[i]
+ ) == num_batch_size, 'Please make sure that the Meshes all have'
+ 'the same batch size.'
+ meshes_all = []
+ for j in range(num_batch_size):
+ meshes_batch = []
+ for i in range(num_scene_size):
+ meshes_batch.append(meshes[i][j])
+ meshes_all.append(join_meshes_as_scene(meshes_batch, include_textures))
+ meshes_final = join_meshes_as_batch(meshes_all, include_textures)
+ return meshes_final
+
+
+def mesh_to_pointcloud_vc(
+ meshes: Meshes,
+ include_textures: bool = True,
+ alpha: float = 1.0,
+) -> Pointclouds:
+ """Convert PyTorch3D vertex color `Meshes` to `PointClouds`.
+
+ Args:
+ meshes (Meshes): input meshes.
+ include_textures (bool, optional): Whether include colors.
+ Require the texture of input meshes is vertex color.
+ Defaults to True.
+ alpha (float, optional): transparency.
+ Defaults to 1.0.
+
+ Returns:
+ Pointclouds: output pointclouds.
+ """
+ assert isinstance(
+ meshes.textures,
+ TexturesVertex), 'textures of input meshes should be `TexturesVertex`'
+ vertices = meshes.verts_padded()
+ if include_textures:
+ verts_rgb = meshes.textures.verts_features_padded()
+ verts_rgba = torch.cat(
+ [verts_rgb,
+ torch.ones_like(verts_rgb)[..., 0:1] * alpha], dim=-1)
+ else:
+ verts_rgba = None
+ pointclouds = Pointclouds(points=vertices, features=verts_rgba)
+ return pointclouds
+
+
+def texture_uv2vc(meshes: Meshes) -> Meshes:
+ """Convert a Pytorch3D meshes's textures from TexturesUV to TexturesVertex.
+
+ Args:
+ meshes (Meshes): input Meshes.
+
+ Returns:
+ Meshes: converted Meshes.
+ """
+ assert isinstance(meshes.textures, TexturesUV)
+ device = meshes.device
+ vert_uv = meshes.textures.verts_uvs_padded()
+ batch_size = vert_uv.shape[0]
+ verts_features = []
+ num_verts = meshes.verts_padded().shape[1]
+ for index in range(batch_size):
+ face_uv = vert_uv[index][meshes.textures.faces_uvs_padded()
+ [index].view(-1)]
+
+ img = meshes.textures._maps_padded[index]
+ width, height, _ = img.shape
+
+ face_uv = face_uv * torch.Tensor([width - 1, height - 1
+ ]).long().to(device)
+
+ face_uv[:, 0] = torch.clip(face_uv[:, 0], 0, width - 1)
+ face_uv[:, 1] = torch.clip(face_uv[:, 1], 0, height - 1)
+ face_uv = face_uv.long()
+ faces = meshes.faces_padded()
+ verts_rgb = torch.zeros(1, num_verts, 3).to(device)
+ verts_rgb[:, faces.view(-1)] = img[height - 1 - face_uv[:, 1],
+ face_uv[:, 0]]
+ verts_features.append(verts_rgb)
+ verts_features = torch.cat(verts_features)
+
+ meshes = meshes.clone()
+ meshes.textures = TexturesVertex(verts_features)
+ return meshes
+
+
+def load_objs_as_meshes(files: List[str],
+ device: Optional[Union[torch.device, str]] = None,
+ load_textures: bool = True,
+ **kwargs) -> Meshes:
+ if not isinstance(files, list):
+ files = [files]
+ return _load_objs_as_meshes(files=files,
+ device=device,
+ load_textures=load_textures,
+ **kwargs)
+
+
+def load_plys_as_meshes(
+ files: List[str],
+ device: Optional[Union[torch.device, str]] = None,
+ load_textures: bool = True,
+) -> Meshes:
+ writer = IO()
+ meshes = []
+ if not isinstance(files, list):
+ files = [files]
+ for idx in range(len(files)):
+ assert files[idx].endswith('.ply'), 'Please input .ply files.'
+ mesh = writer.load_mesh(path=files[idx],
+ include_textures=load_textures,
+ device=device)
+ meshes.append(mesh)
+ meshes = join_meshes_as_batch(meshes, include_textures=load_textures)
+ return meshes
+
+
+def save_meshes_as_plys(files: List[str],
+ meshes: Meshes = None,
+ verts: torch.Tensor = None,
+ faces: torch.Tensor = None,
+ verts_rgb: torch.Tensor = None) -> None:
+ """Save meshes as .ply files. Mainly for vertex color meshes.
+
+ Args:
+ files (List[str]): Output .ply file list.
+ meshes (Meshes, optional): higher priority than
+ (verts & faces & verts_rgb). Defaults to None.
+ verts (torch.Tensor, optional): lower priority than meshes.
+ Defaults to None.
+ faces (torch.Tensor, optional): lower priority than meshes.
+ Defaults to None.
+ verts_rgb (torch.Tensor, optional): lower priority than meshes.
+ Defaults to None.
+ """
+ if meshes is None:
+ assert verts is not None and faces is not None, 'Not mesh input.'
+ meshes = Meshes(
+ verts=verts,
+ faces=faces,
+ textures=TexturesVertex(
+ verts_features=verts_rgb) if verts_rgb is not None else None)
+ else:
+ if verts is not None or faces is not None or verts_rgb is not None:
+ warnings.warn('Redundant input, will use meshes only.')
+ assert files is not None
+ if not isinstance(files, list):
+ files = [files]
+ assert len(files) >= len(meshes), 'Not enough output files.'
+ writer = IO()
+ for idx in range(len(meshes)):
+ assert files[idx].endswith('.ply'), 'Please save as .ply files.'
+ writer.save_mesh(meshes[idx],
+ files[idx],
+ colors_as_uint8=True,
+ binary=False)
+
+
+def save_meshes_as_objs(files: List[str], meshes: Meshes = None) -> None:
+ """Save meshes as .obj files. Pytorch3D will not save vertex color for.
+
+ .obj, please use `save_meshes_as_plys`.
+
+ Args:
+ files (List[str]): Output .obj file list.
+ meshes (Meshes, optional):
+ Defaults to None.
+ """
+ if not isinstance(files, list):
+ files = [files]
+
+ assert len(files) >= len(meshes), 'Not enough output files.'
+
+ for idx in range(len(meshes)):
+ prepare_output_path(files[idx],
+ allowed_suffix=['.obj'],
+ path_type='file'), 'Please save as .obj files.'
+ if isinstance(meshes.textures, TexturesUV):
+ verts_uvs = meshes.textures.verts_uvs_padded()[idx]
+ faces_uvs = meshes.textures.faces_uvs_padded()[idx]
+ texture_map = meshes.textures.maps_padded()[idx]
+ else:
+ verts_uvs = None
+ faces_uvs = None
+ texture_map = None
+ save_obj(f=files[idx],
+ verts=meshes.verts_padded()[idx],
+ faces=meshes.faces_padded()[idx],
+ verts_uvs=verts_uvs,
+ faces_uvs=faces_uvs,
+ texture_map=texture_map)
diff --git a/detrsmpl/utils/misc.py b/detrsmpl/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e98eb2fb3983d9bcf726fd7fd172e84a93c5415a
--- /dev/null
+++ b/detrsmpl/utils/misc.py
@@ -0,0 +1,30 @@
+from functools import partial
+
+import torch
+
+
+def multi_apply(func, *args, **kwargs):
+ """Apply function to a list of arguments.
+
+ Note:
+ This function applies the ``func`` to multiple inputs and
+ map the multiple outputs of the ``func`` into different
+ list. Each list contains the same type of outputs corresponding
+ to different inputs.
+
+ Args:
+ func (Function): A function that will be applied to a list of
+ arguments
+
+ Returns:
+ tuple(list): A tuple containing multiple list, each list contains \
+ a kind of returned results by the function
+ """
+ pfunc = partial(func, **kwargs) if kwargs else func
+ map_results = map(pfunc, *args)
+ return tuple(map(list, zip(*map_results)))
+
+
+def torch_to_numpy(x):
+ assert isinstance(x, torch.Tensor)
+ return x.detach().cpu().numpy()
diff --git a/detrsmpl/utils/path_utils.py b/detrsmpl/utils/path_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..36fb7f69e79d691a81eb5b6faad592e961878bea
--- /dev/null
+++ b/detrsmpl/utils/path_utils.py
@@ -0,0 +1,232 @@
+import os
+import warnings
+from enum import Enum
+from pathlib import Path
+from typing import List, Union
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+
+def check_path_suffix(path_str: str,
+ allowed_suffix: Union[str, List[str]] = '') -> bool:
+ """Check whether the suffix of the path is allowed.
+
+ Args:
+ path_str (str):
+ Path to check.
+ allowed_suffix (List[str], optional):
+ What extension names are allowed.
+ Offer a list like ['.jpg', ',jpeg'].
+ When it's [], all will be received.
+ Use [''] then directory is allowed.
+ Defaults to [].
+
+ Returns:
+ bool:
+ True: suffix test passed
+ False: suffix test failed
+ """
+ if isinstance(allowed_suffix, str):
+ allowed_suffix = [allowed_suffix]
+ pathinfo = Path(path_str)
+ suffix = pathinfo.suffix.lower()
+ if len(allowed_suffix) == 0:
+ return True
+ if pathinfo.is_dir():
+ if '' in allowed_suffix:
+ return True
+ else:
+ return False
+ else:
+ for index, tmp_suffix in enumerate(allowed_suffix):
+ if not tmp_suffix.startswith('.'):
+ tmp_suffix = '.' + tmp_suffix
+ allowed_suffix[index] = tmp_suffix.lower()
+ if suffix in allowed_suffix:
+ return True
+ else:
+ return False
+
+
+class Existence(Enum):
+ """State of file existence."""
+ FileExist = 0
+ DirectoryExistEmpty = 1
+ DirectoryExistNotEmpty = 2
+ MissingParent = 3
+ DirectoryNotExist = 4
+ FileNotExist = 5
+
+
+def check_path_existence(
+ path_str: str,
+ path_type: Literal['file', 'dir', 'auto'] = 'auto',
+) -> Existence:
+ """Check whether a file or a directory exists at the expected path.
+
+ Args:
+ path_str (str):
+ Path to check.
+ path_type (Literal[, optional):
+ What kind of file do we expect at the path.
+ Choose among `file`, `dir`, `auto`.
+ Defaults to 'auto'. path_type = path_type.lower()
+
+ Raises:
+ KeyError: if `path_type` conflicts with `path_str`
+
+ Returns:
+ Existence:
+ 0. FileExist: file at path_str exists.
+ 1. DirectoryExistEmpty: folder at path exists and.
+ 2. DirectoryExistNotEmpty: folder at path_str exists and not empty.
+ 3. MissingParent: its parent doesn't exist.
+ 4. DirectoryNotExist: expect a folder at path_str, but not found.
+ 5. FileNotExist: expect a file at path_str, but not found.
+ """
+ path_type = path_type.lower()
+ assert path_type in {'file', 'dir', 'auto'}
+ pathinfo = Path(path_str)
+ if not pathinfo.parent.is_dir():
+ return Existence.MissingParent
+ suffix = pathinfo.suffix.lower()
+ if path_type == 'dir' or\
+ path_type == 'auto' and suffix == '':
+ if pathinfo.is_dir():
+ if len(os.listdir(path_str)) == 0:
+ return Existence.DirectoryExistEmpty
+ else:
+ return Existence.DirectoryExistNotEmpty
+ else:
+ return Existence.DirectoryNotExist
+ elif path_type == 'file' or\
+ path_type == 'auto' and suffix != '':
+ if pathinfo.is_file():
+ return Existence.FileExist
+ elif pathinfo.is_dir():
+ if len(os.listdir(path_str)) == 0:
+ return Existence.DirectoryExistEmpty
+ else:
+ return Existence.DirectoryExistNotEmpty
+ if path_str.endswith('/'):
+ return Existence.DirectoryNotExist
+ else:
+ return Existence.FileNotExist
+
+
+def prepare_output_path(output_path: str,
+ allowed_suffix: List[str] = [],
+ tag: str = 'output file',
+ path_type: Literal['file', 'dir', 'auto'] = 'auto',
+ overwrite: bool = True) -> None:
+ """Check output folder or file.
+
+ Args:
+ output_path (str): could be folder or file.
+ allowed_suffix (List[str], optional):
+ Check the suffix of `output_path`. If folder, should be [] or [''].
+ If could both be folder or file, should be [suffixs..., ''].
+ Defaults to [].
+ tag (str, optional): The `string` tag to specify the output type.
+ Defaults to 'output file'.
+ path_type (Literal[, optional):
+ Choose `file` for file and `dir` for folder.
+ Choose `auto` if allowed to be both.
+ Defaults to 'auto'.
+ overwrite (bool, optional):
+ Whether overwrite the existing file or folder.
+ Defaults to True.
+
+ Raises:
+ FileNotFoundError: suffix does not match.
+ FileExistsError: file or folder already exists and `overwrite` is
+ False.
+
+ Returns:
+ None
+ """
+ if path_type.lower() == 'dir':
+ allowed_suffix = []
+ exist_result = check_path_existence(output_path, path_type=path_type)
+ if exist_result == Existence.MissingParent:
+ warnings.warn(
+ f'The parent folder of {tag} does not exist: {output_path},' +
+ f' will make dir {Path(output_path).parent.absolute().__str__()}')
+ os.makedirs(Path(output_path).parent.absolute().__str__(),
+ exist_ok=True)
+
+ elif exist_result == Existence.DirectoryNotExist:
+ os.mkdir(output_path)
+ print(f'Making directory {output_path} for saving results.')
+ elif exist_result == Existence.FileNotExist:
+ suffix_matched = \
+ check_path_suffix(output_path, allowed_suffix=allowed_suffix)
+ if not suffix_matched:
+ raise FileNotFoundError(
+ f'The {tag} should be {", ".join(allowed_suffix)}: '
+ f'{output_path}.')
+ elif exist_result == Existence.FileExist:
+ if not overwrite:
+ raise FileExistsError(
+ f'{output_path} exists (set overwrite = True to overwrite).')
+ else:
+ print(f'Overwriting {output_path}.')
+ elif exist_result == Existence.DirectoryExistEmpty:
+ pass
+ elif exist_result == Existence.DirectoryExistNotEmpty:
+ if not overwrite:
+ raise FileExistsError(
+ f'{output_path} is not empty (set overwrite = '
+ 'True to overwrite the files).')
+ else:
+ print(f'Overwriting {output_path} and its files.')
+ else:
+ raise FileNotFoundError(f'No Existence type for {output_path}.')
+
+
+def check_input_path(
+ input_path: str,
+ allowed_suffix: List[str] = [],
+ tag: str = 'input file',
+ path_type: Literal['file', 'dir', 'auto'] = 'auto',
+):
+ """Check input folder or file.
+
+ Args:
+ input_path (str): input folder or file path.
+ allowed_suffix (List[str], optional):
+ Check the suffix of `input_path`. If folder, should be [] or [''].
+ If could both be folder or file, should be [suffixs..., ''].
+ Defaults to [].
+ tag (str, optional): The `string` tag to specify the output type.
+ Defaults to 'output file'.
+ path_type (Literal[, optional):
+ Choose `file` for file and `directory` for folder.
+ Choose `auto` if allowed to be both.
+ Defaults to 'auto'.
+
+ Raises:
+ FileNotFoundError: file does not exists or suffix does not match.
+
+ Returns:
+ None
+ """
+ if path_type.lower() == 'dir':
+ allowed_suffix = []
+ exist_result = check_path_existence(input_path, path_type=path_type)
+
+ if exist_result in [
+ Existence.FileExist, Existence.DirectoryExistEmpty,
+ Existence.DirectoryExistNotEmpty
+ ]:
+ suffix_matched = \
+ check_path_suffix(input_path, allowed_suffix=allowed_suffix)
+ if not suffix_matched:
+ raise FileNotFoundError(
+ f'The {tag} should be {", ".join(allowed_suffix)}:' +
+ f'{input_path}.')
+ else:
+ raise FileNotFoundError(f'The {tag} does not exist: {input_path}.')
diff --git a/detrsmpl/utils/tmp b/detrsmpl/utils/tmp
new file mode 100644
index 0000000000000000000000000000000000000000..9e101f41c51b0a7396301add5c27f04ba9f96044
--- /dev/null
+++ b/detrsmpl/utils/tmp
@@ -0,0 +1,1374 @@
+import glob
+import json
+import os
+import shutil
+import string
+import subprocess
+import sys
+from pathlib import Path
+from typing import Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+
+from mmhuman3d.utils.path_utils import check_input_path, prepare_output_path
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+
+class video_writer:
+
+ def __init__(self,
+ output_path: str,
+ resolution: Iterable[int],
+ fps: float = 30.0,
+ num_frame: int = 1e9,
+ disable_log: bool = False) -> None:
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+ height, width = resolution
+ width += width % 2
+ height += height % 2
+ command = [
+ 'ffmpeg',
+ '-y', # (optional) overwrite output file if it exists
+ '-f',
+ 'rawvideo',
+ '-pix_fmt',
+ 'bgr24',
+ '-s',
+ f'{int(width)}x{int(height)}',
+ '-r',
+ f'{fps}', # frames per second
+ '-loglevel',
+ 'error',
+ '-threads',
+ '1',
+ '-i',
+ '-', # The input comes from a pipe
+ '-vcodec',
+ 'libx264',
+ '-r',
+ f'{fps}', # frames per second
+ '-an', # Tells FFMPEG not to expect any audio
+ output_path,
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ process = subprocess.Popen(
+ command,
+ stdin=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ if process.stdin is None or process.stderr is None:
+ raise BrokenPipeError('No buffer received.')
+ self.process = process
+ self.num_frame = num_frame
+ self.len = 0
+
+ def write(self, image_array: np.ndarray):
+ if self.len <= self.num_frame:
+ try:
+ self.process.stdin.write(image_array.tobytes())
+ self.len += 1
+ except KeyboardInterrupt:
+ self.__del__()
+
+ def __del__(self):
+ self.process.stdin.close()
+ self.process.stderr.close()
+ self.process.wait()
+
+
+def array_to_video(
+ image_array: np.ndarray,
+ output_path: str,
+ fps: Union[int, float] = 30,
+ resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
+ disable_log: bool = False,
+) -> None:
+ """Convert an array to a video directly, gif not supported.
+
+ Args:
+ image_array (np.ndarray): shape should be (f * h * w * 3).
+ output_path (str): output video file path.
+ fps (Union[int, float, optional): fps. Defaults to 30.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of the output video.
+ Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check output path.
+ TypeError: check input array.
+
+ Returns:
+ None.
+ """
+ if not isinstance(image_array, np.ndarray):
+ raise TypeError('Input should be np.ndarray.')
+ assert image_array.ndim == 4
+ assert image_array.shape[-1] == 3
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+ if resolution:
+ height, width = resolution
+ width += width % 2
+ height += height % 2
+ else:
+ image_array = pad_for_libx264(image_array)
+ height, width = image_array.shape[1], image_array.shape[2]
+ command = [
+ 'ffmpeg',
+ '-y', # (optional) overwrite output file if it exists
+ '-f',
+ 'rawvideo',
+ '-s',
+ f'{int(width)}x{int(height)}', # size of one frame
+ '-pix_fmt',
+ 'bgr24',
+ '-r',
+ f'{fps}', # frames per second
+ '-loglevel',
+ 'error',
+ '-threads',
+ '4',
+ ''
+ '-i',
+ '-', # The input comes from a pipe
+ '-vcodec',
+ 'libx264',
+ '-an', # Tells FFMPEG not to expect any audio
+ output_path,
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ process = subprocess.Popen(
+ command,
+ stdin=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ if process.stdin is None or process.stderr is None:
+ raise BrokenPipeError('No buffer received.')
+ index = 0
+ while True:
+ if index >= image_array.shape[0]:
+ break
+ process.stdin.write(image_array[index].tobytes())
+ index += 1
+ process.stdin.close()
+ process.stderr.close()
+ process.wait()
+
+
+def array_to_images(
+ image_array: np.ndarray,
+ output_folder: str,
+ img_format: str = '%06d.png',
+ resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
+ disable_log: bool = False,
+) -> None:
+ """Convert an array to images directly.
+
+ Args:
+ image_array (np.ndarray): shape should be (f * h * w * 3).
+ output_folder (str): output folder for the images.
+ img_format (str, optional): format of the images.
+ Defaults to '%06d.png'.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): resolution(height, width) of output.
+ Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+
+ Raises:
+ FileNotFoundError: check output folder.
+ TypeError: check input array.
+
+ Returns:
+ None
+ """
+ prepare_output_path(
+ output_folder,
+ allowed_suffix=[],
+ tag='output image folder',
+ path_type='dir',
+ overwrite=True)
+
+ if not isinstance(image_array, np.ndarray):
+ raise TypeError('Input should be np.ndarray.')
+ assert image_array.ndim == 4
+ assert image_array.shape[-1] == 3
+ if resolution:
+ height, width = resolution
+ else:
+ height, width = image_array.shape[1], image_array.shape[2]
+ command = [
+ 'ffmpeg',
+ '-y', # (optional) overwrite output file if it exists
+ '-f',
+ 'rawvideo',
+ '-s',
+ f'{int(width)}x{int(height)}', # size of one frame
+ '-pix_fmt',
+ 'bgr24', # bgr24 for matching OpenCV
+ '-loglevel',
+ 'error',
+ '-threads',
+ '4',
+ '-i',
+ '-', # The input comes from a pipe
+ '-f',
+ 'image2',
+ '-start_number',
+ '0',
+ os.path.join(output_folder, img_format),
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ process = subprocess.Popen(
+ command,
+ stdin=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ bufsize=10**8,
+ close_fds=True)
+ if process.stdin is None or process.stderr is None:
+ raise BrokenPipeError('No buffer received.')
+ index = 0
+ while True:
+ if index >= image_array.shape[0]:
+ break
+ process.stdin.write(image_array[index].tobytes())
+ index += 1
+ process.stdin.close()
+ process.stderr.close()
+ process.wait()
+
+
+def video_to_array(
+ input_path: str,
+ resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
+ start: int = 0,
+ end: Optional[int] = None,
+ disable_log: bool = False,
+) -> np.ndarray:
+ """
+ Read a video/gif as an array of (f * h * w * 3).
+
+ Args:
+ input_path (str): input path.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): resolution(height, width) of output.
+ Defaults to None.
+ start (int, optional): start frame index. Inclusive.
+ If < 0, will be converted to frame_index range in [0, frame_num].
+ Defaults to 0.
+ end (int, optional): end frame index. Exclusive.
+ Could be positive int or negative int or None.
+ If None, all frames from start till the last frame are included.
+ Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+
+ Raises:
+ FileNotFoundError: check the input path.
+
+ Returns:
+ np.ndarray: shape will be (f * h * w * 3).
+ """
+ check_input_path(
+ input_path,
+ allowed_suffix=['.mp4', 'mkv', 'avi', '.gif'],
+ tag='input video',
+ path_type='file')
+
+ info = vid_info_reader(input_path)
+ if resolution:
+ height, width = resolution
+ else:
+ width, height = int(info['width']), int(info['height'])
+ num_frames = int(info['nb_frames'])
+ start = (min(start, num_frames - 1) + num_frames) % num_frames
+ end = (min(end, num_frames - 1) +
+ num_frames) % num_frames if end is not None else num_frames
+ command = [
+ 'ffmpeg',
+ '-i',
+ input_path,
+ '-filter_complex',
+ f'[0]trim=start_frame={start}:end_frame={end}[v0]',
+ '-map',
+ '[v0]',
+ '-pix_fmt',
+ 'bgr24', # bgr24 for matching OpenCV
+ '-s',
+ f'{int(width)}x{int(height)}',
+ '-f',
+ 'image2pipe',
+ '-vcodec',
+ 'rawvideo',
+ '-loglevel',
+ 'error',
+ 'pipe:'
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ # Execute FFmpeg as sub-process with stdout as a pipe
+ process = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=10**8)
+ if process.stdout is None:
+ raise BrokenPipeError('No buffer received.')
+ # Read decoded video frames from the PIPE until no more frames to read
+ array = []
+ while True:
+ # Read decoded video frame (in raw video format) from stdout process.
+ buffer = process.stdout.read(int(width * height * 3))
+ # Break the loop if buffer length is not W*H*3\
+ # (when FFmpeg streaming ends).
+ if len(buffer) != width * height * 3:
+ break
+ img = np.frombuffer(buffer, np.uint8).reshape(height, width, 3)
+ array.append(img[np.newaxis])
+ process.stdout.flush()
+ process.stdout.close()
+ process.wait()
+ return np.concatenate(array)
+
+
+def images_to_sorted_images(input_folder, output_folder, img_format='%06d'):
+ """Copy and rename a folder of images into a new folder following the
+ `img_format`.
+
+ Args:
+ input_folder (str): input folder.
+ output_folder (str): output folder.
+ img_format (str, optional): image format name, do not need extension.
+ Defaults to '%06d'.
+
+ Returns:
+ str: image format of the rename images.
+ """
+ img_format = img_format.rsplit('.', 1)[0]
+ file_list = []
+ os.makedirs(output_folder, exist_ok=True)
+ pngs = glob.glob(os.path.join(input_folder, '*.png'))
+ if pngs:
+ ext = 'png'
+ file_list.extend(pngs)
+ jpgs = glob.glob(os.path.join(input_folder, '*.jpg'))
+ if jpgs:
+ ext = 'jpg'
+ file_list.extend(jpgs)
+ file_list.sort()
+ for index, file_name in enumerate(file_list):
+ shutil.copy(
+ file_name,
+ os.path.join(output_folder, (img_format + '.%s') % (index, ext)))
+ return img_format + '.%s' % ext
+
+
+def images_to_array(
+ input_folder: str,
+ resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
+ img_format: str = '%06d.png',
+ start: int = 0,
+ end: Optional[int] = None,
+ remove_raw_files: bool = False,
+ disable_log: bool = False,
+) -> np.ndarray:
+ """
+ Read a folder of images as an array of (f * h * w * 3).
+
+ Args:
+ input_folder (str): folder of input images.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]]:
+ resolution(height, width) of output. Defaults to None.
+ img_format (str, optional): format of images to be read.
+ Defaults to '%06d.png'.
+ start (int, optional): start frame index. Inclusive.
+ If < 0, will be converted to frame_index range in [0, frame_num].
+ Defaults to 0.
+ end (int, optional): end frame index. Exclusive.
+ Could be positive int or negative int or None.
+ If None, all frames from start till the last frame are included.
+ Defaults to None.
+ remove_raw_files (bool, optional): whether remove raw images.
+ Defaults to False.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+
+ Returns:
+ np.ndarray: shape will be (f * h * w * 3).
+ """
+ check_input_path(
+ input_folder,
+ allowed_suffix=[''],
+ tag='input image folder',
+ path_type='dir')
+
+ input_folderinfo = Path(input_folder)
+
+ temp_input_folder = None
+ if img_format is None:
+ temp_input_folder = os.path.join(input_folderinfo.parent,
+ input_folderinfo.name + '_temp')
+ img_format = images_to_sorted_images(
+ input_folder=input_folder, output_folder=temp_input_folder)
+ input_folder = temp_input_folder
+
+ info = vid_info_reader(f'{input_folder}/{img_format}' % start)
+ width, height = int(info['width']), int(info['height'])
+ if resolution:
+ height, width = resolution
+ else:
+ width, height = int(info['width']), int(info['height'])
+
+ num_frames = len(os.listdir(input_folder))
+ start = max(start, 0) % num_frames
+ end = min(end, num_frames) % (num_frames + 1) \
+ if end is not None else num_frames
+ command = [
+ 'ffmpeg',
+ '-y',
+ '-threads',
+ '1',
+ '-start_number',
+ f'{start}',
+ '-i',
+ f'{input_folder}/{img_format}',
+ '-frames:v',
+ f'{end - start}',
+ '-f',
+ 'rawvideo',
+ '-pix_fmt',
+ 'bgr24', # bgr24 for matching OpenCV
+ '-s',
+ f'{int(width)}x{int(height)}',
+ '-loglevel',
+ 'error',
+ '-'
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ process = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=10**8)
+ if process.stdout is None:
+ raise BrokenPipeError('No buffer received.')
+ # Read decoded video frames from the PIPE until no more frames to read
+ array = []
+ while True:
+ # Read decoded video frame (in raw video format) from stdout process.
+ buffer = process.stdout.read(int(width * height * 3))
+ # Break the loop if buffer length is not W*H*3\
+ # (when FFmpeg streaming ends).
+
+ if len(buffer) != width * height * 3:
+ break
+ img = np.frombuffer(buffer, np.uint8).reshape(height, width, 3)
+ array.append(img[np.newaxis])
+ process.stdout.flush()
+ process.stdout.close()
+ process.wait()
+ if temp_input_folder is not None:
+ if Path(temp_input_folder).is_dir():
+ shutil.rmtree(temp_input_folder)
+ if remove_raw_files:
+ if Path(input_folder).is_dir():
+ shutil.rmtree(input_folder)
+
+ return np.concatenate(array)
+
+
+class vid_info_reader(object):
+
+ def __init__(self, input_path) -> None:
+ """Get video information from video, mimiced from ffmpeg-python.
+ https://github.com/kkroening/ffmpeg-python.
+
+ Args:
+ vid_file ([str]): video file path.
+
+ Raises:
+ FileNotFoundError: check the input path.
+
+ Returns:
+ None.
+ """
+ check_input_path(
+ input_path,
+ allowed_suffix=['.mp4', '.gif', '.png', '.jpg', '.jpeg'],
+ tag='input file',
+ path_type='file')
+ cmd = [
+ 'ffprobe', '-show_format', '-show_streams', '-of', 'json',
+ input_path
+ ]
+ process = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ out, _ = process.communicate()
+ probe = json.loads(out.decode('utf-8'))
+ video_stream = next((stream for stream in probe['streams']
+ if stream['codec_type'] == 'video'), None)
+ if video_stream is None:
+ print('No video stream found', file=sys.stderr)
+ sys.exit(1)
+ self.video_stream = video_stream
+
+ def __getitem__(
+ self,
+ key: Literal['index', 'codec_name', 'codec_long_name', 'profile',
+ 'codec_type', 'codec_time_base', 'codec_tag_string',
+ 'codec_tag', 'width', 'height', 'coded_width',
+ 'coded_height', 'has_b_frames', 'pix_fmt', 'level',
+ 'chroma_location', 'refs', 'is_avc', 'nal_length_size',
+ 'r_frame_rate', 'avg_frame_rate', 'time_base',
+ 'start_pts', 'start_time', 'duration_ts', 'duration',
+ 'bit_rate', 'bits_per_raw_sample', 'nb_frames',
+ 'disposition', 'tags']):
+ """Key (str): select in ['index', 'codec_name', 'codec_long_name',
+ 'profile', 'codec_type', 'codec_time_base', 'codec_tag_string',
+ 'codec_tag', 'width', 'height', 'coded_width', 'coded_height',
+ 'has_b_frames', 'pix_fmt', 'level', 'chroma_location', 'refs',
+ 'is_avc', 'nal_length_size', 'r_frame_rate', 'avg_frame_rate',
+ 'time_base', 'start_pts', 'start_time', 'duration_ts', 'duration',
+ 'bit_rate', 'bits_per_raw_sample', 'nb_frames', 'disposition',
+ 'tags']"""
+ return self.video_stream[key]
+
+
+def video_to_gif(
+ input_path: str,
+ output_path: str,
+ resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
+ fps: Union[float, int] = 15,
+ disable_log: bool = False,
+) -> None:
+ """Convert a video to a gif file.
+
+ Args:
+ input_path (str): video file path.
+ output_path (str): gif file path.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of the output video.
+ Defaults to None.
+ fps (Union[float, int], optional): frames per second. Defaults to 15.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None.
+ """
+ check_input_path(
+ input_path,
+ allowed_suffix=['.mp4'],
+ tag='input video',
+ path_type='file')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.gif'],
+ tag='output gif',
+ path_type='file',
+ overwrite=True)
+
+ info = vid_info_reader(input_path)
+ duration = info['duration']
+ if resolution:
+ height, width = resolution
+ else:
+ width, height = int(info['width']), int(info['height'])
+
+ command = [
+ 'ffmpeg', '-r',
+ str(info['r_frame_rate']), '-i', input_path, '-r', f'{fps}', '-s',
+ f'{width}x{height}', '-loglevel', 'error', '-t', f'{duration}',
+ '-threads', '4', '-y', output_path
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+
+
+def video_to_images(input_path: str,
+ output_folder: str,
+ resolution: Optional[Union[Tuple[int, int],
+ Tuple[float, float]]] = None,
+ img_format: str = '%06d.png',
+ start: int = 0,
+ end: Optional[int] = None,
+ disable_log: bool = False) -> None:
+ """Convert a video to a folder of images.
+
+ Args:
+ input_path (str): video file path
+ output_folder (str): output folder to store the images
+ resolution (Optional[Tuple[int, int]], optional):
+ (height, width) of output. defaults to None.
+ img_format (str, optional): format of images to be read.
+ Defaults to '%06d.png'.
+ start (int, optional): start frame index. Inclusive.
+ If < 0, will be converted to frame_index range in [0, frame_num].
+ Defaults to 0.
+ end (int, optional): end frame index. Exclusive.
+ Could be positive int or negative int or None.
+ If None, all frames from start till the last frame are included.
+ Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path
+ FileNotFoundError: check the output path
+
+ Returns:
+ None
+ """
+ check_input_path(
+ input_path,
+ allowed_suffix=['.mp4'],
+ tag='input video',
+ path_type='file')
+ prepare_output_path(
+ output_folder,
+ allowed_suffix=[],
+ tag='output image folder',
+ path_type='dir',
+ overwrite=True)
+ info = vid_info_reader(input_path)
+ num_frames = int(info['nb_frames'])
+ start = (min(start, num_frames - 1) + num_frames) % num_frames
+ end = (min(end, num_frames - 1) +
+ num_frames) % num_frames if end is not None else num_frames
+
+ command = [
+ 'ffmpeg', '-i', input_path, '-filter_complex',
+ f'[0]trim=start_frame={start}:end_frame={end}[v0]', '-map', '[v0]',
+ '-f', 'image2', '-v', 'error', '-start_number', '0', '-threads', '1',
+ f'{output_folder}/{img_format}'
+ ]
+ if resolution:
+ height, width = resolution
+ command.insert(3, '-s')
+ command.insert(4, '%dx%d' % (width, height))
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+
+
+def images_to_video(input_folder: str,
+ output_path: str,
+ remove_raw_file: bool = False,
+ img_format: str = '%06d.png',
+ fps: Union[int, float] = 30,
+ resolution: Optional[Union[Tuple[int, int],
+ Tuple[float, float]]] = None,
+ start: int = 0,
+ end: Optional[int] = None,
+ disable_log: bool = False) -> None:
+ """Convert a folder of images to a video.
+
+ Args:
+ input_folder (str): input image folder
+ output_path (str): output video file path
+ remove_raw_file (bool, optional): whether remove raw images.
+ Defaults to False.
+ img_format (str, optional): format to name the images].
+ Defaults to '%06d.png'.
+ fps (Union[int, float], optional): output video fps. Defaults to 30.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of output.
+ defaults to None.
+ start (int, optional): start frame index. Inclusive.
+ If < 0, will be converted to frame_index range in [0, frame_num].
+ Defaults to 0.
+ end (int, optional): end frame index. Exclusive.
+ Could be positive int or negative int or None.
+ If None, all frames from start till the last frame are included.
+ Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None
+ """
+ check_input_path(
+ input_folder,
+ allowed_suffix=[],
+ tag='input image folder',
+ path_type='dir')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+ input_folderinfo = Path(input_folder)
+ num_frames = len(os.listdir(input_folder))
+ start = (min(start, num_frames - 1) + num_frames) % num_frames
+ end = (min(end, num_frames - 1) +
+ num_frames) % num_frames if end is not None else num_frames
+ temp_input_folder = None
+ if img_format is None:
+ temp_input_folder = os.path.join(input_folderinfo.parent,
+ input_folderinfo.name + '_temp')
+ img_format = images_to_sorted_images(input_folder, temp_input_folder)
+
+ command = [
+ 'ffmpeg',
+ '-y',
+ '-threads',
+ '4',
+ '-start_number',
+ f'{start}',
+ '-r',
+ f'{fps}',
+ '-i',
+ f'{input_folder}/{img_format}'
+ if temp_input_folder is None else f'{temp_input_folder}/{img_format}',
+ '-frames:v',
+ f'{end - start}',
+ '-profile:v',
+ 'baseline',
+ '-level',
+ '3.0',
+ '-c:v',
+ 'libx264',
+ '-pix_fmt',
+ 'yuv420p',
+ '-an',
+ '-v',
+ 'error',
+ '-loglevel',
+ 'error',
+ output_path,
+ ]
+ if resolution:
+ height, width = resolution
+ width += width % 2
+ height += height % 2
+ command.insert(1, '-s')
+ command.insert(2, '%dx%d' % (width, height))
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+ if remove_raw_file:
+ if Path(input_folder).is_dir():
+ shutil.rmtree(input_folder)
+ if temp_input_folder is not None:
+ if Path(temp_input_folder).is_dir():
+ shutil.rmtree(temp_input_folder)
+
+
+def images_to_gif(
+ input_folder: str,
+ output_path: str,
+ remove_raw_file: bool = False,
+ img_format: str = '%06d.png',
+ fps: int = 15,
+ resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
+ start: int = 0,
+ end: Optional[int] = None,
+ disable_log: bool = False,
+) -> None:
+ """Convert series of images to a video, similar to images_to_video, but
+ provide more suitable parameters.
+
+ Args:
+ input_folder (str): input image folder.
+ output_path (str): output gif file path.
+ remove_raw_file (bool, optional): whether remove raw images.
+ Defaults to False.
+ img_format (str, optional): format to name the images.
+ Defaults to '%06d.png'.
+ fps (int, optional): output video fps. Defaults to 15.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of output. Defaults to None.
+ start (int, optional): start frame index. Inclusive.
+ If < 0, will be converted to frame_index range in [0, frame_num].
+ Defaults to 0.
+ end (int, optional): end frame index. Exclusive.
+ Could be positive int or negative int or None.
+ If None, all frames from start till the last frame are included.
+ Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None
+ """
+ input_folderinfo = Path(input_folder)
+ check_input_path(
+ input_folder,
+ allowed_suffix=[],
+ tag='input image folder',
+ path_type='dir')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.gif'],
+ tag='output gif',
+ path_type='file',
+ overwrite=True)
+ num_frames = len(os.listdir(input_folder))
+ start = (min(start, num_frames - 1) + num_frames) % num_frames
+ end = (min(end, num_frames - 1) +
+ num_frames) % num_frames if end is not None else num_frames
+ temp_input_folder = None
+ if img_format is None:
+ file_list = []
+ temp_input_folder = os.path.join(input_folderinfo.parent,
+ input_folderinfo.name + '_temp')
+ os.makedirs(temp_input_folder, exist_ok=True)
+ pngs = glob.glob(os.path.join(input_folder, '*.png'))
+ ext = 'png'
+ if pngs:
+ ext = 'png'
+ file_list.extend(pngs)
+ jpgs = glob.glob(os.path.join(input_folder, '*.jpg'))
+ if jpgs:
+ ext = 'jpg'
+ file_list.extend(jpgs)
+ file_list.sort()
+ for index, file_name in enumerate(file_list):
+ shutil.copy(
+ file_name,
+ os.path.join(temp_input_folder, '%06d.%s' % (index + 1, ext)))
+ input_folder = temp_input_folder
+ img_format = '%06d.' + ext
+
+ command = [
+ 'ffmpeg',
+ '-y',
+ '-threads',
+ '4',
+ '-start_number',
+ f'{start}',
+ '-r',
+ f'{fps}',
+ '-i',
+ f'{input_folder}/{img_format}',
+ '-frames:v',
+ f'{end - start}',
+ '-loglevel',
+ 'error',
+ '-v',
+ 'error',
+ output_path,
+ ]
+ if resolution:
+ height, width = resolution
+ command.insert(1, '-s')
+ command.insert(2, '%dx%d' % (width, height))
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+ if remove_raw_file:
+ shutil.rmtree(input_folder)
+ if temp_input_folder is not None:
+ shutil.rmtree(temp_input_folder)
+
+
+def gif_to_video(input_path: str,
+ output_path: str,
+ fps: int = 30,
+ remove_raw_file: bool = False,
+ resolution: Optional[Union[Tuple[int, int],
+ Tuple[float, float]]] = None,
+ disable_log: bool = False) -> None:
+ """Convert a gif file to a video.
+
+ Args:
+ input_path (str): input gif file path.
+ output_path (str): output video file path.
+ fps (int, optional): fps. Defaults to 30.
+ remove_raw_file (bool, optional): whether remove original input file.
+ Defaults to False.
+ down_sample_scale (Union[int, float], optional): down sample scale.
+ Defaults to 1.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of output. Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None
+ """
+ check_input_path(
+ input_path, allowed_suffix=['.gif'], tag='input gif', path_type='file')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+ command = [
+ 'ffmpeg', '-i', input_path, '-r', f'{fps}', '-loglevel', 'error', '-y',
+ output_path, '-threads', '4'
+ ]
+ if resolution:
+ height, width = resolution
+ command.insert(3, '-s')
+ command.insert(4, '%dx%d' % (width, height))
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+ if remove_raw_file:
+ subprocess.call(['rm', '-f', input_path])
+
+
+def gif_to_images(input_path: str,
+ output_folder: str,
+ fps: int = 30,
+ img_format: str = '%06d.png',
+ resolution: Optional[Union[Tuple[int, int],
+ Tuple[float, float]]] = None,
+ disable_log: bool = False) -> None:
+ """Convert a gif file to a folder of images.
+
+ Args:
+ input_path (str): input gif file path.
+ output_folder (str): output folder to save the images.
+ fps (int, optional): fps. Defaults to 30.
+ img_format (str, optional): output image name format.
+ Defaults to '%06d.png'.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of output.
+ Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None
+ """
+ check_input_path(
+ input_path, allowed_suffix=['.gif'], tag='input gif', path_type='file')
+ prepare_output_path(
+ output_folder,
+ allowed_suffix=[],
+ tag='output image folder',
+ path_type='dir',
+ overwrite=True)
+ command = [
+ 'ffmpeg', '-r', f'{fps}', '-i', input_path, '-loglevel', 'error', '-f',
+ 'image2', '-v', 'error', '-threads', '4', '-y', '-start_number', '0',
+ f'{output_folder}/{img_format}'
+ ]
+ if resolution:
+ height, width = resolution
+ command.insert(3, '-s')
+ command.insert(4, '%dx%d' % (width, height))
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+
+
+def crop_video(
+ input_path: str,
+ output_path: str,
+ box: Optional[Union[List[int], Tuple[int, int, int, int]]] = None,
+ resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
+ disable_log: bool = False,
+) -> None:
+ """Spatially or temporally crop a video or gif file.
+
+ Args:
+ input_path (str): input video or gif file path.
+ output_path (str): output video or gif file path.
+ box (Iterable[int], optional): [x, y of the crop region left.
+ corner and width and height]. Defaults to [0, 0, 100, 100].
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of output. Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None'-start_number', f'{start}',
+ """
+ check_input_path(
+ input_path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='input video',
+ path_type='file')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+
+ info = vid_info_reader(input_path)
+ width, height = int(info['width']), int(info['height'])
+
+ if box is None:
+ box = [0, 0, width, height]
+
+ assert len(box) == 4
+ x, y, w, h = box
+ assert (w > 0 and h > 0)
+ command = [
+ 'ffmpeg', '-i', input_path, '-vcodec', 'libx264', '-vf',
+ 'crop=%d:%d:%d:%d' % (w, h, x, y), '-loglevel', 'error', '-y',
+ output_path
+ ]
+ if resolution:
+ height, width = resolution
+ width += width % 2
+ height += height % 2
+ command.insert(-1, '-s')
+ command.insert(-1, '%dx%d' % (width, height))
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+
+
+def slice_video(input_path: str,
+ output_path: str,
+ start: int = 0,
+ end: Optional[int] = None,
+ resolution: Optional[Union[Tuple[int, int],
+ Tuple[float, float]]] = None,
+ disable_log: bool = False) -> None:
+ """Temporally crop a video/gif into another video/gif.
+
+ Args:
+ input_path (str): input video or gif file path.
+ output_path (str): output video of gif file path.
+ start (int, optional): start frame index. Defaults to 0.
+ end (int, optional): end frame index. Exclusive.
+ Could be positive int or negative int or None.
+ If None, all frames from start till the last frame are included.
+ Defaults to None.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of output. Defaults to None.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ NoReturn
+ """
+ info = vid_info_reader(input_path)
+ num_frames = int(info['nb_frames'])
+ start = (min(start, num_frames - 1) + num_frames) % num_frames
+ end = (min(end, num_frames - 1) +
+ num_frames) % num_frames if end is not None else num_frames
+ command = [
+ 'ffmpeg', '-y', '-i', input_path, '-filter_complex',
+ f'[0]trim=start_frame={start}:end_frame={end}[v0]', '-map', '[v0]',
+ '-loglevel', 'error', '-vcodec', 'libx264', output_path
+ ]
+ if resolution:
+ height, width = resolution
+ width += width % 2
+ height += height % 2
+ command.insert(1, '-s')
+ command.insert(2, '%dx%d' % (width, height))
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+
+
+def spatial_concat_video(input_path_list: List[str],
+ output_path: str,
+ array: List[int] = [1, 1],
+ direction: Literal['h', 'w'] = 'h',
+ resolution: Union[Tuple[int,
+ int], List[int], List[float],
+ Tuple[float, float]] = (512, 512),
+ remove_raw_files: bool = False,
+ padding: int = 0,
+ disable_log: bool = False) -> None:
+ """Spatially concat some videos as an array video.
+
+ Args:
+ input_path_list (list): input video or gif file list.
+ output_path (str): output video or gif file path.
+ array (List[int], optional): line number and column number of
+ the video array]. Defaults to [1, 1].
+ direction (str, optional): [choose in 'h' or 'v', represent
+ horizontal and vertical separately].
+ Defaults to 'h'.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
+ optional): (height, width) of output.
+ Defaults to (512, 512).
+ remove_raw_files (bool, optional): whether remove raw images.
+ Defaults to False.
+ padding (int, optional): width of pixels between videos.
+ Defaults to 0.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None
+ """
+ lowercase = string.ascii_lowercase
+ assert len(array) == 2
+ assert (array[0] * array[1]) >= len(input_path_list)
+ for path in input_path_list:
+ check_input_path(
+ path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='input video',
+ path_type='file')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+
+ command = ['ffmpeg']
+ height, width = resolution
+ scale_command = []
+ for index, vid_file in enumerate(input_path_list):
+ command.append('-i')
+ command.append(vid_file)
+ scale_command.append(
+ '[%d:v]scale=%d:%d:force_original_aspect_ratio=0[v%d];' %
+ (index, width, height, index))
+
+ scale_command = ' '.join(scale_command)
+ pad_command = '[v%d]pad=%d:%d[%s];' % (0, width * array[1] + padding *
+ (array[1] - 1),
+ height * array[0] + padding *
+ (array[0] - 1), lowercase[0])
+ for index in range(1, len(input_path_list)):
+ if direction == 'h':
+ pad_width = index % array[1] * (width + padding)
+ pad_height = index // array[1] * (height + padding)
+ else:
+ pad_width = index % array[0] * (width + padding)
+ pad_height = index // array[0] * (height + padding)
+
+ pad_command += '[%s][v%d]overlay=%d:%d' % (lowercase[index - 1], index,
+ pad_width, pad_height)
+ if index != len(input_path_list) - 1:
+ pad_command += '[%s];' % lowercase[index]
+
+ command += [
+ '-filter_complex',
+ '%s%s' % (scale_command, pad_command), '-loglevel', 'error', '-y',
+ output_path
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+
+ if remove_raw_files:
+ command = ['rm', '-f'] + input_path_list
+ subprocess.call(command)
+
+
+def temporal_concat_video(input_path_list: List[str],
+ output_path: str,
+ resolution: Union[Tuple[int, int],
+ Tuple[float, float]] = (512, 512),
+ remove_raw_files: bool = False,
+ disable_log: bool = False) -> None:
+ """Concat no matter videos or gifs into a temporal sequence, and save as a
+ new video or gif file.
+
+ Args:
+ input_path_list (List[str]): list of input video paths.
+ output_path (str): output video file path.
+ resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]]
+ , optional): (height, width) of output].
+ Defaults to (512,512).
+ remove_raw_files (bool, optional): whether remove the input videos.
+ Defaults to False.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None.
+ """
+ for path in input_path_list:
+ check_input_path(
+ path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='input video',
+ path_type='file')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+
+ height, width = resolution
+ command = ['ffmpeg']
+ concat_command = []
+ scale_command = []
+ for index, vid_file in enumerate(input_path_list):
+ command.append('-i')
+ command.append(vid_file)
+ scale_command.append(
+ '[%d:v]scale=%d:%d:force_original_aspect_ratio=0[v%d];' %
+ (index, width, height, index))
+ concat_command.append('[v%d]' % index)
+ concat_command = ''.join(concat_command)
+ scale_command = ''.join(scale_command)
+ command += [
+ '-filter_complex',
+ '%s%sconcat=n=%d:v=1:a=0[v]' %
+ (scale_command, concat_command, len(input_path_list)), '-loglevel',
+ 'error', '-map', '[v]', '-c:v', 'libx264', '-y', output_path
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+
+ if remove_raw_files:
+ command = ['rm'] + input_path_list
+ subprocess.call(command)
+
+
+def compress_video(input_path: str,
+ output_path: str,
+ compress_rate: int = 1,
+ down_sample_scale: Union[float, int] = 1,
+ fps: int = 30,
+ disable_log: bool = False) -> None:
+ """Compress a video file.
+
+ Args:
+ input_path (str): input video file path.
+ output_path (str): output video file path.
+ compress_rate (int, optional): compress rate, influents the bit rate.
+ Defaults to 1.
+ down_sample_scale (Union[float, int], optional): spatial down sample
+ scale. Defaults to 1.
+ fps (int, optional): Frames per second. Defaults to 30.
+ disable_log (bool, optional): whether close the ffmepg command info.
+ Defaults to False.
+ Raises:
+ FileNotFoundError: check the input path.
+ FileNotFoundError: check the output path.
+
+ Returns:
+ None.
+ """
+ input_pathinfo = Path(input_path)
+
+ check_input_path(
+ input_path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='input video',
+ path_type='file')
+ prepare_output_path(
+ output_path,
+ allowed_suffix=['.gif', '.mp4'],
+ tag='output video',
+ path_type='file',
+ overwrite=True)
+
+ info = vid_info_reader(input_path)
+
+ width = int(info['width'])
+ height = int(info['height'])
+ bit_rate = int(info['bit_rate'])
+ duration = float(info['duration'])
+ if (output_path == input_path) or (not output_path):
+ temp_outpath = os.path.join(
+ os.path.abspath(input_pathinfo.parent),
+ 'temp_file' + input_pathinfo.suffix)
+ else:
+ temp_outpath = output_path
+ new_width = int(width / down_sample_scale)
+ new_width += new_width % 2
+ new_height = int(height / down_sample_scale)
+ new_height += new_height % 2
+ command = [
+ 'ffmpeg', '-y', '-r',
+ str(info['r_frame_rate']), '-i', input_path, '-loglevel', 'error',
+ '-b:v', f'{bit_rate / (compress_rate * down_sample_scale)}', '-r',
+ f'{fps}', '-t', f'{duration}', '-s',
+ '%dx%d' % (new_width, new_height), temp_outpath
+ ]
+ if not disable_log:
+ print(f'Running \"{" ".join(command)}\"')
+ subprocess.call(command)
+ if (output_path == input_path) or (not output_path):
+ subprocess.call(['mv', '-f', temp_outpath, input_path])
+
+
+def pad_for_libx264(image_array):
+ """Pad zeros if width or height of image_array is not divisible by 2.
+ Otherwise you will get.
+
+ \"[libx264 @ 0x1b1d560] width not divisible by 2 \"
+
+ Args:
+ image_array (np.ndarray):
+ Image or images load by cv2.imread().
+ Possible shapes:
+ 1. [height, width]
+ 2. [height, width, channels]
+ 3. [images, height, width]
+ 4. [images, height, width, channels]
+
+ Returns:
+ np.ndarray:
+ A image with both edges divisible by 2.
+ """
+ if image_array.ndim == 2 or \
+ (image_array.ndim == 3 and image_array.shape[2] == 3):
+ hei_index = 0
+ wid_index = 1
+ elif image_array.ndim == 4 or \
+ (image_array.ndim == 3 and image_array.shape[2] != 3):
+ hei_index = 1
+ wid_index = 2
+ else:
+ return image_array
+ hei_pad = image_array.shape[hei_index] % 2
+ wid_pad = image_array.shape[wid_index] % 2
+ if hei_pad + wid_pad > 0:
+ pad_width = []
+ for dim_index in range(image_array.ndim):
+ if dim_index == hei_index:
+ pad_width.append((0, hei_pad))
+ elif dim_index == wid_index:
+ pad_width.append((0, wid_pad))
+ else:
+ pad_width.append((0, 0))
+ values = 0
+ image_array = \
+ np.pad(image_array,
+ pad_width,
+ mode='constant', constant_values=values)
+ return image_array
diff --git a/detrsmpl/utils/transforms.py b/detrsmpl/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e3b54ab62ac4ebd965aa726b466909c5263df5a
--- /dev/null
+++ b/detrsmpl/utils/transforms.py
@@ -0,0 +1,590 @@
+from typing import Union
+
+import numpy
+import torch
+
+from detrsmpl.core.conventions.joints_mapping.standard_joint_angles import (
+ TRANSFORMATION_AA_TO_SJA,
+ TRANSFORMATION_SJA_TO_AA,
+)
+from .logger import get_root_logger
+
+try:
+ from pytorch3d.transforms import (
+ axis_angle_to_matrix,
+ axis_angle_to_quaternion,
+ euler_angles_to_matrix,
+ matrix_to_euler_angles,
+ matrix_to_quaternion,
+ matrix_to_rotation_6d,
+ quaternion_to_axis_angle,
+ quaternion_to_matrix,
+ rotation_6d_to_matrix,
+ )
+except (ImportError, ModuleNotFoundError):
+ import traceback
+ logger = get_root_logger()
+ stack_str = ''
+ for line in traceback.format_stack():
+ if 'frozen' not in line:
+ stack_str += line + '\n'
+ import_exception = traceback.format_exc() + '\n'
+ warning_msg = stack_str + import_exception + \
+ 'If pytorch3d is not required,' +\
+ ' this warning could be ignored.'
+ logger.warning(warning_msg)
+
+
+class Compose:
+ def __init__(self, transforms: list):
+ """Composes several transforms together. This transform does not
+ support torchscript.
+
+ Args:
+ transforms (list): (list of transform functions)
+ """
+ self.transforms = transforms
+
+ def __call__(self,
+ rotation: Union[torch.Tensor, numpy.ndarray],
+ convention: str = 'xyz',
+ **kwargs):
+ convention = convention.lower()
+ if not (set(convention) == set('xyz') and len(convention) == 3):
+ raise ValueError(f'Invalid convention {convention}.')
+ if isinstance(rotation, numpy.ndarray):
+ data_type = 'numpy'
+ rotation = torch.FloatTensor(rotation)
+ elif isinstance(rotation, torch.Tensor):
+ data_type = 'tensor'
+ else:
+ raise TypeError(
+ 'Type of rotation should be torch.Tensor or numpy.ndarray')
+ for t in self.transforms:
+ if 'convention' in t.__code__.co_varnames:
+ rotation = t(rotation, convention.upper(), **kwargs)
+ else:
+ rotation = t(rotation, **kwargs)
+ if data_type == 'numpy':
+ rotation = rotation.detach().cpu().numpy()
+ return rotation
+
+
+def aa_to_rotmat(
+ axis_angle: Union[torch.Tensor, numpy.ndarray]
+) -> Union[torch.Tensor, numpy.ndarray]:
+ """
+ Convert axis_angle to rotation matrixs.
+ Args:
+ axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 3). ndim of input is unlimited.
+
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3).
+ """
+ if axis_angle.shape[-1] != 3:
+ raise ValueError(
+ f'Invalid input axis angles shape f{axis_angle.shape}.')
+ t = Compose([axis_angle_to_matrix])
+ return t(axis_angle)
+
+
+def aa_to_quat(
+ axis_angle: Union[torch.Tensor, numpy.ndarray]
+) -> Union[torch.Tensor, numpy.ndarray]:
+ """
+ Convert axis_angle to quaternions.
+ Args:
+ axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 3). ndim of input is unlimited.
+
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4).
+ """
+ if axis_angle.shape[-1] != 3:
+ raise ValueError(f'Invalid input axis angles f{axis_angle.shape}.')
+ t = Compose([axis_angle_to_quaternion])
+ return t(axis_angle)
+
+
+def ee_to_rotmat(euler_angle: Union[torch.Tensor, numpy.ndarray],
+ convention='xyz') -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert euler angle to rotation matrixs.
+
+ Args:
+ euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 3). ndim of input is unlimited.
+ convention (str, optional): Convention string of three letters
+ from {“x”, “y”, and “z”}. Defaults to 'xyz'.
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3).
+ """
+ if euler_angle.shape[-1] != 3:
+ raise ValueError(
+ f'Invalid input euler angles shape f{euler_angle.shape}.')
+ t = Compose([euler_angles_to_matrix])
+ return t(euler_angle, convention.upper())
+
+
+def rotmat_to_ee(
+ matrix: Union[torch.Tensor, numpy.ndarray],
+ convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert rotation matrixs to euler angle.
+
+ Args:
+ matrix (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 3, 3). ndim of input is unlimited.
+ convention (str, optional): Convention string of three letters
+ from {“x”, “y”, and “z”}. Defaults to 'xyz'.
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
+ """
+ if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
+ raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.')
+ t = Compose([matrix_to_euler_angles])
+ return t(matrix, convention.upper())
+
+
+def rotmat_to_quat(
+ matrix: Union[torch.Tensor, numpy.ndarray]
+) -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert rotation matrixs to quaternions.
+
+ Args:
+ matrix (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 3, 3). ndim of input is unlimited.
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4).
+ """
+ if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
+ raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.')
+ t = Compose([matrix_to_quaternion])
+ return t(matrix)
+
+
+def rotmat_to_rot6d(
+ matrix: Union[torch.Tensor, numpy.ndarray]
+) -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert rotation matrixs to rotation 6d representations.
+
+ Args:
+ matrix (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 3, 3). ndim of input is unlimited.
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6).
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+ if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
+ raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.')
+ t = Compose([matrix_to_rotation_6d])
+ return t(matrix)
+
+
+def quat_to_aa(
+ quaternions: Union[torch.Tensor, numpy.ndarray]
+) -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert quaternions to axis angles.
+
+ Args:
+ quaternions (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 3). ndim of input is unlimited.
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
+ """
+ if quaternions.shape[-1] != 4:
+ raise ValueError(f'Invalid input quaternions f{quaternions.shape}.')
+ t = Compose([quaternion_to_axis_angle])
+ return t(quaternions)
+
+
+def quat_to_rotmat(
+ quaternions: Union[torch.Tensor, numpy.ndarray]
+) -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert quaternions to rotation matrixs.
+
+ Args:
+ quaternions (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 3). ndim of input is unlimited.
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3).
+ """
+ if quaternions.shape[-1] != 4:
+ raise ValueError(
+ f'Invalid input quaternions shape f{quaternions.shape}.')
+ t = Compose([quaternion_to_matrix])
+ return t(quaternions)
+
+
+def rot6d_to_rotmat(
+ rotation_6d: Union[torch.Tensor, numpy.ndarray]
+) -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert rotation 6d representations to rotation matrixs.
+
+ Args:
+ rotation_6d (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 6). ndim of input is unlimited.
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3).
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+ if rotation_6d.shape[-1] != 6:
+ raise ValueError(f'Invalid input rotation_6d f{rotation_6d.shape}.')
+ t = Compose([rotation_6d_to_matrix])
+ return t(rotation_6d)
+
+
+def aa_to_ee(axis_angle: Union[torch.Tensor, numpy.ndarray],
+ convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert axis angles to euler angle.
+
+ Args:
+ axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 3). ndim of input is unlimited.
+ convention (str, optional): Convention string of three letters
+ from {“x”, “y”, and “z”}. Defaults to 'xyz'.
+
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
+ """
+ if axis_angle.shape[-1] != 3:
+ raise ValueError(
+ f'Invalid input axis_angle shape f{axis_angle.shape}.')
+ t = Compose([axis_angle_to_matrix, matrix_to_euler_angles])
+ return t(axis_angle, convention)
+
+
+def aa_to_rot6d(
+ axis_angle: Union[torch.Tensor, numpy.ndarray]
+) -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert axis angles to rotation 6d representations.
+
+ Args:
+ axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 3). ndim of input is unlimited.
+
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6).
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+ if axis_angle.shape[-1] != 3:
+ raise ValueError(f'Invalid input axis_angle f{axis_angle.shape}.')
+ t = Compose([axis_angle_to_matrix, matrix_to_rotation_6d])
+ return t(axis_angle)
+
+
+def ee_to_aa(euler_angle: Union[torch.Tensor, numpy.ndarray],
+ convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert euler angles to axis angles.
+
+ Args:
+ euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 3). ndim of input is unlimited.
+ convention (str, optional): Convention string of three letters
+ from {“x”, “y”, and “z”}. Defaults to 'xyz'.
+
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
+ """
+ if euler_angle.shape[-1] != 3:
+ raise ValueError(f'Invalid input euler_angle f{euler_angle.shape}.')
+ t = Compose([
+ euler_angles_to_matrix, matrix_to_quaternion, quaternion_to_axis_angle
+ ])
+ return t(euler_angle, convention)
+
+
+def ee_to_quat(euler_angle: Union[torch.Tensor, numpy.ndarray],
+ convention='xyz') -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert euler angles to quaternions.
+
+ Args:
+ euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 3). ndim of input is unlimited.
+ convention (str, optional): Convention string of three letters
+ from {“x”, “y”, and “z”}. Defaults to 'xyz'.
+
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4).
+ """
+ if euler_angle.shape[-1] != 3:
+ raise ValueError(f'Invalid input euler_angle f{euler_angle.shape}.')
+ t = Compose([euler_angles_to_matrix, matrix_to_quaternion])
+ return t(euler_angle, convention)
+
+
+def ee_to_rot6d(euler_angle: Union[torch.Tensor, numpy.ndarray],
+ convention='xyz') -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert euler angles to rotation 6d representation.
+
+ Args:
+ euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 3). ndim of input is unlimited.
+ convention (str, optional): Convention string of three letters
+ from {“x”, “y”, and “z”}. Defaults to 'xyz'.
+
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6).
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+ if euler_angle.shape[-1] != 3:
+ raise ValueError(f'Invalid input euler_angle f{euler_angle.shape}.')
+ t = Compose([euler_angles_to_matrix, matrix_to_rotation_6d])
+ return t(euler_angle, convention)
+
+
+def rotmat_to_aa(
+ matrix: Union[torch.Tensor, numpy.ndarray],
+ convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert rotation matrixs to axis angles.
+
+ Args:
+ matrix (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 3, 3). ndim of input is unlimited.
+ convention (str, optional): Convention string of three letters
+ from {“x”, “y”, and “z”}. Defaults to 'xyz'.
+
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
+ """
+ if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
+ raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.')
+ t = Compose([matrix_to_quaternion, quaternion_to_axis_angle])
+ return t(matrix)
+
+
+def quat_to_ee(quaternions: Union[torch.Tensor, numpy.ndarray],
+ convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert quaternions to euler angles.
+
+ Args:
+ quaternions (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 4). ndim of input is unlimited.
+ convention (str, optional): Convention string of three letters
+ from {“x”, “y”, and “z”}. Defaults to 'xyz'.
+
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
+ """
+ if quaternions.shape[-1] != 4:
+ raise ValueError(f'Invalid input quaternions f{quaternions.shape}.')
+ t = Compose([quaternion_to_matrix, matrix_to_euler_angles])
+ return t(quaternions, convention)
+
+
+def quat_to_rot6d(
+ quaternions: Union[torch.Tensor, numpy.ndarray]
+) -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert quaternions to rotation 6d representations.
+
+ Args:
+ quaternions (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 4). ndim of input is unlimited.
+
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6).
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+ if quaternions.shape[-1] != 4:
+ raise ValueError(f'Invalid input quaternions f{quaternions.shape}.')
+ t = Compose([quaternion_to_matrix, matrix_to_rotation_6d])
+ return t(quaternions)
+
+
+def rot6d_to_aa(
+ rotation_6d: Union[torch.Tensor, numpy.ndarray]
+) -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert rotation 6d representations to axis angles.
+
+ Args:
+ rotation_6d (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 6). ndim of input is unlimited.
+
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+ if rotation_6d.shape[-1] != 6:
+ raise ValueError(f'Invalid input rotation_6d f{rotation_6d.shape}.')
+ t = Compose([
+ rotation_6d_to_matrix, matrix_to_quaternion, quaternion_to_axis_angle
+ ])
+ return t(rotation_6d)
+
+
+def rot6d_to_ee(rotation_6d: Union[torch.Tensor, numpy.ndarray],
+ convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert rotation 6d representations to euler angles.
+
+ Args:
+ rotation_6d (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 6). ndim of input is unlimited.
+
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+ if rotation_6d.shape[-1] != 6:
+ raise ValueError(f'Invalid input rotation_6d f{rotation_6d.shape}.')
+ t = Compose([rotation_6d_to_matrix, matrix_to_euler_angles])
+ return t(rotation_6d, convention)
+
+
+def rot6d_to_quat(
+ rotation_6d: Union[torch.Tensor, numpy.ndarray]
+) -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert rotation 6d representations to quaternions.
+
+ Args:
+ rotation (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 6). ndim of input is unlimited.
+
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4).
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+ if rotation_6d.shape[-1] != 6:
+ raise ValueError(
+ f'Invalid input rotation_6d shape f{rotation_6d.shape}.')
+ t = Compose([rotation_6d_to_matrix, matrix_to_quaternion])
+ return t(rotation_6d)
+
+
+def aa_to_sja(
+ axis_angle: Union[torch.Tensor, numpy.ndarray],
+ R_t: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_AA_TO_SJA,
+ R_t_inv: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_SJA_TO_AA
+) -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert axis-angles to standard joint angles.
+
+ Args:
+ axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 21, 3), ndim of input is unlimited.
+ R_t (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 21, 3, 3). Transformation matrices from
+ original axis-angle coordinate system to
+ standard joint angle coordinate system,
+ ndim of input is unlimited.
+ R_t_inv (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 21, 3, 3). Transformation matrices from
+ standard joint angle coordinate system to
+ original axis-angle coordinate system,
+ ndim of input is unlimited.
+
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
+ """
+ def _aa_to_sja(aa, R_t, R_t_inv):
+ R_aa = axis_angle_to_matrix(aa)
+ R_sja = R_t @ R_aa @ R_t_inv
+ sja = matrix_to_euler_angles(R_sja, convention='XYZ')
+ return sja
+
+ if axis_angle.shape[-2:] != (21, 3):
+ raise ValueError(
+ f'Invalid input axis angles shape f{axis_angle.shape}.')
+ if R_t.shape[-3:] != (21, 3, 3):
+ raise ValueError(f'Invalid input R_t shape f{R_t.shape}.')
+ if R_t_inv.shape[-3:] != (21, 3, 3):
+ raise ValueError(f'Invalid input R_t_inv shape f{R_t.shape}.')
+ t = Compose([_aa_to_sja])
+ return t(axis_angle, R_t=R_t, R_t_inv=R_t_inv)
+
+
+def sja_to_aa(
+ sja: Union[torch.Tensor, numpy.ndarray],
+ R_t: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_AA_TO_SJA,
+ R_t_inv: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_SJA_TO_AA
+) -> Union[torch.Tensor, numpy.ndarray]:
+ """Convert standard joint angles to axis angles.
+
+ Args:
+ sja (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 21, 3). ndim of input is unlimited.
+ R_t (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 21, 3, 3). Transformation matrices from
+ original axis-angle coordinate system to
+ standard joint angle coordinate system
+ R_t_inv (Union[torch.Tensor, numpy.ndarray]): input shape
+ should be (..., 21, 3, 3). Transformation matrices from
+ standard joint angle coordinate system to
+ original axis-angle coordinate system
+
+ Returns:
+ Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
+ """
+ def _sja_to_aa(sja, R_t, R_t_inv):
+ R_sja = euler_angles_to_matrix(sja, convention='XYZ')
+ R_aa = R_t_inv @ R_sja @ R_t
+ aa = quaternion_to_axis_angle(matrix_to_quaternion(R_aa))
+ return aa
+
+ if sja.shape[-2:] != (21, 3):
+ raise ValueError(f'Invalid input axis angles shape f{sja.shape}.')
+ if R_t.shape[-3:] != (21, 3, 3):
+ raise ValueError(f'Invalid input R_t shape f{R_t.shape}.')
+ if R_t_inv.shape[-3:] != (21, 3, 3):
+ raise ValueError(f'Invalid input R_t_inv shape f{R_t.shape}.')
+ t = Compose([_sja_to_aa])
+ return t(sja, R_t=R_t, R_t_inv=R_t_inv)
+
+
+def make_homegeneous_rotmat_batch(input: torch.Tensor) -> torch.Tensor:
+ """Appends a row of [0,0,0,1] to a batch size x 3 x 4 Tensor.
+
+ Parameters
+ ----------
+ :param input: A tensor of dimensions batch size x 3 x 4
+ :return: A tensor batch size x 4 x 4 (appended with 0,0,0,1)
+ """
+ batch_size = input.shape[0]
+ row_append = torch.tensor([0.0, 0.0, 0.0, 1.0], dtype=torch.float)
+ row_append.requires_grad = False
+ padded_tensor = torch.cat(
+ [input, row_append.view(1, 1, 4).repeat(batch_size, 1, 1)], dim=1)
+ return padded_tensor
+
+
+def make_homegeneous_rotmat(input: torch.Tensor) -> torch.Tensor:
+ """Appends a row of [0,0,0,1] to a 3 x 4 Tensor.
+
+ Parameters
+ ----------
+ :param input: A tensor of dimensions 3 x 4
+ :return: A tensor batch size x 4 x 4 (appended with 0,0,0,1)
+ """
+ row_append = torch.tensor([0.0, 0.0, 0.0, 1.0], dtype=torch.float)
+ row_append.requires_grad = False
+ padded_tensor = torch.cat(input, row_append, dim=1)
+ return padded_tensor
diff --git a/detrsmpl/utils/util_mixins.py b/detrsmpl/utils/util_mixins.py
new file mode 100644
index 0000000000000000000000000000000000000000..923f1982237f3b3ac6613ba1376f55211014f551
--- /dev/null
+++ b/detrsmpl/utils/util_mixins.py
@@ -0,0 +1,104 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""This module defines the :class:`NiceRepr` mixin class, which defines a
+``__repr__`` and ``__str__`` method that only depend on a custom ``__nice__``
+method, which you must define. This means you only have to overload one
+function instead of two. Furthermore, if the object defines a ``__len__``
+method, then the ``__nice__`` method defaults to something sensible, otherwise
+it is treated as abstract and raises ``NotImplementedError``.
+
+To use simply have your object inherit from :class:`NiceRepr`
+(multi-inheritance should be ok).
+
+This code was copied from the ubelt library: https://github.com/Erotemic/ubelt
+
+Example:
+ >>> # Objects that define __nice__ have a default __str__ and __repr__
+ >>> class Student(NiceRepr):
+ ... def __init__(self, name):
+ ... self.name = name
+ ... def __nice__(self):
+ ... return self.name
+ >>> s1 = Student('Alice')
+ >>> s2 = Student('Bob')
+ >>> print(f's1 = {s1}')
+ >>> print(f's2 = {s2}')
+ s1 =
+ s2 =
+
+Example:
+ >>> # Objects that define __len__ have a default __nice__
+ >>> class Group(NiceRepr):
+ ... def __init__(self, data):
+ ... self.data = data
+ ... def __len__(self):
+ ... return len(self.data)
+ >>> g = Group([1, 2, 3])
+ >>> print(f'g = {g}')
+ g =
+"""
+import warnings
+
+
+class NiceRepr:
+ """Inherit from this class and define ``__nice__`` to "nicely" print your
+ objects.
+
+ Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function
+ Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``.
+ If the inheriting class has a ``__len__``, method then the default
+ ``__nice__`` method will return its length.
+
+ Example:
+ >>> class Foo(NiceRepr):
+ ... def __nice__(self):
+ ... return 'info'
+ >>> foo = Foo()
+ >>> assert str(foo) == ''
+ >>> assert repr(foo).startswith('>> class Bar(NiceRepr):
+ ... pass
+ >>> bar = Bar()
+ >>> import pytest
+ >>> with pytest.warns(None) as record:
+ >>> assert 'object at' in str(bar)
+ >>> assert 'object at' in repr(bar)
+
+ Example:
+ >>> class Baz(NiceRepr):
+ ... def __len__(self):
+ ... return 5
+ >>> baz = Baz()
+ >>> assert str(baz) == ''
+ """
+ def __nice__(self):
+ """str: a "nice" summary string describing this module"""
+ if hasattr(self, '__len__'):
+ # It is a common pattern for objects to use __len__ in __nice__
+ # As a convenience we define a default __nice__ for these objects
+ return str(len(self))
+ else:
+ # In all other cases force the subclass to overload __nice__
+ raise NotImplementedError(
+ f'Define the __nice__ method for {self.__class__!r}')
+
+ def __repr__(self):
+ """str: the string of the module"""
+ try:
+ nice = self.__nice__()
+ classname = self.__class__.__name__
+ return f'<{classname}({nice}) at {hex(id(self))}>'
+ except NotImplementedError as ex:
+ warnings.warn(str(ex), category=RuntimeWarning)
+ return object.__repr__(self)
+
+ def __str__(self):
+ """str: the string of the module"""
+ try:
+ classname = self.__class__.__name__
+ nice = self.__nice__()
+ return f'<{classname}({nice})>'
+ except NotImplementedError as ex:
+ warnings.warn(str(ex), category=RuntimeWarning)
+ return object.__repr__(self)
diff --git a/detrsmpl/version.py b/detrsmpl/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..c64e135f874d571b38b00ed075df3035638073eb
--- /dev/null
+++ b/detrsmpl/version.py
@@ -0,0 +1,29 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+
+# __version__ = '0.9.0'
+__version__ = '0.10.0'
+
+
+def parse_version_info(version_str):
+ """Parse a version string into a tuple.
+
+ Args:
+ version_str (str): The version string.
+ Returns:
+ tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
+ (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1').
+ """
+ version_info = []
+ for x in version_str.split('.'):
+ if x.isdigit():
+ version_info.append(int(x))
+ elif x.find('rc') != -1:
+ patch_version = x.split('rc')
+ version_info.append(int(patch_version[0]))
+ version_info.append(f'rc{patch_version[1]}')
+ return tuple(version_info)
+
+
+version_info = parse_version_info(__version__)
+
+__all__ = ['__version__', 'version_info', 'parse_version_info']
diff --git a/engine.py b/engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..14f7082674d6e89f9e4a40d66b8aef2a21c8f06f
--- /dev/null
+++ b/engine.py
@@ -0,0 +1,352 @@
+import math
+import os
+import time
+import datetime
+import sys
+from typing import Iterable
+import os.path as osp
+import torch
+import util.misc as utils
+from collections import OrderedDict
+import mmcv
+import torch
+import numpy as np
+import torch.distributed as dist
+from mmcv.runner import get_dist_info
+from detrsmpl.apis.test import collect_results_cpu, collect_results_gpu
+from detrsmpl.utils.ffmpeg_utils import images_to_video
+from torch.utils.tensorboard import SummaryWriter
+import json
+from mmcv.runner import get_dist_info, init_dist
+
+def round_float(items):
+ if isinstance(items, list):
+ return [round_float(item) for item in items]
+ elif isinstance(items, float):
+ return round(items, 3)
+ elif isinstance(items, np.ndarray):
+ return round_float(float(items))
+ elif isinstance(items, torch.Tensor):
+ return round_float(items.detach().cpu().numpy())
+ else:
+ return items
+
+def train_one_epoch(model: torch.nn.Module,
+ criterion: torch.nn.Module,
+ data_loader: Iterable,
+ optimizer: torch.optim.Optimizer,
+ device: torch.device,
+ epoch: int,
+ max_norm: float = 0,
+ wo_class_error=False,
+ lr_scheduler=None,
+ args=None,
+ logger=None,
+ ema_m=None,
+ tf_writer=None):
+ scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
+
+ try:
+ need_tgt_for_training = args.use_dn
+ except:
+ need_tgt_for_training = False
+
+ model.train()
+ criterion.train()
+ metric_logger = utils.MetricLogger(delimiter=' ')
+ metric_logger.add_meter(
+ 'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ if not wo_class_error:
+ metric_logger.add_meter(
+ 'class_error', utils.SmoothedValue(window_size=1,
+ fmt='{value:.2f}'))
+ header = 'Epoch: [{}]'.format(epoch)
+ print_freq = 10
+
+ _cnt = 0
+
+ for step_i, data_batch in enumerate(metric_logger.log_every(data_loader,
+ print_freq,
+ header,
+ logger=logger)):
+ with torch.cuda.amp.autocast(enabled=args.amp):
+ if need_tgt_for_training:
+ outputs, targets, data_batch_nc = model(data_batch)
+ else:
+ outputs, targets, data_batch_nc = model(data_batch)
+
+ ['hand_kp3d_4', 'face_kp3d_4', 'hand_kp2d_4',]
+ loss_dict = criterion(outputs, targets, data_batch=data_batch_nc)
+ weight_dict = criterion.weight_dict
+
+ for k,v in weight_dict.items():
+ for n in ['hand_kp3d_4', 'face_kp3d_4', 'hand_kp2d_4']:
+ if n in k:
+ weight_dict[k] = weight_dict[k]/10
+
+ losses = sum(loss_dict[k] * weight_dict[k]
+ for k in loss_dict.keys() if k in weight_dict)
+
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
+ loss_dict_reduced_unscaled = {
+ f'{k}_unscaled': v
+ for k, v in loss_dict_reduced.items()
+ }
+ loss_dict_reduced_scaled = {
+ k: v * weight_dict[k]
+ for k, v in loss_dict_reduced.items() if k in weight_dict
+ }
+ losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
+
+ loss_value = losses_reduced_scaled.item()
+ # loss_value = loss_value+loss_value_smpl
+ for k,v in weight_dict.items():
+ for n in ['hand_kp3d_4', 'face_kp3d_4', 'hand_kp2d_4']:
+ if n in k:
+ weight_dict[k] = weight_dict[k]*10
+ if not math.isfinite(loss_value):
+ print('Loss is {}, stopping training'.format(loss_value))
+ print(loss_dict_reduced)
+ sys.exit(1)
+
+ # amp backward function
+ if args.amp:
+ optimizer.zero_grad()
+ scaler.scale(losses).backward()
+ if max_norm > 0:
+ scaler.unscale_(optimizer)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ optimizer.zero_grad()
+ losses.backward()
+ if max_norm > 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
+ optimizer.step()
+ if args.onecyclelr:
+ lr_scheduler.step()
+ if args.use_ema:
+ if epoch >= args.ema_epoch:
+ ema_m.update(model)
+ rank, _ = get_dist_info()
+
+ if rank == 0:
+ tf_writer.add_scalar(
+ 'loss', round_float(loss_value), step_i + len(data_loader) * epoch)
+ for k, v in loss_dict_reduced_scaled.items():
+ tf_writer.add_scalar(
+ k, round_float(v), step_i + len(data_loader) * epoch)
+ for k, v in loss_dict_reduced_unscaled.items():
+ tf_writer.add_scalar(
+ k, round_float(v), step_i + len(data_loader) * epoch)
+ json_log = OrderedDict()
+ json_log['now_time'] = str(datetime.datetime.now())
+ json_log['epoch'] = epoch
+ json_log['lr'] = optimizer.param_groups[0]['lr']
+ json_log['loss'] = round_float(loss_value)
+ for k, v in loss_dict_reduced_scaled.items():
+ json_log[k] = round_float(v)
+
+ for k, v in loss_dict_reduced_unscaled.items():
+ json_log[k] = round_float(v)
+
+ if rank == 0:
+ log_path = os.path.join(args.output_dir, 'train.log.json')
+ with open(log_path, 'a+') as f:
+ mmcv.dump(json_log, f, file_format='json')
+ f.write('\n')
+
+ # metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
+ metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled)
+ if 'class_error' in loss_dict_reduced:
+ metric_logger.update(class_error=loss_dict_reduced['class_error'])
+ metric_logger.update(lr=optimizer.param_groups[0]['lr'])
+
+ _cnt += 1
+ if args.debug:
+ if _cnt % 15 == 0:
+ print('BREAK!' * 5)
+ break
+
+ if getattr(criterion, 'loss_weight_decay', False):
+ criterion.loss_weight_decay(epoch=epoch)
+ if getattr(criterion, 'tuning_matching', False):
+ criterion.tuning_matching(epoch)
+
+ metric_logger.synchronize_between_processes()
+ print('Averaged stats:', metric_logger)
+ resstat = {
+ k: meter.global_avg
+ for k, meter in metric_logger.meters.items() if meter.count > 0
+ }
+ if getattr(criterion, 'loss_weight_decay', False):
+ resstat.update(
+ {f'weight_{k}': v
+ for k, v in criterion.weight_dict.items()})
+ return resstat
+
+
+@torch.no_grad()
+def evaluate(model,
+ criterion,
+ postprocessors,
+ data_loader,
+ device,
+ output_dir,
+ wo_class_error=False,
+ tmpdir=None,
+ gpu_collect=False,
+ args=None,
+ logger=None):
+ try:
+ need_tgt_for_training = args.use_dn
+ except:
+ need_tgt_for_training = False
+ model.eval()
+ criterion.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=' ')
+ if not wo_class_error:
+ metric_logger.add_meter(
+ 'class_error', utils.SmoothedValue(window_size=1,
+ fmt='{value:.2f}'))
+ header = 'Test:'
+ iou_types = tuple(k for k in ('bbox', 'keypoints'))
+ try:
+ useCats = args.useCats
+ except:
+ useCats = True
+ if not useCats:
+ print('useCats: {} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'.format(
+ useCats))
+
+ _cnt = 0
+ results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+
+ if rank == 0:
+ # Check if tmpdir is valid for cpu_collect
+ if (not gpu_collect) and (tmpdir is not None and osp.exists(tmpdir)):
+ raise OSError((f'The tmpdir {tmpdir} already exists.',
+ ' Since tmpdir will be deleted after testing,',
+ ' please make sure you specify an empty one.'))
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ time.sleep(2)
+ # i=0
+ cur_sample_idx = 0
+ eval_result = {}
+ # print()
+ cur_eval_result_list = []
+ rank, world_size = get_dist_info()
+
+ for data_batch in metric_logger.log_every(
+ data_loader, 10, header, logger=logger):
+ # i = i+1
+ with torch.cuda.amp.autocast(enabled=args.amp):
+ if need_tgt_for_training:
+ # outputs = model(samples, targets)
+ outputs, targets, data_batch_nc = model(data_batch)
+ else:
+ outputs,targets, data_batch_nc = model(data_batch)
+
+ orig_target_sizes = torch.stack([t["size"] for t in targets], dim=0)
+ result = postprocessors['bbox'](outputs, orig_target_sizes, targets, data_batch_nc,dataset = dataset)
+
+ # DOING SMPLer-X Test
+ cur_eval_result = dataset.evaluate(result,cur_sample_idx)
+
+ cur_eval_result_list.append(cur_eval_result)
+ # for cur_eval_result in cur_eval_result_list:
+ # for k, v in cur_eval_result.items():
+ # if k in eval_result:
+ # eval_result[k] += v
+ # else:
+ # eval_result[k] = v
+ cur_sample_idx += len(result)
+ cur_eval_result_new = collect_results_cpu(cur_eval_result_list, len(dataset))
+
+ if rank == 0:
+
+ cntt = 0
+ for res in cur_eval_result_new:
+
+ for k,v in res.items():
+ if len(v)>0:
+ if k != 'ann_idx' and k != 'img_path':
+ if k in eval_result:
+ eval_result[k].append(v)
+ else:
+ eval_result[k] = [v]
+
+ for k,v in eval_result.items():
+
+ # if k == 'mpvpe_all' or k == 'pa_mpvpe_all':
+ eval_result[k] = np.concatenate(v)
+
+
+ dataset.print_eval_result(eval_result)
+ # print(len(cur_eval_result_new))
+
+
+@torch.no_grad()
+def inference(model,
+ criterion,
+ postprocessors,
+ data_loader,
+ device,
+ output_dir,
+ wo_class_error=False,
+ tmpdir=None,
+ gpu_collect=False,
+ args=None,
+ logger=None):
+ try:
+ need_tgt_for_training = args.use_dn
+ except:
+ need_tgt_for_training = False
+ model.eval()
+ criterion.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=' ')
+ if not wo_class_error:
+ metric_logger.add_meter(
+ 'class_error', utils.SmoothedValue(window_size=1,
+ fmt='{value:.2f}'))
+ header = 'Test:'
+ iou_types = tuple(k for k in ('bbox', 'keypoints'))
+ try:
+ useCats = args.useCats
+ except:
+ useCats = True
+ if not useCats:
+ print('useCats: {} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'.format(
+ useCats))
+
+ _cnt = 0
+ results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+ for data_batch in metric_logger.log_every(data_loader, 10, header, logger=logger):
+ with torch.cuda.amp.autocast(enabled=args.amp):
+ if need_tgt_for_training:
+ # outputs = model(samples, targets)
+ outputs, targets, data_batch_nc = model(data_batch)
+ else:
+ outputs,targets, data_batch_nc = model(data_batch)
+
+ orig_target_sizes = torch.stack([t["size"] for t in targets], dim=0)
+ result = postprocessors['bbox'](outputs, orig_target_sizes, targets, data_batch_nc)
+ dataset.inference(result)
+
+ time.sleep(3)
+ if rank == 0 and args.to_vid:
+ # img_tmp = dataset.img_path[0]
+ if hasattr(dataset,'result_img_dir'):
+ import shutil
+ images_to_video(dataset.result_img_dir, os.path.join(dataset.output_path,'demo_vid.mp4'),remove_raw_file=False, fps=30)
+ # shutil.rmtree(dataset.result_img_dir)
+ # shutil.rmtree(dataset.tmp_dir)
+
+
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..130f88f53dbf2532d95305a945974a78b0d60d8b
--- /dev/null
+++ b/main.py
@@ -0,0 +1,396 @@
+import argparse
+import datetime
+import json
+import random
+import time
+from pathlib import Path
+import os, sys
+from util.get_param_dicts import get_param_dict
+from util.logger import setup_logger
+import numpy as np
+import torch
+
+import util.misc as utils
+from detrsmpl.data.datasets import build_dataloader
+from mmcv.parallel import MMDistributedDataParallel
+
+from engine import evaluate, train_one_epoch, inference
+from util.config import DictAction
+from util.utils import ModelEma
+
+import shutil
+import torchvision.transforms as transforms
+from torch.utils.tensorboard import SummaryWriter
+import config.config as cfg
+from datasets.dataset import MultipleDatasets
+
+def get_args_parser():
+ parser = argparse.ArgumentParser('Set transformer detector',
+ add_help=False)
+ parser.add_argument('--config_file', '-c', type=str, required=True)
+ parser.add_argument(
+ '--options',
+ nargs='+',
+ action=DictAction,
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file.')
+ # parser.add_argument('--exp_name', default='data/log/smplx_test', type=str)
+ # dataset parameters
+
+ # training parameters
+ parser.add_argument('--output_dir',
+ default='',
+ help='path where to save, empty for no saving')
+ parser.add_argument('--device',
+ default='cuda',
+ help='device to use for training / testing')
+ parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
+ parser.add_argument('--pretrain_model_path',
+ help='load from other checkpoint')
+ parser.add_argument('--finetune_ignore', type=str, nargs='+')
+ parser.add_argument('--start_epoch',
+ default=0,
+ type=int,
+ metavar='N',
+ help='start epoch')
+ parser.add_argument('--eval', action='store_true')
+ parser.add_argument('--num_workers', default=0, type=int)
+ parser.add_argument('--test', action='store_true')
+ parser.add_argument('--debug', action='store_true')
+ parser.add_argument('--find_unused_params', action='store_true')
+
+ parser.add_argument('--save_log', action='store_true')
+ parser.add_argument('--to_vid', action='store_true')
+ parser.add_argument('--inference', action='store_true')
+ # distributed training parameters
+
+ parser.add_argument('--world_size', default=1, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--rank', default=0, type=int,
+ help='number of distributed processes')
+ parser.add_argument("--local_rank", default=0, type=int, help='local rank for DistributedDataParallel')
+ parser.add_argument('--amp', action='store_true',
+ help="Train with mixed precision")
+
+ parser.add_argument('--inference_input', default=None, type=str)
+ return parser
+
+
+def build_model_main(args, cfg):
+ print(args.modelname)
+ from models.registry import MODULE_BUILD_FUNCS
+ assert args.modelname in MODULE_BUILD_FUNCS._module_dict
+ build_func = MODULE_BUILD_FUNCS.get(args.modelname)
+ model, criterion, postprocessors, postprocessors_aios = build_func(
+ args, cfg)
+ return model, criterion, postprocessors, postprocessors_aios
+
+
+def main(args):
+ utils.init_distributed_mode(args)
+ print('Loading config file from {}'.format(args.config_file))
+ shutil.copy2(args.config_file,'config/aios_smplx.py')
+ from config.config import cfg
+ if args.options is not None:
+ cfg.merge_from_dict(args.options)
+ if args.rank == 0:
+ save_cfg_path = os.path.join(args.output_dir, 'config_cfg.py')
+ cfg.dump(save_cfg_path)
+ save_json_path = os.path.join(args.output_dir, 'config_args_raw.json')
+ with open(save_json_path, 'w') as f:
+ json.dump(vars(args), f, indent=2)
+ cfg_dict = cfg._cfg_dict.to_dict()
+ args_vars = vars(args)
+ for k, v in cfg_dict.items():
+ if k not in args_vars:
+ setattr(args, k, v)
+ else:
+ continue
+ raise ValueError('Key {} can used by args only'.format(k))
+
+ # update some new args temporally
+ if not getattr(args, 'use_ema', None):
+ args.use_ema = False
+ if not getattr(args, 'debug', None):
+ args.debug = False
+
+ # setup logger
+ os.makedirs(args.output_dir, exist_ok=True)
+ logger = setup_logger(output=os.path.join(args.output_dir, 'info.txt'),
+ distributed_rank=args.rank,
+ color=False,
+ name='detr')
+ logger.info('git:\n {}\n'.format(utils.get_sha()))
+ logger.info('Command: ' + ' '.join(sys.argv))
+ writer = None
+ if args.rank == 0:
+ writer = SummaryWriter(args.output_dir)
+ save_json_path = os.path.join(args.output_dir, 'config_args_all.json')
+ # print("args:", vars(args))
+ with open(save_json_path, 'w') as f:
+ json.dump(vars(args), f, indent=2)
+ logger.info('Full config saved to {}'.format(save_json_path))
+ logger.info('world size: {}'.format(args.world_size))
+ logger.info('rank: {}'.format(args.rank))
+ logger.info('local_rank: {}'.format(args.local_rank))
+ logger.info('args: ' + str(args) + '\n')
+
+ if args.frozen_weights is not None:
+ assert args.masks, 'Frozen training is meant for segmentation only'
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+ # build model
+ model, criterion, postprocessors, _ = build_model_main(
+ args, cfg)
+
+ wo_class_error = False
+ model.to(device)
+
+ # ema
+ if args.use_ema:
+ ema_m = ModelEma(model, args.ema_decay)
+ else:
+ ema_m = None
+
+ model_without_ddp = model
+ if args.distributed:
+ model = MMDistributedDataParallel(
+ model,
+ device_ids=[args.gpu],
+ find_unused_parameters=args.find_unused_params)
+ model_without_ddp = model.module
+ n_parameters = sum(p.numel() for p in model.parameters()
+ if p.requires_grad)
+ logger.info('number of params:' + str(n_parameters))
+ logger.info('params:\n' + json.dumps(
+ {n: p.numel()
+ for n, p in model.named_parameters() if p.requires_grad},
+ indent=2))
+
+ param_dicts = get_param_dict(args, model_without_ddp)
+ optimizer = torch.optim.AdamW(param_dicts,
+ lr=args.lr,
+ weight_decay=args.weight_decay)
+
+ logger.info('Creating dataset...')
+ if not args.eval:
+ trainset= []
+ for trainset_i,v in cfg.trainset_partition.items():
+ exec('from datasets.' + trainset_i +
+ ' import ' + trainset_i)
+ trainset.append(
+ eval(trainset_i)(transforms.ToTensor(), 'train'))
+ trainset_loader = MultipleDatasets(trainset, make_same_len=False,partition=cfg.trainset_partition)
+
+ data_loader_train = build_dataloader(
+ trainset_loader,
+ args.batch_size,
+ 0 if 'workers_per_gpu' in args else 1,
+ dist=args.distributed)
+ exec('from datasets.' + cfg.testset +
+ ' import ' + cfg.testset)
+
+
+ if not args.inference:
+ dataset_val = eval(cfg.testset)(transforms.ToTensor(), "test")
+ else:
+ dataset_val = eval(cfg.testset)(args.inference_input, args.output_dir)
+
+ data_loader_val = build_dataloader(
+ dataset_val,
+ args.batch_size,
+ 0 if 'workers_per_gpu' in args else 2,
+ dist=args.distributed,
+ shuffle=False)
+
+ if args.onecyclelr:
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
+ optimizer,
+ max_lr=args.lr,
+ steps_per_epoch=len(data_loader_train),
+ epochs=args.epochs,
+ pct_start=0.2)
+ elif args.multi_step_lr:
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
+ optimizer, milestones=args.lr_drop_list)
+ else:
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
+
+ if args.frozen_weights is not None:
+ checkpoint = torch.load(args.frozen_weights, map_location='cpu')
+ model_without_ddp.detr.load_state_dict(checkpoint['model'])
+
+ output_dir = Path(args.output_dir)
+ if os.path.exists(os.path.join(args.output_dir, 'checkpoint.pth')):
+ args.resume = os.path.join(args.output_dir, 'checkpoint.pth')
+ if args.resume:
+ if args.resume.startswith('https'):
+ checkpoint = torch.hub.load_state_dict_from_url(args.resume,
+ map_location='cpu',
+ check_hash=True)
+ else:
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ model_without_ddp.load_state_dict(checkpoint['model'])
+ if args.use_ema:
+ if 'ema_model' in checkpoint:
+ ema_m.module.load_state_dict(
+ utils.clean_state_dict(checkpoint['ema_model']))
+ else:
+ del ema_m
+ ema_m = ModelEma(model, args.ema_decay)
+
+ if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ args.start_epoch = checkpoint['epoch'] + 1
+
+ if (not args.resume) and args.pretrain_model_path:
+ checkpoint = torch.load(args.pretrain_model_path,
+ map_location='cpu')['model']
+ from collections import OrderedDict
+ _ignorekeywordlist = args.finetune_ignore if args.finetune_ignore else []
+ ignorelist = []
+
+ def check_keep(keyname, ignorekeywordlist):
+ for keyword in ignorekeywordlist:
+ if keyword in keyname:
+ ignorelist.append(keyname)
+ return False
+ return True
+
+
+ _tmp_st = OrderedDict({
+ k: v
+ for k, v in utils.clean_state_dict(checkpoint).items()
+ if check_keep(k, _ignorekeywordlist)
+ })
+ logger.info('Ignore keys: {}'.format(json.dumps(ignorelist, indent=2)))
+ # Change This
+ _load_output = model_without_ddp.load_state_dict(_tmp_st, strict=False)
+ print('loading')
+ logger.info(str(_load_output))
+
+ if args.use_ema:
+ if 'ema_model' in checkpoint:
+ ema_m.module.load_state_dict(utils.clean_state_dict(checkpoint['ema_model']))
+ else:
+ del ema_m
+ ema_m = ModelEma(model, args.ema_decay)
+ _load_output = model_without_ddp.load_state_dict(_tmp_st, strict=False)
+ logger.info(str(_load_output))
+
+
+ if args.eval:
+ os.environ['EVAL_FLAG'] = 'TRUE'
+
+ if args.inference_input is not None and args.inference:
+ inference(model,
+ criterion,
+ postprocessors,
+ data_loader_val,
+ device,
+ args.output_dir,
+ wo_class_error=wo_class_error,
+ args=args)
+ else:
+
+ from config.config import cfg
+ cfg.result_dir=args.output_dir
+ cfg.exp_name=args.pretrain_model_path
+ evaluate(model,
+ criterion,
+ postprocessors,
+ data_loader_val,
+ device,
+ args.output_dir,
+ wo_class_error=wo_class_error,
+ args=args)
+
+ return
+
+ print('Start training')
+ start_time = time.time()
+ for epoch in range(args.start_epoch, args.epochs):
+ epoch_start_time = time.time()
+
+ train_stats = train_one_epoch(
+ model,
+ criterion,
+ data_loader_train,
+ optimizer,
+ device,
+ epoch,
+ args.clip_max_norm,
+ wo_class_error=wo_class_error,
+ lr_scheduler=lr_scheduler,
+ args=args,
+ logger=(logger if args.save_log else None),
+ ema_m=ema_m,
+ tf_writer=writer)
+ if args.output_dir:
+ checkpoint_paths = [output_dir / 'checkpoint.pth']
+
+ if not args.onecyclelr:
+ lr_scheduler.step()
+ if args.output_dir:
+ checkpoint_paths = [output_dir / 'checkpoint.pth']
+ # extra checkpoint before LR drop and every 100 epochs
+ if (epoch + 1) % args.lr_drop == 0 or (
+ epoch + 1) % args.save_checkpoint_interval == 0:
+ checkpoint_paths.append(output_dir /
+ f'checkpoint{epoch:04}.pth')
+ for checkpoint_path in checkpoint_paths:
+ weights = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ 'epoch': epoch,
+ 'args': args,
+ }
+ if args.use_ema:
+ weights.update({
+ 'ema_model': ema_m.module.state_dict(),
+ })
+ utils.save_on_master(weights, checkpoint_path)
+ log_stats = {
+ **{f'train_{k}': v
+ for k, v in train_stats.items()},
+ }
+
+ ep_paras = {'epoch': epoch, 'n_parameters': n_parameters}
+ log_stats.update(ep_paras)
+ try:
+ log_stats.update({'now_time': str(datetime.datetime.now())})
+ except:
+ pass
+
+ epoch_time = time.time() - epoch_start_time
+ epoch_time_str = str(datetime.timedelta(seconds=int(epoch_time)))
+ log_stats['epoch_time'] = epoch_time_str
+
+ if args.output_dir and utils.is_main_process():
+ with (output_dir / 'log.txt').open('a') as f:
+ f.write(json.dumps(log_stats) + '\n')
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser('DETR training and evaluation script',
+ parents=[get_args_parser()])
+ __spec__ = "ModuleSpec(name='builtins', loader=)"
+ args = parser.parse_args()
+ if args.output_dir:
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+ main(args)
diff --git a/mmcv/.circleci/config.yml b/mmcv/.circleci/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8fbf916c0254efc459ca7abc5d078ae2bda1f43b
--- /dev/null
+++ b/mmcv/.circleci/config.yml
@@ -0,0 +1,173 @@
+version: 2.1
+jobs:
+ lint:
+ docker:
+ - image: cimg/python:3.7.4
+ steps:
+ - checkout
+ - run:
+ name: Install pre-commit hook
+ command: |
+ pip install pre-commit
+ pre-commit install
+ - run:
+ name: Linting
+ command: pre-commit run --all-files
+
+ build_cpu:
+ parameters:
+ # The python version must match available image tags in
+ # https://circleci.com/developer/images/image/cimg/python
+ python:
+ type: string
+ default: "3.7.0"
+ torch:
+ type: string
+ torchvision:
+ type: string
+ machine:
+ image: ubuntu-2004:202010-01
+ resource_class: large
+ steps:
+ - checkout
+ - run:
+ name: Install system dependencies
+ command: |
+ sudo apt-get update
+ sudo apt-get install -y ffmpeg libturbojpeg ninja-build
+ ffmpeg -version
+ - run:
+ # https://github.com/pytorch/vision/issues/2921
+ name: Install dependency of torchvision when using pyenv
+ command: sudo apt-get install -y liblzma-dev
+ - run:
+ # python3.7 should be re-installed due to the issue https://github.com/pytorch/vision/issues/2921
+ name: Select Python
+ command: |
+ pyenv uninstall -f << parameters.python >>
+ pyenv install << parameters.python >>
+ pyenv global << parameters.python >>
+ - run:
+ name: Upgrade pip
+ command: |
+ python -m pip install pip --upgrade
+ - run:
+ name: Install PyTorch
+ command: python -m pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ - run:
+ name: Install psutil
+ command: python -m pip install psutil
+ - run:
+ name: Build and install
+ command: |
+ rm -rf .eggs
+ python setup.py check -m -s
+ python -m pip install -e .
+ no_output_timeout: 20m
+ environment:
+ MMCV_WITH_OPS: 1
+ - run:
+ name: Install dependencies of unit test
+ command: |
+ python -m pip install -r requirements/test.txt
+ - run:
+ name: Run unittests and generate coverage report
+ command: |
+ python -m coverage run --branch --source mmcv -m pytest tests/
+ python -m coverage xml
+ python -m coverage report -m
+
+ build_cu102:
+ machine:
+ image: ubuntu-1604-cuda-10.1:201909-23 # the actual version of cuda is 10.2
+ resource_class: gpu.nvidia.small
+ steps:
+ - checkout
+ - run:
+ name: Set CUDA environment
+ command: |
+ echo 'export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH' >> $BASH_ENV
+ echo 'export PATH=/usr/local/cuda/bin:$PATH' >> $BASH_ENV
+ echo 'export CUDA_HOME=/usr/local/cuda' >> $BASH_ENV
+ source $BASH_ENV
+ nvidia-smi
+ nvcc --version
+ gcc --version
+ - run:
+ name: Install system dependencies
+ command: |
+ sudo apt-get update
+ sudo apt-get install -y libturbojpeg ninja-build
+ # the default version of ffmpeg is 2.8.7, which should be upgraded to 4+
+ sudo add-apt-repository -y ppa:jonathonf/ffmpeg-4
+ sudo apt-get update
+ sudo apt-get install -y ffmpeg
+ ffmpeg -version
+ sudo add-apt-repository --remove ppa:jonathonf/ffmpeg-4 -y
+ - run:
+ # https://github.com/pytorch/vision/issues/2921
+ name: Install dependency of torchvision when using pyenv
+ command: sudo apt-get install -y liblzma-dev
+ - run:
+ # python3.7 should be re-installed due to the issue https://github.com/pytorch/vision/issues/2921
+ name: Select python3.7
+ command: |
+ pyenv uninstall -f 3.7.0
+ pyenv install 3.7.0
+ pyenv global 3.7.0
+ - run:
+ name: Upgrade pip
+ command: |
+ python -m pip install pip --upgrade
+ - run:
+ name: Install PyTorch
+ command: python -m pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 -f https://download.pytorch.org/whl/torch_stable.html
+ - run:
+ name: Install psutil
+ command: python -m pip install psutil
+ - run:
+ name: Download onnxruntime library and install onnxruntime
+ command: |
+ wget https://github.com/microsoft/onnxruntime/releases/download/v1.8.1/onnxruntime-linux-x64-1.8.1.tgz
+ tar -zxvf onnxruntime-linux-x64-1.8.1.tgz
+ echo 'export ONNXRUNTIME_DIR=$(pwd)/onnxruntime-linux-x64-1.8.1' >> $BASH_ENV
+ echo 'export LD_LIBRARY_PATH=$ONNXRUNTIME_DIR/lib:$LD_LIBRARY_PATH' >> $BASH_ENV
+ source $BASH_ENV
+ python -m pip install onnxruntime==1.8.1
+ - run:
+ name: Build and install
+ command: |
+ rm -rf .eggs
+ python setup.py check -m -s
+ python -m pip install -e .
+ environment:
+ MMCV_WITH_OPS: 1
+ MMCV_WITH_ORT: 1
+ - run:
+ name: Install dependencies for unit test
+ command: |
+ python -m pip install -r requirements/test.txt
+ - run:
+ name: Run unittests and generate coverage report
+ command: |
+ python -m coverage run --branch --source mmcv -m pytest tests/
+ python -m coverage xml
+ python -m coverage report -m
+workflows:
+ unit_tests:
+ jobs:
+ - lint
+ - build_cpu:
+ name: build_py3.8_pt1.9_cpu
+ torch: 1.9.0
+ torchvision: 0.10.0
+ python: "3.8.0"
+ requires:
+ - lint
+ - hold:
+ type: approval # <<< This key-value pair will set your workflow to a status of "On Hold"
+ requires:
+ - build_py3.8_pt1.9_cpu
+ - build_cu102:
+ requires:
+ - hold
diff --git a/mmcv/.dev_scripts/check_installation.py b/mmcv/.dev_scripts/check_installation.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b771acc5b227bc584ad6dc46c5a85d16a16d6a2
--- /dev/null
+++ b/mmcv/.dev_scripts/check_installation.py
@@ -0,0 +1,44 @@
+import numpy as np
+import torch
+
+from mmcv.ops import box_iou_rotated
+from mmcv.utils import collect_env
+
+
+def check_installation():
+ """Check whether mmcv-full has been installed successfully."""
+ np_boxes1 = np.asarray(
+ [[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
+ [7.0, 7.0, 8.0, 8.0, 0.4]],
+ dtype=np.float32)
+ np_boxes2 = np.asarray(
+ [[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5],
+ [5.0, 5.0, 6.0, 7.0, 0.4]],
+ dtype=np.float32)
+ boxes1 = torch.from_numpy(np_boxes1)
+ boxes2 = torch.from_numpy(np_boxes2)
+
+ # test mmcv-full with CPU ops
+ box_iou_rotated(boxes1, boxes2)
+ print('CPU ops were compiled successfully.')
+
+ # test mmcv-full with both CPU and CUDA ops
+ if torch.cuda.is_available():
+ boxes1 = boxes1.cuda()
+ boxes2 = boxes2.cuda()
+ box_iou_rotated(boxes1, boxes2)
+ print('CUDA ops were compiled successfully.')
+ else:
+ print('No CUDA runtime is found, skipping the checking of CUDA ops.')
+
+
+if __name__ == '__main__':
+ print('Start checking the installation of mmcv-full ...')
+ check_installation()
+ print('mmcv-full has been installed successfully.\n')
+
+ env_info_dict = collect_env()
+ env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
+ dash_line = '-' * 60 + '\n'
+ print('Environment information:')
+ print(dash_line + env_info + '\n' + dash_line)
diff --git a/mmcv/.dev_scripts/visualize_lr.py b/mmcv/.dev_scripts/visualize_lr.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ca9aaa116e75b8f693589d1bcc0031d5ace0277
--- /dev/null
+++ b/mmcv/.dev_scripts/visualize_lr.py
@@ -0,0 +1,230 @@
+import argparse
+import json
+import os
+import os.path as osp
+import time
+import warnings
+from collections import OrderedDict
+from unittest.mock import patch
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch.nn as nn
+from torch.optim import SGD
+from torch.utils.data import DataLoader
+
+import mmcv
+from mmcv.runner import build_runner
+from mmcv.utils import get_logger
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Visualize the given config'
+ 'of learning rate and momentum, and this'
+ 'script will overwrite the log_config')
+ parser.add_argument('config', help='train config file path')
+ parser.add_argument(
+ '--work-dir', default='./', help='the dir to save logs and models')
+ parser.add_argument(
+ '--num-iters', default=300, help='The number of iters per epoch')
+ parser.add_argument(
+ '--num-epochs', default=300, help='Only used in EpochBasedRunner')
+ parser.add_argument(
+ '--window-size',
+ default='12*14',
+ help='Size of the window to display images, in format of "$W*$H".')
+ parser.add_argument(
+ '--log-interval', default=10, help='The interval of TextLoggerHook')
+ args = parser.parse_args()
+ return args
+
+
+class SimpleModel(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.conv = nn.Conv2d(1, 1, 1)
+
+ def train_step(self, *args, **kwargs):
+ return dict()
+
+ def val_step(self, *args, **kwargs):
+ return dict()
+
+
+def iter_train(self, data_loader, **kwargs):
+ self.mode = 'train'
+ self.data_loader = data_loader
+ self.call_hook('before_train_iter')
+ self.call_hook('after_train_iter')
+ self._inner_iter += 1
+ self._iter += 1
+
+
+def epoch_train(self, data_loader, **kwargs):
+ self.model.train()
+ self.mode = 'train'
+ self.data_loader = data_loader
+ self._max_iters = self._max_epochs * len(self.data_loader)
+ self.call_hook('before_train_epoch')
+ for i, data_batch in enumerate(self.data_loader):
+ self._inner_iter = i
+ self.call_hook('before_train_iter')
+ self.call_hook('after_train_iter')
+ self._iter += 1
+ self.call_hook('after_train_epoch')
+ self._epoch += 1
+
+
+def log(self, runner):
+ cur_iter = self.get_iter(runner, inner_iter=True)
+
+ log_dict = OrderedDict(
+ mode=self.get_mode(runner),
+ epoch=self.get_epoch(runner),
+ iter=cur_iter)
+
+ # only record lr of the first param group
+ cur_lr = runner.current_lr()
+ if isinstance(cur_lr, list):
+ log_dict['lr'] = cur_lr[0]
+ else:
+ assert isinstance(cur_lr, dict)
+ log_dict['lr'] = {}
+ for k, lr_ in cur_lr.items():
+ assert isinstance(lr_, list)
+ log_dict['lr'].update({k: lr_[0]})
+
+ cur_momentum = runner.current_momentum()
+ if isinstance(cur_momentum, list):
+ log_dict['momentum'] = cur_momentum[0]
+ else:
+ assert isinstance(cur_momentum, dict)
+ log_dict['momentum'] = {}
+ for k, lr_ in cur_momentum.items():
+ assert isinstance(lr_, list)
+ log_dict['momentum'].update({k: lr_[0]})
+ log_dict = dict(log_dict, **runner.log_buffer.output)
+ self._log_info(log_dict, runner)
+ self._dump_log(log_dict, runner)
+ return log_dict
+
+
+@patch('torch.cuda.is_available', lambda: False)
+@patch('mmcv.runner.EpochBasedRunner.train', epoch_train)
+@patch('mmcv.runner.IterBasedRunner.train', iter_train)
+@patch('mmcv.runner.hooks.TextLoggerHook.log', log)
+def run(cfg, logger):
+ momentum_config = cfg.get('momentum_config')
+ lr_config = cfg.get('lr_config')
+
+ model = SimpleModel()
+ optimizer = SGD(model.parameters(), 0.1, momentum=0.8)
+ cfg.work_dir = cfg.get('work_dir', './')
+ workflow = [('train', 1)]
+
+ if cfg.get('runner') is None:
+ cfg.runner = {
+ 'type': 'EpochBasedRunner',
+ 'max_epochs': cfg.get('total_epochs', cfg.num_epochs)
+ }
+ warnings.warn(
+ 'config is now expected to have a `runner` section, '
+ 'please set `runner` in your config.', UserWarning)
+ batch_size = 1
+ data = cfg.get('data')
+ if data:
+ batch_size = data.get('samples_per_gpu')
+ fake_dataloader = DataLoader(
+ list(range(cfg.num_iters)), batch_size=batch_size)
+ runner = build_runner(
+ cfg.runner,
+ default_args=dict(
+ model=model,
+ batch_processor=None,
+ optimizer=optimizer,
+ work_dir=cfg.work_dir,
+ logger=logger,
+ meta=None))
+ log_config = dict(
+ interval=cfg.log_interval, hooks=[
+ dict(type='TextLoggerHook'),
+ ])
+
+ runner.register_training_hooks(lr_config, log_config=log_config)
+ runner.register_momentum_hook(momentum_config)
+ runner.run([fake_dataloader], workflow)
+
+
+def plot_lr_curve(json_file, cfg):
+ data_dict = dict(LearningRate=[], Momentum=[])
+ assert os.path.isfile(json_file)
+ with open(json_file) as f:
+ for line in f:
+ log = json.loads(line.strip())
+ data_dict['LearningRate'].append(log['lr'])
+ data_dict['Momentum'].append(log['momentum'])
+
+ wind_w, wind_h = (int(size) for size in cfg.window_size.split('*'))
+ # if legend is None, use {filename}_{key} as legend
+ fig, axes = plt.subplots(2, 1, figsize=(wind_w, wind_h))
+ plt.subplots_adjust(hspace=0.5)
+ font_size = 20
+ for index, (updater_type, data_list) in enumerate(data_dict.items()):
+ ax = axes[index]
+ if cfg.runner.type == 'EpochBasedRunner':
+ ax.plot(data_list, linewidth=1)
+ ax.xaxis.tick_top()
+ ax.set_xlabel('Iters', fontsize=font_size)
+ ax.xaxis.set_label_position('top')
+ sec_ax = ax.secondary_xaxis(
+ 'bottom',
+ functions=(lambda x: x / cfg.num_iters * cfg.log_interval,
+ lambda y: y * cfg.num_iters / cfg.log_interval))
+ sec_ax.tick_params(labelsize=font_size)
+ sec_ax.set_xlabel('Epochs', fontsize=font_size)
+ else:
+ # plt.subplot(2, 1, index + 1)
+ x_list = np.arange(len(data_list)) * cfg.log_interval
+ ax.plot(x_list, data_list)
+ ax.set_xlabel('Iters', fontsize=font_size)
+ ax.set_ylabel(updater_type, fontsize=font_size)
+ if updater_type == 'LearningRate':
+ if cfg.get('lr_config'):
+ title = cfg.lr_config.type
+ else:
+ title = 'No learning rate scheduler'
+ else:
+ if cfg.get('momentum_config'):
+ title = cfg.momentum_config.type
+ else:
+ title = 'No momentum scheduler'
+ ax.set_title(title, fontsize=font_size)
+ ax.grid()
+ # set tick font size
+ ax.tick_params(labelsize=font_size)
+ save_path = osp.join(cfg.work_dir, 'visualization-result')
+ plt.savefig(save_path)
+ print(f'The learning rate graph is saved at {save_path}.png')
+ plt.show()
+
+
+def main():
+ args = parse_args()
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+ cfg = mmcv.Config.fromfile(args.config)
+ cfg['num_iters'] = args.num_iters
+ cfg['num_epochs'] = args.num_epochs
+ cfg['log_interval'] = args.log_interval
+ cfg['window_size'] = args.window_size
+
+ log_path = osp.join(cfg.get('work_dir', './'), f'{timestamp}.log')
+ json_path = log_path + '.json'
+ logger = get_logger('mmcv', log_path)
+
+ run(cfg, logger)
+ plot_lr_curve(json_path, cfg)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/mmcv/.dockerignore b/mmcv/.dockerignore
new file mode 100644
index 0000000000000000000000000000000000000000..8c22f226d3e2d8a625515290691d2cfc6ed87f2e
--- /dev/null
+++ b/mmcv/.dockerignore
@@ -0,0 +1,6 @@
+.git
+.gitignore
+*.egg-info
+.eggs/
+.mypy-cache
+pip-wheel-metadata
diff --git a/mmcv/.github/ISSUE_TEMPLATE/config.yml b/mmcv/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9ca189206785218a14096a6f9563b1f976ffb12f
--- /dev/null
+++ b/mmcv/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1,9 @@
+blank_issues_enabled: false
+
+contact_links:
+ - name: Common Issues
+ url: https://mmcv.readthedocs.io/en/latest/trouble_shooting.html
+ about: Check if your issue already has solutions
+ - name: MMCV Documentation
+ url: https://mmcv.readthedocs.io/en/latest/
+ about: Check if your question is answered in docs
diff --git a/mmcv/.github/ISSUE_TEMPLATE/feature_request.md b/mmcv/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 0000000000000000000000000000000000000000..7bf92e8c912df6839eb755715c181f5fc7244f36
--- /dev/null
+++ b/mmcv/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,21 @@
+---
+name: Feature request
+about: Suggest an idea for this project
+title: ''
+labels: ''
+assignees: ''
+---
+
+**Describe the feature**
+
+**Motivation**
+A clear and concise description of the motivation of the feature.
+Ex1. It is inconvenient when \[....\].
+Ex2. There is a recent paper \[....\], which is very helpful for \[....\].
+
+**Related resources**
+If there is an official code release or third-party implementations, please also provide the information here, which would be very helpful.
+
+**Additional context**
+Add any other context or screenshots about the feature request here.
+If you would like to implement the feature and create a PR, please leave a comment here and that would be much appreciated.
diff --git a/mmcv/.github/ISSUE_TEMPLATE/general_questions.md b/mmcv/.github/ISSUE_TEMPLATE/general_questions.md
new file mode 100644
index 0000000000000000000000000000000000000000..b5eaf2610781037b0cbea9a146c034ebb36f2934
--- /dev/null
+++ b/mmcv/.github/ISSUE_TEMPLATE/general_questions.md
@@ -0,0 +1,12 @@
+---
+name: General questions
+about: Ask general questions to get help
+title: ''
+labels: ''
+assignees: ''
+---
+
+**Checklist**
+
+1. I have searched related issues but cannot get the expected help.
+2. I have read the FAQ documentation but cannot get the expected help.
diff --git a/mmcv/.github/ISSUE_TEMPLATE/unexpected_report.md b/mmcv/.github/ISSUE_TEMPLATE/unexpected_report.md
new file mode 100644
index 0000000000000000000000000000000000000000..b0ccc0fd18bc1548d517e05afe9ae183d32bb0f9
--- /dev/null
+++ b/mmcv/.github/ISSUE_TEMPLATE/unexpected_report.md
@@ -0,0 +1,45 @@
+---
+name: Unexpected Results
+about: Create a report to help us improve
+title: ''
+labels: ''
+assignees: ''
+---
+
+Thanks for reporting the unexpected results and we appreciate it a lot.
+
+**Checklist**
+
+1. I have searched related issues but cannot get the expected help.
+2. I have read the [FAQ documentation](https://mmcv.readthedocs.io/en/latest/trouble_shooting.html) but cannot get the expected help.
+3. The unexpected results still exist in the latest version.
+
+**Describe the Issue**
+A clear and concise description of what the bug is, including what results are expected and what the real results you got.
+
+**Reproduction**
+
+1. What command, code, or script did you run?
+
+```bash
+A placeholder for the command.
+```
+
+2. Did you make any modifications on the code? Did you understand what you have modified?
+
+**Environment**
+
+1. Please run `python -c "from mmcv.utils import collect_env; print(collect_env())"` to collect necessary environment information and paste it here.
+2. You may add addition that may be helpful for locating the problem, such as
+ - How you installed PyTorch \[e.g., pip, conda, source\]
+ - Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
+
+**Error traceback**
+If applicable, paste the error traceback here.
+
+```none
+A placeholder for traceback.
+```
+
+**Bug fix**
+If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
diff --git a/mmcv/.github/pull_request_template.md b/mmcv/.github/pull_request_template.md
new file mode 100644
index 0000000000000000000000000000000000000000..0980b85db1c5fc90b2a8c32aa5fbdf923b25bf32
--- /dev/null
+++ b/mmcv/.github/pull_request_template.md
@@ -0,0 +1,33 @@
+Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
+
+## Motivation
+
+Please describe the motivation of this PR and the goal you want to achieve through this PR.
+
+## Modification
+
+Please briefly describe what modification is made in this PR.
+
+## BC-breaking (Optional)
+
+Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
+If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
+
+## Use cases (Optional)
+
+If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
+
+## Checklist
+
+**Before PR**:
+
+- [ ] I have read and followed the workflow indicated in the [CONTRIBUTING.md](https://github.com/open-mmlab/mmcv/blob/master/CONTRIBUTING.md) to create this PR.
+- [ ] Pre-commit or linting tools indicated in [CONTRIBUTING.md](https://github.com/open-mmlab/mmcv/blob/master/CONTRIBUTING.md) are used to fix the potential lint issues.
+- [ ] Bug fixes are covered by unit tests, the case that causes the bug should be added in the unit tests.
+- [ ] New functionalities are covered by complete unit tests. If not, please add more unit test to ensure the correctness.
+- [ ] The documentation has been modified accordingly, including docstring or example tutorials.
+
+**After PR**:
+
+- [ ] If the modification has potential influence on downstream or other related projects, this PR should be tested with some of those projects, like MMDet or MMCls.
+- [ ] CLA has been signed and all committers have signed the CLA in this PR.
diff --git a/mmcv/.github/workflows/build.yml b/mmcv/.github/workflows/build.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e2ec9d8796e3309227abe85319e55486dac744bd
--- /dev/null
+++ b/mmcv/.github/workflows/build.yml
@@ -0,0 +1,404 @@
+name: build
+
+on:
+ push:
+ paths-ignore:
+ - 'README.md'
+ - 'README_zh-CN.md'
+ - 'docs/**'
+ - 'examples/**'
+ - '.dev_scripts/**'
+ - 'docker/**'
+
+ pull_request:
+ paths-ignore:
+ - 'README.md'
+ - 'README_zh-CN.md'
+ - 'docs/**'
+ - 'examples/**'
+ - '.dev_scripts/**'
+ - 'docker/**'
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+env:
+ MMCV_WITH_OPS: 1
+
+jobs:
+ build_without_torch:
+ runs-on: ubuntu-18.04
+ strategy:
+ matrix:
+ python-version: [3.7]
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install system dependencies
+ run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg
+ - name: Build and install
+ run: rm -rf .eggs && pip install -e .
+ - name: Validate the installation
+ run: python -c "import mmcv"
+ - name: Run unittests and generate coverage report
+ run: |
+ pip install -r requirements/test.txt
+ pytest tests/ \
+ --ignore=tests/test_runner \
+ --ignore=tests/test_device/test_ipu \
+ --ignore=tests/test_optimizer.py \
+ --ignore=tests/test_cnn \
+ --ignore=tests/test_parallel.py \
+ --ignore=tests/test_ops \
+ --ignore=tests/test_load_model_zoo.py \
+ --ignore=tests/test_utils/test_logging.py \
+ --ignore=tests/test_image/test_io.py \
+ --ignore=tests/test_utils/test_registry.py \
+ --ignore=tests/test_utils/test_parrots_jit.py \
+ --ignore=tests/test_utils/test_trace.py \
+ --ignore=tests/test_utils/test_hub.py \
+ --ignore=tests/test_device \
+ --ignore=tests/test_utils/test_torch_ops.py
+
+ build_without_ops:
+ runs-on: ubuntu-18.04
+ env:
+ MMCV_WITH_OPS: 0
+ strategy:
+ matrix:
+ python-version: [3.7]
+ torch: [1.7.0, 1.8.0, 1.9.0]
+ include:
+ - torch: 1.7.0
+ torchvision: 0.8.1
+ - torch: 1.8.0
+ torchvision: 0.9.0
+ - torch: 1.9.0
+ torchvision: 0.10.0
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install system dependencies
+ run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg
+ - name: Install PyTorch
+ run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ - name: Build and install
+ run: rm -rf .eggs && pip install -e .
+ - name: Validate the installation
+ run: python -c "import mmcv"
+ - name: Run unittests
+ run: |
+ pip install -r requirements/test.txt
+ pytest tests/ --ignore=tests/test_ops
+
+ build_cpu:
+ runs-on: ubuntu-18.04
+ strategy:
+ matrix:
+ python-version: [3.7]
+ torch: [1.5.1, 1.6.0, 1.7.0, 1.8.0, 1.9.0]
+ include:
+ - torch: 1.5.1
+ torchvision: 0.6.1
+ - torch: 1.6.0
+ torchvision: 0.7.0
+ - torch: 1.7.0
+ torchvision: 0.8.1
+ - torch: 1.8.0
+ torchvision: 0.9.0
+ - torch: 1.9.0
+ torchvision: 0.10.0
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install system dependencies
+ run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg
+ - name: Install PyTorch
+ run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ # pstuil is an optional package to detect the number of CPU for compiling mmcv
+ - name: Install psutil
+ run: pip install psutil
+ - name: Create sdist and untar
+ run: |
+ MMCV_WITH_OPS=1 python setup.py sdist
+ tar zxvf dist/mmcv-full* -C /tmp
+ rm -r mmcv
+ - name: Build and install from sdist
+ run: |
+ pushd /tmp/mmcv-full*
+ pip install -e .
+ popd
+ - name: Validate the installation
+ run: python -c "import mmcv"
+ - name: Run unittests and generate coverage report
+ run: |
+ pip install -r requirements/test.txt
+ coverage run --branch --source=mmcv -m pytest tests/
+ coverage xml
+ coverage report -m
+
+ build_cu101:
+ runs-on: ubuntu-18.04
+ container:
+ image: pytorch/pytorch:1.6.0-cuda10.1-cudnn7-devel
+ env:
+ FORCE_CUDA: 1
+ MMCV_CUDA_ARGS: -gencode=arch=compute_61,code=sm_61
+ strategy:
+ matrix:
+ python-version: [3.7]
+ torch: [1.3.1, 1.5.1+cu101, 1.6.0+cu101, 1.7.0+cu101, 1.8.0+cu101]
+ include:
+ - torch: 1.3.1
+ torchvision: 0.4.2
+ - torch: 1.5.1+cu101
+ torchvision: 0.6.1+cu101
+ - torch: 1.6.0+cu101
+ torchvision: 0.7.0+cu101
+ - torch: 1.7.0+cu101
+ torchvision: 0.8.1+cu101
+ - torch: 1.8.0+cu101
+ torchvision: 0.9.0+cu101
+ - python-version: 3.6
+ torch: 1.8.0+cu101
+ torchvision: 0.9.0+cu101
+ - python-version: 3.8
+ torch: 1.8.0+cu101
+ torchvision: 0.9.0+cu101
+ - python-version: 3.9
+ torch: 1.8.0+cu101
+ torchvision: 0.9.0+cu101
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Fetch GPG keys
+ run: |
+ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
+ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
+ - name: Install python-dev
+ run: apt-get update && apt-get install -y python${{matrix.python-version}}-dev
+ if: ${{matrix.python-version != '3.9'}}
+ - name: Install Pillow
+ run: python -m pip install Pillow==6.2.2
+ if: ${{matrix.torchvision == '0.4.2'}}
+ # When we use a third-party container, we need to add python -m to call
+ # the user-installed pip when we use the pip command, otherwise it will
+ # call the system pip
+ - name: Install PyTorch
+ run: python -m pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
+ - name: Install system dependencies
+ run: apt-get update && apt-get install -y ffmpeg libturbojpeg ninja-build
+ - name: Install dependencies for compiling onnx when python=3.9
+ run: python -m pip install protobuf && apt-get -y install libprotobuf-dev protobuf-compiler cmake
+ if: ${{matrix.python-version == '3.9'}}
+ # pstuil is an optional package to detect the number of CPU for compiling mmcv
+ - name: Install psutil
+ run: python -m pip install psutil
+ - name: Build and install
+ run: rm -rf .eggs && python -m pip install -e .
+ - name: Validate the installation
+ run: python -c "import mmcv"
+ - name: Run unittests and generate coverage report
+ run: |
+ python -m pip install -r requirements/test.txt
+ coverage run --branch --source=mmcv -m pytest tests/
+ coverage xml
+ coverage report -m
+ # Only upload coverage report for python3.7 && pytorch1.6
+ - name: Upload coverage to Codecov
+ if: ${{matrix.torch == '1.6.0+cu101' && matrix.python-version == '3.7'}}
+ uses: codecov/codecov-action@v1.0.14
+ with:
+ file: ./coverage.xml
+ flags: unittests
+ env_vars: OS,PYTHON
+ name: codecov-umbrella
+ fail_ci_if_error: false
+
+ build_cu102:
+ runs-on: ubuntu-18.04
+ container:
+ image: pytorch/pytorch:1.9.0-cuda10.2-cudnn7-devel
+ env:
+ FORCE_CUDA: 1
+ MMCV_CUDA_ARGS: -gencode=arch=compute_61,code=sm_61
+ strategy:
+ matrix:
+ python-version: [3.7]
+ torch: [1.9.0+cu102, 1.10.0+cu102]
+ include:
+ - torch: 1.9.0+cu102
+ torchvision: 0.10.0+cu102
+ - torch: 1.10.0+cu102
+ torchvision: 0.11.0+cu102
+ - python-version: '3.10'
+ torch: 1.11.0+cu102
+ torchvision: 0.12.0+cu102
+ - python-version: '3.10'
+ torch: 1.12.0+cu102
+ torchvision: 0.13.0+cu102
+ - python-version: 3.6
+ torch: 1.9.0+cu102
+ torchvision: 0.10.0+cu102
+ - python-version: 3.8
+ torch: 1.9.0+cu102
+ torchvision: 0.10.0+cu102
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Fetch GPG keys
+ run: |
+ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
+ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
+ - name: Add PPA
+ run: |
+ apt-get update && apt-get install -y software-properties-common
+ add-apt-repository -y ppa:deadsnakes/ppa
+ - name: Install python-dev
+ run: apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y python${{matrix.python-version}}-dev
+ - name: python -m Install PyTorch
+ run: python -m pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
+ - name: Install system dependencies
+ run: apt-get update && apt-get install -y ffmpeg libturbojpeg ninja-build
+ # pstuil is an optional package to detect the number of CPU for compiling mmcv
+ - name: Install psutil
+ run: python -m pip install psutil
+ - name: Build and install
+ run: rm -rf .eggs && python -m pip install -e .
+ - name: Validate the installation
+ run: python -c "import mmcv"
+ - name: Run unittests and generate coverage report
+ run: |
+ python -m pip install -r requirements/test.txt
+ coverage run --branch --source=mmcv -m pytest tests/
+ coverage xml
+ if: ${{matrix.python-version != '3.10'}}
+ # special treatment for python3.10 because onnx and onnxruntime don't provide python3.10 pre-built packages
+ - name: Run unittests and generate coverage report for python3.10
+ run: |
+ python -m pip install -r requirements/test.txt
+ coverage run --branch --source=mmcv -m pytest tests/ --ignore=tests/test_ops/test_onnx.py --ignore=tests/test_ops/test_tensorrt.py --ignore=tests/test_ops/test_tensorrt_preprocess.py
+ coverage xml
+ if: ${{matrix.python-version == '3.10'}}
+
+
+ build_windows_without_ops:
+ runs-on: windows-latest
+ env:
+ MMCV_WITH_OPS: 0
+ strategy:
+ matrix:
+ torch: [1.7.1, 1.8.0, 1.9.0]
+ include:
+ - torch: 1.7.1
+ torchvision: 0.8.2
+ - torch: 1.8.0
+ torchvision: 0.9.0
+ - torch: 1.9.0
+ torchvision: 0.10.0
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python 3.7
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.7
+ - name: Install PyTorch
+ run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu --no-cache-dir -f https://download.pytorch.org/whl/torch_stable.html
+ - name: Build and install
+ run: pip install -e .
+ - name: Validate the installation
+ run: python -c "import mmcv"
+ - name: Run unittests
+ run: |
+ pip install -r requirements/test.txt
+ pytest tests/ --ignore=tests/test_ops --ignore tests/test_utils/test_progressbar.py --ignore tests/test_utils/test_timer.py --ignore tests/test_image/test_io.py
+
+ build_windows:
+ runs-on: windows-latest
+ strategy:
+ matrix:
+ torch: [1.7.1, 1.8.0, 1.9.0]
+ include:
+ - torch: 1.7.1
+ torchvision: 0.8.2
+ - torch: 1.8.0
+ torchvision: 0.9.0
+ - torch: 1.9.0
+ torchvision: 0.10.0
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python 3.7
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.7
+ - name: Install PyTorch
+ run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu --no-cache-dir -f https://download.pytorch.org/whl/torch_stable.html
+ - name: Build and install
+ run: pip install -e .
+ - name: Validate the installation
+ run: python -c "import mmcv"
+ - name: Run unittests
+ run: |
+ pip install -r requirements/test.txt
+ pytest tests/ --ignore tests/test_utils/test_progressbar.py --ignore tests/test_utils/test_timer.py --ignore tests/test_image/test_io.py
+
+ build_macos:
+ runs-on: macos-latest
+ strategy:
+ matrix:
+ torch: [1.3.1, 1.5.1, 1.6.0, 1.7.0, 1.8.0, 1.9.0]
+ include:
+ - torch: 1.3.1
+ torchvision: 0.4.2
+ - torch: 1.5.1
+ torchvision: 0.6.1
+ - torch: 1.6.0
+ torchvision: 0.7.0
+ - torch: 1.7.0
+ torchvision: 0.8.1
+ - torch: 1.8.0
+ torchvision: 0.9.0
+ - torch: 1.9.0
+ torchvision: 0.10.0
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python 3.7
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.7
+ - name: Install system dependencies
+ run: brew install ffmpeg jpeg-turbo
+ - name: Install utils
+ run: pip install psutil
+ - name: Install Pillow
+ run: pip install Pillow==6.2.2
+ if: ${{matrix.torchvision == '0.4.2'}}
+ - name: Install PyTorch
+ run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} --no-cache-dir
+ - name: Build and install
+ run: |
+ rm -rf .eggs
+ CC=clang CXX=clang++ CFLAGS='-stdlib=libc++' pip install -e .
+ - name: Validate the installation
+ run: python -c "import mmcv"
+ - name: Run unittests
+ run: |
+ pip install -r requirements/test.txt
+ # The timing on macos VMs is not precise, so we skip the progressbar tests
+ pytest tests/ --ignore tests/test_utils/test_progressbar.py --ignore tests/test_utils/test_timer.py
diff --git a/mmcv/.github/workflows/build_pat.yml b/mmcv/.github/workflows/build_pat.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9b02c3f41a546df213e5bf2c5e15e9047ed6c494
--- /dev/null
+++ b/mmcv/.github/workflows/build_pat.yml
@@ -0,0 +1,26 @@
+name: build_pat
+
+on: push
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+env:
+ MMCV_WITH_OPS: 1
+
+jobs:
+ build_parrots:
+ runs-on: ubuntu-18.04
+ container:
+ image: ghcr.io/zhouzaida/parrots-mmcv:1.3.4
+ credentials:
+ username: zhouzaida
+ password: ${{ secrets.CR_PAT }}
+
+ steps:
+ - uses: actions/checkout@v2
+ - name: Install unittest dependencies
+ run: pip install -r requirements/test.txt
+ - name: Build and install
+ run: rm -rf .eggs && MMCV_WITH_OPS=1 pip install -e .
diff --git a/mmcv/.github/workflows/lint.yml b/mmcv/.github/workflows/lint.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7f0550681d804d137182d40396d6d42973acc83b
--- /dev/null
+++ b/mmcv/.github/workflows/lint.yml
@@ -0,0 +1,29 @@
+name: lint
+
+on: [push, pull_request]
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ lint:
+ runs-on: ubuntu-18.04
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python 3.7
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.7
+ - name: Install pre-commit hook
+ run: |
+ pip install pre-commit
+ pre-commit install
+ - name: Linting
+ run: pre-commit run --all-files
+ - name: Format c/cuda codes with clang-format
+ uses: DoozyX/clang-format-lint-action@v0.11
+ with:
+ source: mmcv/ops/csrc
+ extensions: h,c,cpp,hpp,cu,cuh
+ style: google
diff --git a/mmcv/.github/workflows/publish-to-pypi.yml b/mmcv/.github/workflows/publish-to-pypi.yml
new file mode 100644
index 0000000000000000000000000000000000000000..04b0add31fd12d808f58dc45bc6e02eb2ad59623
--- /dev/null
+++ b/mmcv/.github/workflows/publish-to-pypi.yml
@@ -0,0 +1,46 @@
+name: deploy
+
+on: push
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ build-n-publish:
+ runs-on: ubuntu-18.04
+ if: startsWith(github.event.ref, 'refs/tags')
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python 3.7
+ uses: actions/setup-python@v1
+ with:
+ python-version: 3.7
+ - name: Upgrade Setuptools
+ run: pip install setuptools --upgrade
+ - name: Build MMCV
+ run: python setup.py sdist
+ - name: Publish distribution to PyPI
+ run: |
+ pip install twine
+ twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }}
+
+ build-n-publish_with_ops:
+ runs-on: ubuntu-18.04
+ if: startsWith(github.event.ref, 'refs/tags')
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python 3.7
+ uses: actions/setup-python@v1
+ with:
+ python-version: 3.7
+ - name: Upgrade Setuptools
+ run: pip install setuptools --upgrade
+ - name: Build MMCV with ops
+ run: |
+ sed -i "s/os.getenv('MMCV_WITH_OPS', '0')/os.getenv('MMCV_WITH_OPS', '1')/g" setup.py
+ python setup.py sdist
+ - name: Publish distribution to PyPI
+ run: |
+ pip install twine
+ twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }}
diff --git a/mmcv/.gitignore b/mmcv/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..10a38f688f192d041d9aac98a2ace4bb8b1afd62
--- /dev/null
+++ b/mmcv/.gitignore
@@ -0,0 +1,121 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# PyTorch checkpoint
+*.pth
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/en/_build/
+docs/zh_cn/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+# editors and IDEs
+.idea/
+.vscode/
+
+# custom
+.DS_Store
+
+# datasets and logs and checkpoints
+data/
+work_dir/
+
+src/
diff --git a/mmcv/.owners.yml b/mmcv/.owners.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8f7057cb339a36314e92be9b74d0b8ab1df2defc
--- /dev/null
+++ b/mmcv/.owners.yml
@@ -0,0 +1,14 @@
+assign:
+ strategy:
+ # random
+ daily-shift-based
+ scedule:
+ '*/1 * * * *'
+ assignees:
+ - zhouzaida
+ - ice-tong
+ - HAOCHENYE
+ - zhouzaida
+ - ice-tong
+ - HAOCHENYE
+ - zhouzaida
diff --git a/mmcv/.pre-commit-config-zh-cn.yaml b/mmcv/.pre-commit-config-zh-cn.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0b0f85cacda50d99f1de49dda04d799ce25234a6
--- /dev/null
+++ b/mmcv/.pre-commit-config-zh-cn.yaml
@@ -0,0 +1,72 @@
+exclude: ^tests/data/
+repos:
+ - repo: https://gitee.com/openmmlab/mirrors-flake8
+ rev: 3.8.3
+ hooks:
+ - id: flake8
+ - repo: https://gitee.com/openmmlab/mirrors-isort
+ rev: 5.10.1
+ hooks:
+ - id: isort
+ - repo: https://gitee.com/openmmlab/mirrors-yapf
+ rev: v0.30.0
+ hooks:
+ - id: yapf
+ - repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
+ rev: v3.1.0
+ hooks:
+ - id: trailing-whitespace
+ - id: check-yaml
+ - id: end-of-file-fixer
+ - id: requirements-txt-fixer
+ - id: double-quote-string-fixer
+ - id: check-merge-conflict
+ - id: fix-encoding-pragma
+ args: ["--remove"]
+ - id: mixed-line-ending
+ args: ["--fix=lf"]
+ - repo: https://gitee.com/openmmlab/mirrors-codespell
+ rev: v2.1.0
+ hooks:
+ - id: codespell
+ - repo: https://gitee.com/openmmlab/mirrors-mdformat
+ rev: 0.7.9
+ hooks:
+ - id: mdformat
+ args: ["--number"]
+ additional_dependencies:
+ - mdformat-openmmlab
+ - mdformat_frontmatter
+ - linkify-it-py
+ - repo: https://gitee.com/openmmlab/mirrors-docformatter
+ rev: v1.3.1
+ hooks:
+ - id: docformatter
+ args: ["--in-place", "--wrap-descriptions", "79"]
+ - repo: https://github.com/asottile/pyupgrade
+ rev: v2.32.1
+ hooks:
+ - id: pyupgrade
+ args: ["--py36-plus"]
+ - repo: https://gitee.com/openmmlab/pre-commit-hooks
+ rev: v0.2.0 # Use the ref you want to point at
+ hooks:
+ - id: check-copyright
+ args: ["mmcv", "tests", "--excludes", "mmcv/ops"]
+ - repo: https://gitee.com/openmmlab/mirrors-mypy
+ rev: v0.812
+ hooks:
+ - id: mypy
+ exclude: |-
+ (?x)(
+ ^test
+ | ^docs
+ )
+ # - repo: local
+ # hooks:
+ # - id: clang-format
+ # name: clang-format
+ # description: Format files with ClangFormat
+ # entry: clang-format -style=google -i
+ # language: system
+ # files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
diff --git a/mmcv/.pre-commit-config.yaml b/mmcv/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f4dd84c0b4689bf6ecec35ce39c80abef077426f
--- /dev/null
+++ b/mmcv/.pre-commit-config.yaml
@@ -0,0 +1,72 @@
+exclude: ^tests/data/
+repos:
+ - repo: https://github.com/PyCQA/flake8
+ rev: 3.8.3
+ hooks:
+ - id: flake8
+ - repo: https://github.com/PyCQA/isort
+ rev: 5.10.1
+ hooks:
+ - id: isort
+ - repo: https://github.com/pre-commit/mirrors-yapf
+ rev: v0.30.0
+ hooks:
+ - id: yapf
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v3.1.0
+ hooks:
+ - id: trailing-whitespace
+ - id: check-yaml
+ - id: end-of-file-fixer
+ - id: requirements-txt-fixer
+ - id: double-quote-string-fixer
+ - id: check-merge-conflict
+ - id: fix-encoding-pragma
+ args: ["--remove"]
+ - id: mixed-line-ending
+ args: ["--fix=lf"]
+ - repo: https://github.com/codespell-project/codespell
+ rev: v2.1.0
+ hooks:
+ - id: codespell
+ - repo: https://github.com/executablebooks/mdformat
+ rev: 0.7.9
+ hooks:
+ - id: mdformat
+ args: ["--number"]
+ additional_dependencies:
+ - mdformat-openmmlab
+ - mdformat_frontmatter
+ - linkify-it-py
+ - repo: https://github.com/myint/docformatter
+ rev: v1.3.1
+ hooks:
+ - id: docformatter
+ args: ["--in-place", "--wrap-descriptions", "79"]
+ - repo: https://github.com/asottile/pyupgrade
+ rev: v2.32.1
+ hooks:
+ - id: pyupgrade
+ args: ["--py36-plus"]
+ - repo: https://github.com/open-mmlab/pre-commit-hooks
+ rev: v0.2.0 # Use the ref you want to point at
+ hooks:
+ - id: check-copyright
+ args: ["mmcv", "tests", "--excludes", "mmcv/ops"]
+ - repo: https://github.com/pre-commit/mirrors-mypy
+ rev: v0.812
+ hooks:
+ - id: mypy
+ exclude: |-
+ (?x)(
+ ^test
+ | ^docs
+ )
+ # - repo: local
+ # hooks:
+ # - id: clang-format
+ # name: clang-format
+ # description: Format files with ClangFormat
+ # entry: clang-format -style=google -i
+ # language: system
+ # files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
diff --git a/mmcv/.readthedocs.yml b/mmcv/.readthedocs.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7d5f1c2060a64e5cf9c2bec433cd24532a283164
--- /dev/null
+++ b/mmcv/.readthedocs.yml
@@ -0,0 +1,9 @@
+version: 2
+
+formats: all
+
+python:
+ version: 3.7
+ install:
+ - requirements: requirements/runtime.txt
+ - requirements: requirements/docs.txt
diff --git a/mmcv/CITATION.cff b/mmcv/CITATION.cff
new file mode 100644
index 0000000000000000000000000000000000000000..786117aac3e063efc18ad1b55e163d570a09e379
--- /dev/null
+++ b/mmcv/CITATION.cff
@@ -0,0 +1,8 @@
+cff-version: 1.2.0
+message: "If you use this software, please cite it as below."
+authors:
+ - name: "MMCV Contributors"
+title: "OpenMMLab Computer Vision Foundation"
+date-released: 2018-08-22
+url: "https://github.com/open-mmlab/mmcv"
+license: Apache-2.0
diff --git a/mmcv/CONTRIBUTING.md b/mmcv/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..eea0b2544fd606d8593f1b2f12008a76673829d1
--- /dev/null
+++ b/mmcv/CONTRIBUTING.md
@@ -0,0 +1,59 @@
+## Contributing to OpenMMLab
+
+All kinds of contributions are welcome, including but not limited to the following.
+
+- Fix typo or bugs
+- Add documentation or translate the documentation into other languages
+- Add new features and components
+
+### Workflow
+
+1. fork and pull the latest OpenMMLab repository
+2. checkout a new branch (do not use master branch for PRs)
+3. commit your changes
+4. create a PR
+
+```{note}
+If you plan to add some new features that involve large changes, it is encouraged to open an issue for discussion first.
+```
+
+### Code style
+
+#### Python
+
+We adopt [PEP8](https://www.python.org/dev/peps/pep-0008/) as the preferred code style.
+
+We use the following tools for linting and formatting:
+
+- [flake8](https://github.com/PyCQA/flake8): A wrapper around some linter tools.
+- [isort](https://github.com/timothycrosley/isort): A Python utility to sort imports.
+- [yapf](https://github.com/google/yapf): A formatter for Python files.
+- [codespell](https://github.com/codespell-project/codespell): A Python utility to fix common misspellings in text files.
+- [mdformat](https://github.com/executablebooks/mdformat): Mdformat is an opinionated Markdown formatter that can be used to enforce a consistent style in Markdown files.
+- [docformatter](https://github.com/myint/docformatter): A formatter to format docstring.
+
+Style configurations of yapf and isort can be found in [setup.cfg](./setup.cfg).
+
+We use [pre-commit hook](https://pre-commit.com/) that checks and formats for `flake8`, `yapf`, `isort`, `trailing whitespaces`, `markdown files`,
+fixes `end-of-files`, `double-quoted-strings`, `python-encoding-pragma`, `mixed-line-ending`, sorts `requirments.txt` automatically on every commit.
+The config for a pre-commit hook is stored in [.pre-commit-config](./.pre-commit-config.yaml).
+
+After you clone the repository, you will need to install initialize pre-commit hook.
+
+```shell
+pip install -U pre-commit
+```
+
+From the repository folder
+
+```shell
+pre-commit install
+```
+
+After this on every commit check code linters and formatter will be enforced.
+
+> Before you create a PR, make sure that your code lints and is formatted by yapf.
+
+#### C++ and CUDA
+
+We follow the [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).
diff --git a/mmcv/Jenkinsfile b/mmcv/Jenkinsfile
new file mode 100644
index 0000000000000000000000000000000000000000..f0c19d9f3c3e0efc9ed218efa2259c598e383a06
--- /dev/null
+++ b/mmcv/Jenkinsfile
@@ -0,0 +1,56 @@
+def docker_images = ["registry.cn-hangzhou.aliyuncs.com/sensetime/openmmlab:cuda10.1-cudnn7-devel-ubuntu18.04-py37-pt1.3",
+ "registry.cn-hangzhou.aliyuncs.com/sensetime/openmmlab:cuda10.2-cudnn7-devel-ubuntu18.04-py37-pt1.5"]
+def torch_versions = ["1.3.0", "1.5.0"]
+def torchvision_versions = ["0.4.2", "0.6.0"]
+
+
+def get_stages(docker_image, folder) {
+ def pip_mirror = "-i https://mirrors.aliyun.com/pypi/simple"
+ stages = {
+ docker.image(docker_image).inside('-u root --gpus all --net host') {
+ sh "rm -rf ${env.WORKSPACE}-${folder} ${env.WORKSPACE}-${folder}@tmp"
+ sh "cp -r ${env.WORKSPACE} ${env.WORKSPACE}-${folder}"
+ try {
+ dir("${env.WORKSPACE}-${folder}") {
+ stage("before_install") {
+ sh "apt-get update && apt-get install -y ninja-build"
+ }
+ stage("dependencies") {
+ // torch and torchvision are pre-installed in dockers
+ sh "pip list | grep torch"
+ sh "apt-get install -y ffmpeg libturbojpeg"
+ sh "pip install pytest coverage lmdb PyTurboJPEG Cython ${pip_mirror}"
+ }
+ stage("build") {
+ sh "MMCV_WITH_OPS=1 pip install -e . ${pip_mirror}"
+ }
+ stage("test") {
+ sh "coverage run --branch --source=mmcv -m pytest tests/"
+ sh "coverage xml"
+ sh "coverage report -m"
+ }
+ }
+ } finally {
+ sh "rm -rf ${env.WORKSPACE}-${folder} ${env.WORKSPACE}-${folder}@tmp"
+ }
+ }
+ }
+ return stages
+}
+
+
+node('master') {
+ // fetch latest change from SCM (Source Control Management)
+ checkout scm
+
+ def stages = [:]
+ for (int i = 0; i < docker_images.size(); i++) {
+ def docker_image = docker_images[i]
+ def torch = torch_versions[i]
+ def torchvision = torchvision_versions[i]
+ def tag = docker_image + '_' + torch + '_' + torchvision
+ def folder = "${i}"
+ stages[tag] = get_stages(docker_image, folder)
+ }
+ parallel stages
+}
diff --git a/mmcv/LICENSE b/mmcv/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..f02314255d824c0816b0bf1648aac8ab78976199
--- /dev/null
+++ b/mmcv/LICENSE
@@ -0,0 +1,203 @@
+Copyright (c) OpenMMLab. All rights reserved
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2018-2020 Open-MMLab. All rights reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/mmcv/LICENSES.md b/mmcv/LICENSES.md
new file mode 100644
index 0000000000000000000000000000000000000000..5de8358331f4d21529e016807b86b66dc6ca29da
--- /dev/null
+++ b/mmcv/LICENSES.md
@@ -0,0 +1,8 @@
+# Licenses for special operations
+
+In this file, we list the operations with other licenses instead of Apache 2.0. Users should be careful about adopting these operations in any commercial matters.
+
+| Operation | Files | License |
+| :--------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------: |
+| upfirdn2d | [mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu](https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu) | NVIDIA License |
+| fused_leaky_relu | [mmcv/ops/csrc/pytorch/cuda/fused_bias_leakyrelu_cuda.cu](https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/pytorch/cuda/fused_bias_leakyrelu_cuda.cu) | NVIDIA License |
diff --git a/mmcv/MANIFEST.in b/mmcv/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..5de8494b5df3656a4f6a09da26d9f4bb27ed69a5
--- /dev/null
+++ b/mmcv/MANIFEST.in
@@ -0,0 +1,7 @@
+include requirements/runtime.txt
+include mmcv/model_zoo/open_mmlab.json mmcv/model_zoo/deprecated.json mmcv/model_zoo/mmcls.json mmcv/model_zoo/torchvision_0.12.json
+include mmcv/ops/csrc/common/cuda/*.cuh mmcv/ops/csrc/common/cuda/*.hpp mmcv/ops/csrc/common/*.hpp
+include mmcv/ops/csrc/pytorch/*.cpp mmcv/ops/csrc/pytorch/cuda/*.cu mmcv/ops/csrc/pytorch/cuda/*.cpp mmcv/ops/csrc/pytorch/cpu/*.cpp
+include mmcv/ops/csrc/parrots/*.h mmcv/ops/csrc/parrots/*.cpp
+include mmcv/ops/csrc/pytorch/mps/*.mm mmcv/ops/csrc/common/mps/*.h mmcv/ops/csrc/common/mps/*.mm
+recursive-include mmcv/ops/csrc/ *.h *.hpp *.cpp *.cuh *.cu *.mm
diff --git a/mmcv/README.md b/mmcv/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1a6541a689a48944394db84b48d5b484e63a8708
--- /dev/null
+++ b/mmcv/README.md
@@ -0,0 +1,274 @@
+
+
+
+
+
+
+
+[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mmcv.readthedocs.io/en/latest/)
+[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/mmcv)](https://pypi.org/project/mmcv/)
+[![PyPI](https://img.shields.io/pypi/v/mmcv)](https://pypi.org/project/mmcv)
+[![badge](https://github.com/open-mmlab/mmcv/workflows/build/badge.svg)](https://github.com/open-mmlab/mmcv/actions)
+[![codecov](https://codecov.io/gh/open-mmlab/mmcv/branch/master/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmcv)
+[![license](https://img.shields.io/github/license/open-mmlab/mmcv.svg)](https://github.com/open-mmlab/mmcv/blob/master/LICENSE)
+
+English | [简体中文](README_zh-CN.md)
+
+## Introduction
+
+MMCV is a foundational library for computer vision research and supports many
+research projects as below:
+
+- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages.
+- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab image classification toolbox and benchmark.
+- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark.
+- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection.
+- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark.
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark.
+- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox.
+- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
+- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 3D human parametric model toolbox and benchmark.
+- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab self-supervised learning toolbox and benchmark.
+- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab model compression toolbox and benchmark.
+- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab fewshot learning toolbox and benchmark.
+- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark.
+- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark.
+- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark.
+- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox.
+- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox.
+- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab model deployment framework.
+
+It provides the following functionalities.
+
+- Universal IO APIs
+- Image/Video processing
+- Image and annotation visualization
+- Useful utilities (progress bar, timer, ...)
+- PyTorch runner with hooking mechanism
+- Various CNN architectures
+- High-quality implementation of common CUDA ops
+
+It supports the following systems.
+
+- Linux
+- Windows
+- macOS
+
+See the [documentation](http://mmcv.readthedocs.io/en/latest) for more features and usage.
+
+Note: MMCV requires Python 3.6+.
+
+## Installation
+
+There are two versions of MMCV:
+
+- **mmcv-full**: comprehensive, with full features and various CUDA ops out of box. It takes longer time to build.
+- **mmcv**: lite, without CUDA ops but all other features, similar to mmcv\<1.0.0. It is useful when you do not need those CUDA ops.
+
+**Note**: Do not install both versions in the same environment, otherwise you may encounter errors like `ModuleNotFound`. You need to uninstall one before installing the other. `Installing the full version is highly recommended if CUDA is available`.
+
+a. Install the full version.
+
+Before installing mmcv-full, make sure that PyTorch has been successfully installed following the [official guide](https://pytorch.org/).
+
+We provide pre-built mmcv packages (recommended) with different PyTorch and CUDA versions to simplify the building for **Linux and Windows systems**. In addition, you can run [check_installation.py](.dev_scripts/check_installation.py) to check the installation of mmcv-full after running the installation commands.
+
+i. Install the latest version.
+
+The rule for installing the latest `mmcv-full` is as follows:
+
+```shell
+pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html
+```
+
+Please replace `{cu_version}` and `{torch_version}` in the url to your desired one. For example,
+to install the latest `mmcv-full` with `CUDA 11.1` and `PyTorch 1.9.0`, use the following command:
+
+```shell
+pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
+```
+
+**Note**: mmcv-full is only compiled on PyTorch 1.x.0 because the compatibility usually holds between 1.x.0 and 1.x.1. If your PyTorch version is 1.x.1, you can install mmcv-full compiled with PyTorch 1.x.0 and it usually works well. For example, if your PyTorch version is 1.8.1 and CUDA version is 11.1, you can use the following command to install mmcv-full.
+
+```shell
+pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.0/index.html
+```
+
+For more details, please refer the the following tables and delete `=={mmcv_version}`.
+
+ii. Install a specified version.
+
+The rule for installing a specified `mmcv-full` is as follows:
+
+```shell
+pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html
+```
+
+First of all, please refer to the Releases and replace `{mmcv_version}` a specified one. e.g. `1.3.9`.
+Then replace `{cu_version}` and `{torch_version}` in the url to your desired versions. For example,
+to install `mmcv-full==1.3.9` with `CUDA 11.1` and `PyTorch 1.9.0`, use the following command:
+
+```shell
+pip install mmcv-full==1.3.9 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
+```
+
+For more details, please refer the the following tables.
+
+
+
+
+ CUDA |
+ torch 1.11 |
+ torch 1.10 |
+ torch 1.9 |
+ torch 1.8 |
+ torch 1.7 |
+ torch 1.6 |
+ torch 1.5 |
+
+
+ 11.5 |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu115/torch1.11.0/index.html |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
+ 11.3 |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html |
+ |
+ |
+ |
+ |
+ |
+
+
+ 11.1 |
+ |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.10.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.0/index.html |
+ |
+ |
+ |
+
+
+ 11.0 |
+ |
+ |
+ |
+ |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.0/index.html |
+ |
+ |
+
+
+ 10.2 |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.11.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.10.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.9.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.8.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.7.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.6.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.5.0/index.html |
+
+
+ 10.1 |
+ |
+ |
+ |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.5.0/index.html |
+
+
+ 9.2 |
+ |
+ |
+ |
+ |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu92/torch1.7.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu92/torch1.6.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu92/torch1.5.0/index.html |
+
+
+ cpu |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.11.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.10.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.9.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.8.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.7.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.6.0/index.html |
+ install pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.5.0/index.html |
+
+
+
+
+**Note**: The pre-built packages provided above do not include all versions of mmcv-full, you can click on the corresponding links to see the supported versions. For example, you can click [cu102-torch1.8.0](https://download.openmmlab.com/mmcv/dist/cu102/torch1.8.0/index.html) and you can see that `cu102-torch1.8.0` only provides 1.3.0 and above versions of mmcv-full. In addition, We no longer provide `mmcv-full` pre-built packages compiled with `PyTorch 1.3 & 1.4` since v1.3.17. You can find previous versions that compiled with PyTorch 1.3 & 1.4 [here](./docs/en/get_started/previous_versions.md). The compatibility is still ensured in our CI, but we will discard the support of PyTorch 1.3 & 1.4 next year.
+
+**Note**: mmcv-full does not provide pre-built packages for `cu102-torch1.11` and `cu92-torch*` on Windows.
+
+Another way is to compile locally by running
+
+```python
+pip install mmcv-full
+```
+
+Note that the local compiling may take up to 10 mins.
+
+b. Install the lite version.
+
+```python
+pip install mmcv
+```
+
+c. Install full version with custom operators for onnxruntime
+
+- Check [here](docs/en/deployment/onnxruntime_op.md) for detailed instruction.
+
+If you would like to build MMCV from source, please refer to the [guide](https://mmcv.readthedocs.io/en/latest/get_started/build.html).
+
+## FAQ
+
+If you face some installation issues, CUDA related issues or RuntimeErrors,
+you may first refer to this [Frequently Asked Questions](https://mmcv.readthedocs.io/en/latest/faq.html).
+
+## Citation
+
+If you find this project useful in your research, please consider cite:
+
+```latex
+@misc{mmcv,
+ title={{MMCV: OpenMMLab} Computer Vision Foundation},
+ author={MMCV Contributors},
+ howpublished = {\url{https://github.com/open-mmlab/mmcv}},
+ year={2018}
+}
+```
+
+## Contributing
+
+We appreciate all contributions to improve MMCV. Please refer to [CONTRIBUTING.md](CONTRIBUTING.md) for the contributing guideline.
+
+## License
+
+MMCV is released under the Apache 2.0 license, while some specific operations in this library are with other licenses. Please refer to [LICENSES.md](LICENSES.md) for the careful check, if you are using our code for commercial matters.
diff --git a/mmcv/README_zh-CN.md b/mmcv/README_zh-CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..8c768c837ecddc7f6c4d7e036f590d9d2b96fa64
--- /dev/null
+++ b/mmcv/README_zh-CN.md
@@ -0,0 +1,276 @@
+
+
+
+
+
+
+
+[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mmcv.readthedocs.io/zh_CN/latest/)
+[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/mmcv)](https://pypi.org/project/mmcv/)
+[![PyPI](https://img.shields.io/pypi/v/mmcv)](https://pypi.org/project/mmcv)
+[![badge](https://github.com/open-mmlab/mmcv/workflows/build/badge.svg)](https://github.com/open-mmlab/mmcv/actions)
+[![codecov](https://codecov.io/gh/open-mmlab/mmcv/branch/master/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmcv)
+[![license](https://img.shields.io/github/license/open-mmlab/mmcv.svg)](https://github.com/open-mmlab/mmcv/blob/master/LICENSE)
+
+[English](README.md) | 简体中文
+
+## 简介
+
+MMCV 是一个面向计算机视觉的基础库,它支持了很多开源项目,例如:
+
+- [MIM](https://github.com/open-mmlab/mim): MIM 是 OpenMMlab 项目、算法、模型的统一入口
+- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab 图像分类工具箱
+- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab 目标检测工具箱
+- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab 新一代通用 3D 目标检测平台
+- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab 旋转框检测工具箱与测试基准
+- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab 语义分割工具箱
+- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab 全流程文字检测识别理解工具箱
+- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab 姿态估计工具箱
+- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 人体参数化模型工具箱与测试基准
+- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab 自监督学习工具箱与测试基准
+- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab 模型压缩工具箱与测试基准
+- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab 少样本学习工具箱与测试基准
+- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab 新一代视频理解工具箱
+- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab 一体化视频目标感知平台
+- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab 光流估计工具箱与测试基准
+- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab 图像视频编辑工具箱
+- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab 图片视频生成模型工具箱
+- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab 模型部署框架
+
+MMCV 提供了如下众多功能:
+
+- 通用的 IO 接口
+- 图像和视频处理
+- 图像和标注结果可视化
+- 常用小工具(进度条,计时器等)
+- 基于 PyTorch 的通用训练框架
+- 多种 CNN 网络结构
+- 高质量实现的常见 CUDA 算子
+
+MMCV 支持以下的系统:
+
+- Linux
+- Windows
+- macOS
+
+如想了解更多特性和使用,请参考[文档](http://mmcv.readthedocs.io/zh_CN/latest)。
+
+提示: MMCV 需要 Python 3.6 以上版本。
+
+## 安装
+
+MMCV 有两个版本:
+
+- **mmcv-full**: 完整版,包含所有的特性以及丰富的开箱即用的 CUDA 算子。注意完整版本可能需要更长时间来编译。
+- **mmcv**: 精简版,不包含 CUDA 算子但包含其余所有特性和功能,类似 MMCV 1.0 之前的版本。如果你不需要使用 CUDA 算子的话,精简版可以作为一个考虑选项。
+
+**注意**: 请不要在同一个环境中安装两个版本,否则可能会遇到类似 `ModuleNotFound` 的错误。在安装一个版本之前,需要先卸载另一个。`如果CUDA可用,强烈推荐安装mmcv-full`。
+
+a. 安装完整版
+
+在安装 mmcv-full 之前,请确保 PyTorch 已经成功安装在环境中,可以参考 PyTorch [官方文档](https://pytorch.org/)。
+
+我们提供了 **Linux 和 Windows 平台** PyTorch 和 CUDA 版本组合的 mmcv-full 预编译包,可以大大简化用户安装编译过程。强烈推荐通过预编译包来安装。另外,安装完成后可以运行 [check_installation.py](.dev_scripts/check_installation.py) 脚本检查 mmcv-full 是否安装成功。
+
+i. 安装最新版本
+
+如下是安装最新版 `mmcv-full` 的命令
+
+```shell
+pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html
+```
+
+请将链接中的 `{cu_version}` 和 `{torch_version}` 根据自身需求替换成实际的版本号,例如想安装和 `CUDA 11.1`、`PyTorch 1.9.0` 兼容的最新版 `mmcv-full`,使用如下替换过的命令
+
+```shell
+pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
+```
+
+**注意**: PyTorch 在 1.x.0 和 1.x.1 之间通常是兼容的,故 mmcv-full 只提供 1.x.0 的编译包。如果你的 PyTorch 版本是 1.x.1,你可以放心地安装在 1.x.0 版本编译的 mmcv-full。例如,如果你的 PyTorch 版本是 1.8.1、CUDA 版本是 11.1,你可以使用以下命令安装 mmcv-full。
+
+```shell
+pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.0/index.html
+```
+
+如果想知道更多 CUDA 和 PyTorch 版本的命令,可以参考下面的表格,将链接中的 `=={mmcv_version}` 删去即可。
+
+ii. 安装特定的版本
+
+如下是安装特定版本 `mmcv-full` 的命令
+
+```shell
+pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html
+```
+
+首先请参考版本发布信息找到想要安装的版本号,将 `{mmcv_version}` 替换成该版本号,例如 `1.3.9`。
+然后将链接中的 `{cu_version}` 和 `{torch_version}` 根据自身需求替换成实际的版本号,例如想安装和 `CUDA 11.1`、`PyTorch 1.9.0` 兼容的 `mmcv-full` 1.3.9 版本,使用如下替换过的命令
+
+```shell
+pip install mmcv-full==1.3.9 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
+```
+
+对于更多的 PyTorch 和 CUDA 版本组合,请参考下表:
+
+
+
+
+ CUDA |
+ torch 1.11 |
+ torch 1.10 |
+ torch 1.9 |
+ torch 1.8 |
+ torch 1.7 |
+ torch 1.6 |
+ torch 1.5 |
+
+
+ 11.5 |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu115/torch1.11.0/index.html |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
+ 11.3 |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html |
+ |
+ |
+ |
+ |
+ |
+
+
+ 11.1 |
+ |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.10.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.0/index.html |
+ |
+ |
+ |
+
+
+ 11.0 |
+ |
+ |
+ |
+ |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.0/index.html |
+ |
+ |
+
+
+ 10.2 |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.11.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.10.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.9.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.8.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.7.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.6.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.5.0/index.html |
+
+
+ 10.1 |
+ |
+ |
+ |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.5.0/index.html |
+
+
+ 9.2 |
+ |
+ |
+ |
+ |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu92/torch1.7.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu92/torch1.6.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cu92/torch1.5.0/index.html |
+
+
+ cpu |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.11.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.10.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.9.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.8.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.7.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.6.0/index.html |
+ 安装 pip install mmcv-full=={mmcv_version} -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.5.0/index.html |
+
+
+
+
+**注意**:以上提供的预编译包并不囊括所有的 mmcv-full 版本,你可以点击对应链接查看支持的版本。例如,点击 [cu102-torch1.8.0](https://download.openmmlab.com/mmcv/dist/cu102/torch1.8.0/index.html),可以看到 `cu102-torch1.8.0` 只提供了 1.3.0 及以上的 mmcv-full 版本。另外,从 `mmcv v1.3.17` 开始,我们不再提供`PyTorch 1.3 & 1.4` 对应的 mmcv-full 预编译包。你可以在 [这](./docs/zh_cn/get_started/previous_versions.md) 找到 `PyTorch 1.3 & 1.4` 对应的预编包。虽然我们不再提供 `PyTorch 1.3 & 1.4` 对应的预编译包,但是我们依然在 CI 中保证对它们的兼容持续到下一年。
+
+**注意**:mmcv-full 没有提供 Windows 平台 `cu102-torch1.8.0` 和 `cu92-torch*` 的预编译包。
+
+除了使用预编译包之外,另一种方式是在本地进行编译,直接运行下述命令
+
+```python
+pip install mmcv-full
+```
+
+但注意本地编译可能会耗时 10 分钟以上。
+
+b. 安装精简版
+
+```python
+pip install mmcv
+```
+
+c. 安装完整版并且编译 onnxruntime 的自定义算子
+
+- 详细的指南请查看[这里](docs/zh_cn/deployment/onnxruntime_op.md)。
+
+如果想从源码编译 MMCV,请参考[该文档](https://mmcv.readthedocs.io/zh_CN/latest/get_started/build.html)。
+
+## FAQ
+
+如果你遇到了安装问题,CUDA 相关的问题或者 RuntimeErrors,可以首先参考[问题解决页面](https://mmcv.readthedocs.io/zh_CN/latest/faq.html) 看是否已经有解决方案。
+
+## 贡献指南
+
+我们感谢所有的贡献者为改进和提升 MMCV 所作出的努力。请参考[贡献指南](CONTRIBUTING.md)来了解参与项目贡献的相关指引。
+
+## 许可证
+
+`MMCV` 目前以 Apache 2.0 的许可证发布,但是其中有一部分功能并不是使用的 Apache2.0 许可证,我们在 [许可证](LICENSES.md) 中详细地列出了这些功能以及他们对应的许可证,如果您正在从事盈利性活动,请谨慎参考此文档。
+
+## 欢迎加入 OpenMMLab 社区
+
+扫描下方的二维码可关注 OpenMMLab 团队的 [知乎官方账号](https://www.zhihu.com/people/openmmlab),加入 OpenMMLab 团队的 [官方交流 QQ 群](https://jq.qq.com/?_wv=1027&k=3ijNTqfg),或添加微信小助手”OpenMMLabwx“加入官方交流微信群。
+
+
+
+我们会在 OpenMMLab 社区为大家
+
+- 📢 分享 AI 框架的前沿核心技术
+- 💻 解读 PyTorch 常用模块源码
+- 📰 发布 OpenMMLab 的相关新闻
+- 🚀 介绍 OpenMMLab 开发的前沿算法
+- 🏃 获取更高效的问题答疑和意见反馈
+- 🔥 提供与各行各业开发者充分交流的平台
+
+干货满满 📘,等你来撩 💗,OpenMMLab 社区期待您的加入 👬
diff --git a/mmcv/TERMINOLOGY.md b/mmcv/TERMINOLOGY.md
new file mode 100644
index 0000000000000000000000000000000000000000..07411b7774c2ed713f472c1287b98b871c7f4d02
--- /dev/null
+++ b/mmcv/TERMINOLOGY.md
@@ -0,0 +1,30 @@
+# English-Chinese terminology comparison (英汉术语对照)
+
+This document is used as a reference for English-Chinese terminology translation.
+
+该文档用作中英文翻译对照参考。
+
+| English | 中文 |
+| :---------------: | :----------: |
+| annotation | 标注 |
+| backbone | 主干网络 |
+| benchmark | 基准测试 |
+| checkpoint | 模型权重文件 |
+| classifier | 分类器 |
+| cls_head | 分类头 |
+| decoder | 解码器 |
+| detector | 检测器 |
+| encoder | 编码器 |
+| finetune | 微调 |
+| ground truth | 真实标签 |
+| hook | 钩子 |
+| localizer | 定位器 |
+| neck | 模型颈部 |
+| pipeline | 流水线 |
+| recognizer | 识别器 |
+| register | 注册器 |
+| schedule | 调整 |
+| scheduler | 调度器 |
+| segmentor | 分割器 |
+| tensor | 张量 |
+| training schedule | 训练策略 |
diff --git a/mmcv/docker/README.md b/mmcv/docker/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e9985b4ca645a14c9e3f18bf7afcc0cb4f52bf73
--- /dev/null
+++ b/mmcv/docker/README.md
@@ -0,0 +1,70 @@
+# Docker images
+
+There are two `Dockerfile` files to build docker images, one to build an image with the mmcv-full pre-built package and the other with the mmcv development environment.
+
+```text
+.
+|-- README.md
+|-- dev # build with mmcv development environment
+| `-- Dockerfile
+`-- release # build with mmcv pre-built package
+ `-- Dockerfile
+```
+
+## Build docker images
+
+### Build with mmcv pre-built package
+
+Build with local repository
+
+```bash
+git clone https://github.com/open-mmlab/mmcv.git && cd mmcv
+docker build -t mmcv -f docker/release/Dockerfile .
+```
+
+Or build with remote repository
+
+```bash
+docker build -t mmcv https://github.com/open-mmlab/mmcv.git#master:docker/release
+```
+
+The [Dockerfile](release/Dockerfile) installs latest released version of mmcv-full by default, but you can specify mmcv versions to install expected versions.
+
+```bash
+docker image build -t mmcv -f docker/release/Dockerfile --build-arg MMCV=1.5.0 .
+```
+
+If you also want to use other versions of PyTorch and CUDA, you can also pass them when building docker images.
+
+An example to build an image with PyTorch 1.11 and CUDA 11.3.
+
+```bash
+docker build -t mmcv -f docker/release/Dockerfile \
+ --build-arg PYTORCH=1.9.0 \
+ --build-arg CUDA=11.1 \
+ --build-arg CUDNN=8 \
+ --build-arg MMCV=1.5.0 .
+```
+
+More available versions of PyTorch and CUDA can be found at [dockerhub/pytorch](https://hub.docker.com/r/pytorch/pytorch/tags).
+
+### Build with mmcv development environment
+
+If you want to build an docker image with the mmcv development environment, you can use the following command
+
+```bash
+git clone https://github.com/open-mmlab/mmcv.git && cd mmcv
+docker build -t mmcv -f docker/dev/Dockerfile --build-arg CUDA_ARCH=7.5 .
+```
+
+Note that `CUDA_ARCH` is the cumpute capability of your GPU and you can find it at [Compute Capability](https://developer.nvidia.com/cuda-gpus#compute).
+
+The building process may take 10 minutes or more.
+
+## Run images
+
+```bash
+docker run --gpus all --shm-size=8g -it mmcv
+```
+
+See [docker run](https://docs.docker.com/engine/reference/commandline/run/) for more usages.
diff --git a/mmcv/docker/dev/Dockerfile b/mmcv/docker/dev/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..0c673e958f2909cd80f589100c2b7cbfa726c499
--- /dev/null
+++ b/mmcv/docker/dev/Dockerfile
@@ -0,0 +1,32 @@
+ARG PYTORCH="1.8.1"
+ARG CUDA="10.2"
+ARG CUDNN="7"
+
+FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
+
+# To fix GPG key error when running apt-get update
+RUN rm /etc/apt/sources.list.d/cuda.list \
+ && rm /etc/apt/sources.list.d/nvidia-ml.list \
+ && apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub \
+ && apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
+
+# Install git and system dependencies for opencv-python
+RUN apt-get update && apt-get install -y git \
+ && apt-get update && apt-get install -y libgl1 libglib2.0-0
+
+# Install system dependencies for unit tests
+RUN apt-get install -y ffmpeg libturbojpeg \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
+
+# build mmcv-full from source with develop mode
+ARG HTTPS_PROXY=""
+ENV https_proxy=${HTTPS_PROXY}
+ENV FORCE_CUDA="1"
+ENV MMCV_WITH_OPS="1"
+ARG CUDA_ARCH=""
+ENV TORCH_CUDA_ARCH_LIST=${CUDA_ARCH}
+RUN git clone https://github.com/open-mmlab/mmcv.git /mmcv
+WORKDIR /mmcv
+RUN git rev-parse --short HEAD
+RUN pip install --no-cache-dir -e .[all] -v && pip install pre-commit && pre-commit install
diff --git a/mmcv/docker/release/Dockerfile b/mmcv/docker/release/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..493aa6d1625c9bdee1b9f3bd8121c6ff2f723d4a
--- /dev/null
+++ b/mmcv/docker/release/Dockerfile
@@ -0,0 +1,20 @@
+ARG PYTORCH="1.8.1"
+ARG CUDA="10.2"
+ARG CUDNN="7"
+
+FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
+
+# To fix GPG key error when running apt-get update
+RUN rm /etc/apt/sources.list.d/cuda.list \
+ && rm /etc/apt/sources.list.d/nvidia-ml.list \
+ && apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub \
+ && apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
+
+# Install system dependencies for opencv-python
+RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
+
+# Install mmcv-full
+ARG MMCV="1.5.1"
+RUN pip install openmim && mim install mmcv-full==${MMCV} && python -c 'import mmcv;print(mmcv.__version__)'
diff --git a/mmcv/examples/train.py b/mmcv/examples/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..b08d36bf621747354d0df30bd6d787fd2c12faf1
--- /dev/null
+++ b/mmcv/examples/train.py
@@ -0,0 +1,84 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+from torchvision.datasets import CIFAR10
+
+from mmcv.parallel import MMDataParallel
+from mmcv.runner import EpochBasedRunner
+from mmcv.utils import get_logger
+
+
+class Model(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.conv1 = nn.Conv2d(3, 6, 5)
+ self.pool = nn.MaxPool2d(2, 2)
+ self.conv2 = nn.Conv2d(6, 16, 5)
+ self.fc1 = nn.Linear(16 * 5 * 5, 120)
+ self.fc2 = nn.Linear(120, 84)
+ self.fc3 = nn.Linear(84, 10)
+ self.loss_fn = nn.CrossEntropyLoss()
+
+ def forward(self, x):
+ x = self.pool(F.relu(self.conv1(x)))
+ x = self.pool(F.relu(self.conv2(x)))
+ x = x.view(-1, 16 * 5 * 5)
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ x = self.fc3(x)
+ return x
+
+ def train_step(self, data, optimizer):
+ images, labels = data
+ predicts = self(images) # -> self.__call__() -> self.forward()
+ loss = self.loss_fn(predicts, labels)
+ return {'loss': loss}
+
+
+if __name__ == '__main__':
+ model = Model()
+ if torch.cuda.is_available():
+ # only use gpu:0 to train
+ # Solved issue https://github.com/open-mmlab/mmcv/issues/1470
+ model = MMDataParallel(model.cuda(), device_ids=[0])
+
+ # dataset and dataloader
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+ ])
+ trainset = CIFAR10(
+ root='data', train=True, download=True, transform=transform)
+ trainloader = DataLoader(
+ trainset, batch_size=128, shuffle=True, num_workers=2)
+
+ optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
+ logger = get_logger('mmcv')
+ # runner is a scheduler to manage the training
+ runner = EpochBasedRunner(
+ model,
+ optimizer=optimizer,
+ work_dir='./work_dir',
+ logger=logger,
+ max_epochs=4)
+
+ # learning rate scheduler config
+ lr_config = dict(policy='step', step=[2, 3])
+ # configuration of optimizer
+ optimizer_config = dict(grad_clip=None)
+ # configuration of saving checkpoints periodically
+ checkpoint_config = dict(interval=1)
+ # save log periodically and multiple hooks can be used simultaneously
+ log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
+ # register hooks to runner and those hooks will be invoked automatically
+ runner.register_training_hooks(
+ lr_config=lr_config,
+ optimizer_config=optimizer_config,
+ checkpoint_config=checkpoint_config,
+ log_config=log_config)
+
+ runner.run([trainloader], [('train', 1)])
diff --git a/mmcv/mmcv/__init__.py b/mmcv/mmcv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..14c556acdf5832a1da569da6819a428f17adc328
--- /dev/null
+++ b/mmcv/mmcv/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# flake8: noqa
+from .arraymisc import *
+from .fileio import *
+from .image import *
+from .utils import *
+from .version import *
+from .video import *
+from .visualization import *
+
+# The following modules are not imported to this level, so mmcv may be used
+# without PyTorch.
+# - runner
+# - parallel
+# - op
+# - device
diff --git a/mmcv/mmcv/arraymisc/__init__.py b/mmcv/mmcv/arraymisc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b4700d6139ae3d604ff6e542468cce4200c020c
--- /dev/null
+++ b/mmcv/mmcv/arraymisc/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .quantization import dequantize, quantize
+
+__all__ = ['quantize', 'dequantize']
diff --git a/mmcv/mmcv/arraymisc/quantization.py b/mmcv/mmcv/arraymisc/quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..6182710d51787061304cfc7304ec97d565822536
--- /dev/null
+++ b/mmcv/mmcv/arraymisc/quantization.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Union
+
+import numpy as np
+
+
+def quantize(arr: np.ndarray,
+ min_val: Union[int, float],
+ max_val: Union[int, float],
+ levels: int,
+ dtype=np.int64) -> tuple:
+ """Quantize an array of (-inf, inf) to [0, levels-1].
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (int or float): Minimum value to be clipped.
+ max_val (int or float): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the quantized array.
+
+ Returns:
+ tuple: Quantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(
+ f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ arr = np.clip(arr, min_val, max_val) - min_val
+ quantized_arr = np.minimum(
+ np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
+
+ return quantized_arr
+
+
+def dequantize(arr: np.ndarray,
+ min_val: Union[int, float],
+ max_val: Union[int, float],
+ levels: int,
+ dtype=np.float64) -> tuple:
+ """Dequantize an array.
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (int or float): Minimum value to be clipped.
+ max_val (int or float): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the dequantized array.
+
+ Returns:
+ tuple: Dequantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(
+ f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
+ min_val) / levels + min_val
+
+ return dequantized_arr
diff --git a/mmcv/mmcv/cnn/__init__.py b/mmcv/mmcv/cnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7246c897430f0cc7ce12719ad8608824fc734446
--- /dev/null
+++ b/mmcv/mmcv/cnn/__init__.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .alexnet import AlexNet
+# yapf: disable
+from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
+ PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
+ ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
+ ConvTranspose2d, ConvTranspose3d, ConvWS2d,
+ DepthwiseSeparableConvModule, GeneralizedAttention,
+ HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
+ NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
+ build_activation_layer, build_conv_layer,
+ build_norm_layer, build_padding_layer, build_plugin_layer,
+ build_upsample_layer, conv_ws_2d, is_norm)
+from .builder import MODELS, build_model_from_cfg
+# yapf: enable
+from .resnet import ResNet, make_res_layer
+from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
+ NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
+ XavierInit, bias_init_with_prob, caffe2_xavier_init,
+ constant_init, fuse_conv_bn, get_model_complexity_info,
+ initialize, kaiming_init, normal_init, trunc_normal_init,
+ uniform_init, xavier_init)
+from .vgg import VGG, make_vgg_layer
+
+__all__ = [
+ 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
+ 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
+ 'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
+ 'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
+ 'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
+ 'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
+ 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
+ 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
+ 'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
+ 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
+ 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
+ 'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+ 'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
+]
diff --git a/mmcv/mmcv/cnn/alexnet.py b/mmcv/mmcv/cnn/alexnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d45d96d86bdcb52a51f095c4571b21c8421cbfa
--- /dev/null
+++ b/mmcv/mmcv/cnn/alexnet.py
@@ -0,0 +1,63 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+
+class AlexNet(nn.Module):
+ """AlexNet backbone.
+
+ Args:
+ num_classes (int): number of classes for classification.
+ """
+
+ def __init__(self, num_classes: int = -1):
+ super().__init__()
+ self.num_classes = num_classes
+ self.features = nn.Sequential(
+ nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(64, 192, kernel_size=5, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(192, 384, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(384, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ )
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Dropout(),
+ nn.Linear(256 * 6 * 6, 4096),
+ nn.ReLU(inplace=True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(inplace=True),
+ nn.Linear(4096, num_classes),
+ )
+
+ def init_weights(self, pretrained: Optional[str] = None) -> None:
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ # use default initializer
+ pass
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+
+ x = self.features(x)
+ if self.num_classes > 0:
+ x = x.view(x.size(0), 256 * 6 * 6)
+ x = self.classifier(x)
+
+ return x
diff --git a/mmcv/mmcv/cnn/bricks/__init__.py b/mmcv/mmcv/cnn/bricks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f33124ed23fc6f27119a37bcb5ab004d3572be0
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/__init__.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .activation import build_activation_layer
+from .context_block import ContextBlock
+from .conv import build_conv_layer
+from .conv2d_adaptive_padding import Conv2dAdaptivePadding
+from .conv_module import ConvModule
+from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
+from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
+from .drop import Dropout, DropPath
+from .generalized_attention import GeneralizedAttention
+from .hsigmoid import HSigmoid
+from .hswish import HSwish
+from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
+from .norm import build_norm_layer, is_norm
+from .padding import build_padding_layer
+from .plugin import build_plugin_layer
+from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
+ PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS)
+from .scale import Scale
+from .swish import Swish
+from .upsample import build_upsample_layer
+from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
+ Linear, MaxPool2d, MaxPool3d)
+
+__all__ = [
+ 'ConvModule', 'build_activation_layer', 'build_conv_layer',
+ 'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
+ 'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
+ 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
+ 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
+ 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
+ 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
+ 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
+]
diff --git a/mmcv/mmcv/cnn/bricks/activation.py b/mmcv/mmcv/cnn/bricks/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..23e62722776d18b764cffe4a76e646e3103f8fb7
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/activation.py
@@ -0,0 +1,95 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mmcv.utils import TORCH_VERSION, build_from_cfg, digit_version
+from .registry import ACTIVATION_LAYERS
+
+for module in [
+ nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
+ nn.Sigmoid, nn.Tanh
+]:
+ ACTIVATION_LAYERS.register_module(module=module)
+
+
+@ACTIVATION_LAYERS.register_module(name='Clip')
+@ACTIVATION_LAYERS.register_module()
+class Clamp(nn.Module):
+ """Clamp activation layer.
+
+ This activation function is to clamp the feature map value within
+ :math:`[min, max]`. More details can be found in ``torch.clamp()``.
+
+ Args:
+ min (Number | optional): Lower-bound of the range to be clamped to.
+ Default to -1.
+ max (Number | optional): Upper-bound of the range to be clamped to.
+ Default to 1.
+ """
+
+ def __init__(self, min: float = -1., max: float = 1.):
+ super().__init__()
+ self.min = min
+ self.max = max
+
+ def forward(self, x) -> torch.Tensor:
+ """Forward function.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: Clamped tensor.
+ """
+ return torch.clamp(x, min=self.min, max=self.max)
+
+
+class GELU(nn.Module):
+ r"""Applies the Gaussian Error Linear Units function:
+
+ .. math::
+ \text{GELU}(x) = x * \Phi(x)
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for
+ Gaussian Distribution.
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/GELU.png
+
+ Examples::
+
+ >>> m = nn.GELU()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.gelu(input)
+
+
+if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.4')):
+ ACTIVATION_LAYERS.register_module(module=GELU)
+else:
+ ACTIVATION_LAYERS.register_module(module=nn.GELU)
+
+
+def build_activation_layer(cfg: Dict) -> nn.Module:
+ """Build activation layer.
+
+ Args:
+ cfg (dict): The activation layer config, which should contain:
+
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an activation layer.
+
+ Returns:
+ nn.Module: Created activation layer.
+ """
+ return build_from_cfg(cfg, ACTIVATION_LAYERS)
diff --git a/mmcv/mmcv/cnn/bricks/context_block.py b/mmcv/mmcv/cnn/bricks/context_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..15669cab35dcdc98a95df006788f78f84b88dc44
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/context_block.py
@@ -0,0 +1,127 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Union
+
+import torch
+from torch import nn
+
+from ..utils import constant_init, kaiming_init
+from .registry import PLUGIN_LAYERS
+
+
+def last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None:
+ if isinstance(m, nn.Sequential):
+ constant_init(m[-1], val=0)
+ else:
+ constant_init(m, val=0)
+
+
+@PLUGIN_LAYERS.register_module()
+class ContextBlock(nn.Module):
+ """ContextBlock module in GCNet.
+
+ See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
+ (https://arxiv.org/abs/1904.11492) for details.
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ ratio (float): Ratio of channels of transform bottleneck
+ pooling_type (str): Pooling method for context modeling.
+ Options are 'att' and 'avg', stand for attention pooling and
+ average pooling respectively. Default: 'att'.
+ fusion_types (Sequence[str]): Fusion method for feature fusion,
+ Options are 'channels_add', 'channel_mul', stand for channelwise
+ addition and multiplication respectively. Default: ('channel_add',)
+ """
+
+ _abbr_ = 'context_block'
+
+ def __init__(self,
+ in_channels: int,
+ ratio: float,
+ pooling_type: str = 'att',
+ fusion_types: tuple = ('channel_add', )):
+ super().__init__()
+ assert pooling_type in ['avg', 'att']
+ assert isinstance(fusion_types, (list, tuple))
+ valid_fusion_types = ['channel_add', 'channel_mul']
+ assert all([f in valid_fusion_types for f in fusion_types])
+ assert len(fusion_types) > 0, 'at least one fusion should be used'
+ self.in_channels = in_channels
+ self.ratio = ratio
+ self.planes = int(in_channels * ratio)
+ self.pooling_type = pooling_type
+ self.fusion_types = fusion_types
+ if pooling_type == 'att':
+ self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
+ self.softmax = nn.Softmax(dim=2)
+ else:
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ if 'channel_add' in fusion_types:
+ self.channel_add_conv = nn.Sequential(
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(inplace=True), # yapf: disable
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
+ else:
+ self.channel_add_conv = None
+ if 'channel_mul' in fusion_types:
+ self.channel_mul_conv = nn.Sequential(
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(inplace=True), # yapf: disable
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
+ else:
+ self.channel_mul_conv = None
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.pooling_type == 'att':
+ kaiming_init(self.conv_mask, mode='fan_in')
+ self.conv_mask.inited = True
+
+ if self.channel_add_conv is not None:
+ last_zero_init(self.channel_add_conv)
+ if self.channel_mul_conv is not None:
+ last_zero_init(self.channel_mul_conv)
+
+ def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:
+ batch, channel, height, width = x.size()
+ if self.pooling_type == 'att':
+ input_x = x
+ # [N, C, H * W]
+ input_x = input_x.view(batch, channel, height * width)
+ # [N, 1, C, H * W]
+ input_x = input_x.unsqueeze(1)
+ # [N, 1, H, W]
+ context_mask = self.conv_mask(x)
+ # [N, 1, H * W]
+ context_mask = context_mask.view(batch, 1, height * width)
+ # [N, 1, H * W]
+ context_mask = self.softmax(context_mask)
+ # [N, 1, H * W, 1]
+ context_mask = context_mask.unsqueeze(-1)
+ # [N, 1, C, 1]
+ context = torch.matmul(input_x, context_mask)
+ # [N, C, 1, 1]
+ context = context.view(batch, channel, 1, 1)
+ else:
+ # [N, C, 1, 1]
+ context = self.avg_pool(x)
+
+ return context
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # [N, C, 1, 1]
+ context = self.spatial_pool(x)
+
+ out = x
+ if self.channel_mul_conv is not None:
+ # [N, C, 1, 1]
+ channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
+ out = out * channel_mul_term
+ if self.channel_add_conv is not None:
+ # [N, C, 1, 1]
+ channel_add_term = self.channel_add_conv(context)
+ out = out + channel_add_term
+
+ return out
diff --git a/mmcv/mmcv/cnn/bricks/conv.py b/mmcv/mmcv/cnn/bricks/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..147517ef4ecdee16d26b535fa49c26a2fcbdd48e
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/conv.py
@@ -0,0 +1,46 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, Optional
+
+from torch import nn
+
+from .registry import CONV_LAYERS
+
+CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d)
+CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d)
+CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
+CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
+
+
+def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
+ """Build convolution layer.
+
+ Args:
+ cfg (None or dict): The conv layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an conv layer.
+ args (argument list): Arguments passed to the `__init__`
+ method of the corresponding conv layer.
+ kwargs (keyword arguments): Keyword arguments passed to the `__init__`
+ method of the corresponding conv layer.
+
+ Returns:
+ nn.Module: Created conv layer.
+ """
+ if cfg is None:
+ cfg_ = dict(type='Conv2d')
+ else:
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in CONV_LAYERS:
+ raise KeyError(f'Unrecognized layer type {layer_type}')
+ else:
+ conv_layer = CONV_LAYERS.get(layer_type)
+
+ layer = conv_layer(*args, **kwargs, **cfg_)
+
+ return layer
diff --git a/mmcv/mmcv/cnn/bricks/conv2d_adaptive_padding.py b/mmcv/mmcv/cnn/bricks/conv2d_adaptive_padding.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a7a1d2844db097c21e5ecc55a579e0b9b95c816
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/conv2d_adaptive_padding.py
@@ -0,0 +1,64 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from typing import Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .registry import CONV_LAYERS
+
+
+@CONV_LAYERS.register_module()
+class Conv2dAdaptivePadding(nn.Conv2d):
+ """Implementation of 2D convolution in tensorflow with `padding` as "same",
+ which applies padding to input (if needed) so that input image gets fully
+ covered by filter and stride you specified. For stride 1, this will ensure
+ that output image size is same as input. For stride of 2, output dimensions
+ will be half, for example.
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the
+ output. Default: ``True``
+ """
+
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ stride: Union[int, Tuple[int, int]] = 1,
+ padding: Union[int, Tuple[int, int]] = 0,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ groups: int = 1,
+ bias: bool = True):
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0,
+ dilation, groups, bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ img_h, img_w = x.size()[-2:]
+ kernel_h, kernel_w = self.weight.size()[-2:]
+ stride_h, stride_w = self.stride
+ output_h = math.ceil(img_h / stride_h)
+ output_w = math.ceil(img_w / stride_w)
+ pad_h = (
+ max((output_h - 1) * self.stride[0] +
+ (kernel_h - 1) * self.dilation[0] + 1 - img_h, 0))
+ pad_w = (
+ max((output_w - 1) * self.stride[1] +
+ (kernel_w - 1) * self.dilation[1] + 1 - img_w, 0))
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
+ ])
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups)
diff --git a/mmcv/mmcv/cnn/bricks/conv_module.py b/mmcv/mmcv/cnn/bricks/conv_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5d4a8c2760ea81656d3eefdad86e8dd43488447
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/conv_module.py
@@ -0,0 +1,212 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from mmcv.utils import _BatchNorm, _InstanceNorm
+from ..utils import constant_init, kaiming_init
+from .activation import build_activation_layer
+from .conv import build_conv_layer
+from .norm import build_norm_layer
+from .padding import build_padding_layer
+from .registry import PLUGIN_LAYERS
+
+
+@PLUGIN_LAYERS.register_module()
+class ConvModule(nn.Module):
+ """A conv block that bundles conv/norm/activation layers.
+
+ This block simplifies the usage of convolution layers, which are commonly
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+ It is based upon three build methods: `build_conv_layer()`,
+ `build_norm_layer()` and `build_activation_layer()`.
+
+ Besides, we add some additional features in this module.
+ 1. Automatically set `bias` of the conv layer.
+ 2. Spectral norm is supported.
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
+ supports zero and circular padding, and we add "reflect" padding mode.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ Same as that in ``nn._ConvNd``.
+ out_channels (int): Number of channels produced by the convolution.
+ Same as that in ``nn._ConvNd``.
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
+ Same as that in ``nn._ConvNd``.
+ stride (int | tuple[int]): Stride of the convolution.
+ Same as that in ``nn._ConvNd``.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ the input. Same as that in ``nn._ConvNd``.
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ Same as that in ``nn._ConvNd``.
+ groups (int): Number of blocked connections from input channels to
+ output channels. Same as that in ``nn._ConvNd``.
+ bias (bool | str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
+ False. Default: "auto".
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ inplace (bool): Whether to use inplace mode for activation.
+ Default: True.
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
+ Default: False.
+ padding_mode (str): If the `padding_mode` has not been supported by
+ current `Conv2d` in PyTorch, we will use our own padding layer
+ instead. Currently, we support ['zeros', 'circular'] with official
+ implementation and ['reflect'] with our own implementation.
+ Default: 'zeros'.
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
+ sequence of "conv", "norm" and "act". Common examples are
+ ("conv", "norm", "act") and ("act", "conv", "norm").
+ Default: ('conv', 'norm', 'act').
+ """
+
+ _abbr_ = 'conv_block'
+
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ stride: Union[int, Tuple[int, int]] = 1,
+ padding: Union[int, Tuple[int, int]] = 0,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ groups: int = 1,
+ bias: Union[bool, str] = 'auto',
+ conv_cfg: Optional[Dict] = None,
+ norm_cfg: Optional[Dict] = None,
+ act_cfg: Optional[Dict] = dict(type='ReLU'),
+ inplace: bool = True,
+ with_spectral_norm: bool = False,
+ padding_mode: str = 'zeros',
+ order: tuple = ('conv', 'norm', 'act')):
+ super().__init__()
+ assert conv_cfg is None or isinstance(conv_cfg, dict)
+ assert norm_cfg is None or isinstance(norm_cfg, dict)
+ assert act_cfg is None or isinstance(act_cfg, dict)
+ official_padding_mode = ['zeros', 'circular']
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.inplace = inplace
+ self.with_spectral_norm = with_spectral_norm
+ self.with_explicit_padding = padding_mode not in official_padding_mode
+ self.order = order
+ assert isinstance(self.order, tuple) and len(self.order) == 3
+ assert set(order) == {'conv', 'norm', 'act'}
+
+ self.with_norm = norm_cfg is not None
+ self.with_activation = act_cfg is not None
+ # if the conv layer is before a norm layer, bias is unnecessary.
+ if bias == 'auto':
+ bias = not self.with_norm
+ self.with_bias = bias
+
+ if self.with_explicit_padding:
+ pad_cfg = dict(type=padding_mode)
+ self.padding_layer = build_padding_layer(pad_cfg, padding)
+
+ # reset padding to 0 for conv module
+ conv_padding = 0 if self.with_explicit_padding else padding
+ # build convolution layer
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=conv_padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ # export the attributes of self.conv to a higher level for convenience
+ self.in_channels = self.conv.in_channels
+ self.out_channels = self.conv.out_channels
+ self.kernel_size = self.conv.kernel_size
+ self.stride = self.conv.stride
+ self.padding = padding
+ self.dilation = self.conv.dilation
+ self.transposed = self.conv.transposed
+ self.output_padding = self.conv.output_padding
+ self.groups = self.conv.groups
+
+ if self.with_spectral_norm:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ # build normalization layers
+ if self.with_norm:
+ # norm layer is after conv layer
+ if order.index('norm') > order.index('conv'):
+ norm_channels = out_channels
+ else:
+ norm_channels = in_channels
+ self.norm_name, norm = build_norm_layer(
+ norm_cfg, norm_channels) # type: ignore
+ self.add_module(self.norm_name, norm)
+ if self.with_bias:
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
+ warnings.warn(
+ 'Unnecessary conv bias before batch/instance norm')
+ else:
+ self.norm_name = None # type: ignore
+
+ # build activation layer
+ if self.with_activation:
+ act_cfg_ = act_cfg.copy() # type: ignore
+ # nn.Tanh has no 'inplace' argument
+ if act_cfg_['type'] not in [
+ 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU'
+ ]:
+ act_cfg_.setdefault('inplace', inplace)
+ self.activate = build_activation_layer(act_cfg_)
+
+ # Use msra init by default
+ self.init_weights()
+
+ @property
+ def norm(self):
+ if self.norm_name:
+ return getattr(self, self.norm_name)
+ else:
+ return None
+
+ def init_weights(self):
+ # 1. It is mainly for customized conv layers with their own
+ # initialization manners by calling their own ``init_weights()``,
+ # and we do not want ConvModule to override the initialization.
+ # 2. For customized conv layers without their own initialization
+ # manners (that is, they don't have their own ``init_weights()``)
+ # and PyTorch's conv layers, they will be initialized by
+ # this method with default ``kaiming_init``.
+ # Note: For PyTorch's conv layers, they will be overwritten by our
+ # initialization implementation using default ``kaiming_init``.
+ if not hasattr(self.conv, 'init_weights'):
+ if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
+ nonlinearity = 'leaky_relu'
+ a = self.act_cfg.get('negative_slope', 0.01)
+ else:
+ nonlinearity = 'relu'
+ a = 0
+ kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
+ if self.with_norm:
+ constant_init(self.norm, 1, bias=0)
+
+ def forward(self,
+ x: torch.Tensor,
+ activate: bool = True,
+ norm: bool = True) -> torch.Tensor:
+ for layer in self.order:
+ if layer == 'conv':
+ if self.with_explicit_padding:
+ x = self.padding_layer(x)
+ x = self.conv(x)
+ elif layer == 'norm' and norm and self.with_norm:
+ x = self.norm(x)
+ elif layer == 'act' and activate and self.with_activation:
+ x = self.activate(x)
+ return x
diff --git a/mmcv/mmcv/cnn/bricks/conv_ws.py b/mmcv/mmcv/cnn/bricks/conv_ws.py
new file mode 100644
index 0000000000000000000000000000000000000000..6569f920fea942a9345ff509c7dbdb6ace1f3741
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/conv_ws.py
@@ -0,0 +1,154 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .registry import CONV_LAYERS
+
+
+def conv_ws_2d(input: torch.Tensor,
+ weight: torch.Tensor,
+ bias: Optional[torch.Tensor] = None,
+ stride: Union[int, Tuple[int, int]] = 1,
+ padding: Union[int, Tuple[int, int]] = 0,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ groups: int = 1,
+ eps: float = 1e-5) -> torch.Tensor:
+ c_in = weight.size(0)
+ weight_flat = weight.view(c_in, -1)
+ mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+ std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+ weight = (weight - mean) / (std + eps)
+ return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
+
+
+@CONV_LAYERS.register_module('ConvWS')
+class ConvWS2d(nn.Conv2d):
+
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ stride: Union[int, Tuple[int, int]] = 1,
+ padding: Union[int, Tuple[int, int]] = 0,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ groups: int = 1,
+ bias: bool = True,
+ eps: float = 1e-5):
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups, self.eps)
+
+
+@CONV_LAYERS.register_module(name='ConvAWS')
+class ConvAWS2d(nn.Conv2d):
+ """AWS (Adaptive Weight Standardization)
+
+ This is a variant of Weight Standardization
+ (https://arxiv.org/pdf/1903.10520.pdf)
+ It is used in DetectoRS to avoid NaN
+ (https://arxiv.org/pdf/2006.02334.pdf)
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the conv kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If set True, adds a learnable bias to the
+ output. Default: True
+ """
+
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ stride: Union[int, Tuple[int, int]] = 1,
+ padding: Union[int, Tuple[int, int]] = 0,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ groups: int = 1,
+ bias: bool = True):
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ self.register_buffer('weight_gamma',
+ torch.ones(self.out_channels, 1, 1, 1))
+ self.register_buffer('weight_beta',
+ torch.zeros(self.out_channels, 1, 1, 1))
+
+ def _get_weight(self, weight: torch.Tensor) -> torch.Tensor:
+ weight_flat = weight.view(weight.size(0), -1)
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
+ weight = (weight - mean) / std
+ weight = self.weight_gamma * weight + self.weight_beta
+ return weight
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ weight = self._get_weight(self.weight)
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups)
+
+ def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str,
+ local_metadata: Dict, strict: bool,
+ missing_keys: List[str],
+ unexpected_keys: List[str],
+ error_msgs: List[str]) -> None:
+ """Override default load function.
+
+ AWS overrides the function _load_from_state_dict to recover
+ weight_gamma and weight_beta if they are missing. If weight_gamma and
+ weight_beta are found in the checkpoint, this function will return
+ after super()._load_from_state_dict. Otherwise, it will compute the
+ mean and std of the pretrained weights and store them in weight_beta
+ and weight_gamma.
+ """
+
+ self.weight_gamma.data.fill_(-1)
+ local_missing_keys: List = []
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, local_missing_keys,
+ unexpected_keys, error_msgs)
+ if self.weight_gamma.data.mean() > 0:
+ for k in local_missing_keys:
+ missing_keys.append(k)
+ return
+ weight = self.weight.data
+ weight_flat = weight.view(weight.size(0), -1)
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
+ self.weight_beta.data.copy_(mean)
+ self.weight_gamma.data.copy_(std)
+ missing_gamma_beta = [
+ k for k in local_missing_keys
+ if k.endswith('weight_gamma') or k.endswith('weight_beta')
+ ]
+ for k in missing_gamma_beta:
+ local_missing_keys.remove(k)
+ for k in local_missing_keys:
+ missing_keys.append(k)
diff --git a/mmcv/mmcv/cnn/bricks/depthwise_separable_conv_module.py b/mmcv/mmcv/cnn/bricks/depthwise_separable_conv_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf1fe4cad3812007573211fa2bede28b23822122
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/depthwise_separable_conv_module.py
@@ -0,0 +1,99 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from .conv_module import ConvModule
+
+
+class DepthwiseSeparableConvModule(nn.Module):
+ """Depthwise separable convolution module.
+
+ See https://arxiv.org/pdf/1704.04861.pdf for details.
+
+ This module can replace a ConvModule with the conv block replaced by two
+ conv block: depthwise conv block and pointwise conv block. The depthwise
+ conv block contains depthwise-conv/norm/activation layers. The pointwise
+ conv block contains pointwise-conv/norm/activation layers. It should be
+ noted that there will be norm/activation layer in the depthwise conv block
+ if `norm_cfg` and `act_cfg` are specified.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ Same as that in ``nn._ConvNd``.
+ out_channels (int): Number of channels produced by the convolution.
+ Same as that in ``nn._ConvNd``.
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
+ Same as that in ``nn._ConvNd``.
+ stride (int | tuple[int]): Stride of the convolution.
+ Same as that in ``nn._ConvNd``. Default: 1.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ the input. Same as that in ``nn._ConvNd``. Default: 0.
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ Same as that in ``nn._ConvNd``. Default: 1.
+ norm_cfg (dict): Default norm config for both depthwise ConvModule and
+ pointwise ConvModule. Default: None.
+ act_cfg (dict): Default activation config for both depthwise ConvModule
+ and pointwise ConvModule. Default: dict(type='ReLU').
+ dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
+ dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
+ pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
+ pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
+ kwargs (optional): Other shared arguments for depthwise and pointwise
+ ConvModule. See ConvModule for ref.
+ """
+
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ stride: Union[int, Tuple[int, int]] = 1,
+ padding: Union[int, Tuple[int, int]] = 0,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ norm_cfg: Optional[Dict] = None,
+ act_cfg: Dict = dict(type='ReLU'),
+ dw_norm_cfg: Union[Dict, str] = 'default',
+ dw_act_cfg: Union[Dict, str] = 'default',
+ pw_norm_cfg: Union[Dict, str] = 'default',
+ pw_act_cfg: Union[Dict, str] = 'default',
+ **kwargs):
+ super().__init__()
+ assert 'groups' not in kwargs, 'groups should not be specified'
+
+ # if norm/activation config of depthwise/pointwise ConvModule is not
+ # specified, use default config.
+ dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
+ dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
+ pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
+ pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
+
+ # depthwise convolution
+ self.depthwise_conv = ConvModule(
+ in_channels,
+ in_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=in_channels,
+ norm_cfg=dw_norm_cfg, # type: ignore
+ act_cfg=dw_act_cfg, # type: ignore
+ **kwargs)
+
+ self.pointwise_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 1,
+ norm_cfg=pw_norm_cfg, # type: ignore
+ act_cfg=pw_act_cfg, # type: ignore
+ **kwargs)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.depthwise_conv(x)
+ x = self.pointwise_conv(x)
+ return x
diff --git a/mmcv/mmcv/cnn/bricks/drop.py b/mmcv/mmcv/cnn/bricks/drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea05221d854592a5d885efbef002cb673c65f778
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/drop.py
@@ -0,0 +1,69 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn as nn
+
+from mmcv import build_from_cfg
+from .registry import DROPOUT_LAYERS
+
+
+def drop_path(x: torch.Tensor,
+ drop_prob: float = 0.,
+ training: bool = False) -> torch.Tensor:
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+
+ We follow the implementation
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ # handle tensors with different dimensions, not just 4D tensors.
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
+ random_tensor = keep_prob + torch.rand(
+ shape, dtype=x.dtype, device=x.device)
+ output = x.div(keep_prob) * random_tensor.floor()
+ return output
+
+
+@DROPOUT_LAYERS.register_module()
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+
+ We follow the implementation
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
+
+ Args:
+ drop_prob (float): Probability of the path to be zeroed. Default: 0.1
+ """
+
+ def __init__(self, drop_prob: float = 0.1):
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return drop_path(x, self.drop_prob, self.training)
+
+
+@DROPOUT_LAYERS.register_module()
+class Dropout(nn.Dropout):
+ """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
+ ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
+ ``DropPath``
+
+ Args:
+ drop_prob (float): Probability of the elements to be
+ zeroed. Default: 0.5.
+ inplace (bool): Do the operation inplace or not. Default: False.
+ """
+
+ def __init__(self, drop_prob: float = 0.5, inplace: bool = False):
+ super().__init__(p=drop_prob, inplace=inplace)
+
+
+def build_dropout(cfg: Dict, default_args: Optional[Dict] = None) -> Any:
+ """Builder for drop out layers."""
+ return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
diff --git a/mmcv/mmcv/cnn/bricks/generalized_attention.py b/mmcv/mmcv/cnn/bricks/generalized_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..118e39c7ea2d9f24a97f22878dfbe753c4afef0b
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/generalized_attention.py
@@ -0,0 +1,412 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..utils import kaiming_init
+from .registry import PLUGIN_LAYERS
+
+
+@PLUGIN_LAYERS.register_module()
+class GeneralizedAttention(nn.Module):
+ """GeneralizedAttention module.
+
+ See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks'
+ (https://arxiv.org/abs/1711.07971) for details.
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ spatial_range (int): The spatial range. -1 indicates no spatial range
+ constraint. Default: -1.
+ num_heads (int): The head number of empirical_attention module.
+ Default: 9.
+ position_embedding_dim (int): The position embedding dimension.
+ Default: -1.
+ position_magnitude (int): A multiplier acting on coord difference.
+ Default: 1.
+ kv_stride (int): The feature stride acting on key/value feature map.
+ Default: 2.
+ q_stride (int): The feature stride acting on query feature map.
+ Default: 1.
+ attention_type (str): A binary indicator string for indicating which
+ items in generalized empirical_attention module are used.
+ Default: '1111'.
+
+ - '1000' indicates 'query and key content' (appr - appr) item,
+ - '0100' indicates 'query content and relative position'
+ (appr - position) item,
+ - '0010' indicates 'key content only' (bias - appr) item,
+ - '0001' indicates 'relative position only' (bias - position) item.
+ """
+
+ _abbr_ = 'gen_attention_block'
+
+ def __init__(self,
+ in_channels: int,
+ spatial_range: int = -1,
+ num_heads: int = 9,
+ position_embedding_dim: int = -1,
+ position_magnitude: int = 1,
+ kv_stride: int = 2,
+ q_stride: int = 1,
+ attention_type: str = '1111'):
+
+ super().__init__()
+
+ # hard range means local range for non-local operation
+ self.position_embedding_dim = (
+ position_embedding_dim
+ if position_embedding_dim > 0 else in_channels)
+
+ self.position_magnitude = position_magnitude
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.spatial_range = spatial_range
+ self.kv_stride = kv_stride
+ self.q_stride = q_stride
+ self.attention_type = [bool(int(_)) for _ in attention_type]
+ self.qk_embed_dim = in_channels // num_heads
+ out_c = self.qk_embed_dim * num_heads
+
+ if self.attention_type[0] or self.attention_type[1]:
+ self.query_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_c,
+ kernel_size=1,
+ bias=False)
+ self.query_conv.kaiming_init = True
+
+ if self.attention_type[0] or self.attention_type[2]:
+ self.key_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_c,
+ kernel_size=1,
+ bias=False)
+ self.key_conv.kaiming_init = True
+
+ self.v_dim = in_channels // num_heads
+ self.value_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=self.v_dim * num_heads,
+ kernel_size=1,
+ bias=False)
+ self.value_conv.kaiming_init = True
+
+ if self.attention_type[1] or self.attention_type[3]:
+ self.appr_geom_fc_x = nn.Linear(
+ self.position_embedding_dim // 2, out_c, bias=False)
+ self.appr_geom_fc_x.kaiming_init = True
+
+ self.appr_geom_fc_y = nn.Linear(
+ self.position_embedding_dim // 2, out_c, bias=False)
+ self.appr_geom_fc_y.kaiming_init = True
+
+ if self.attention_type[2]:
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
+ appr_bias_value = -2 * stdv * torch.rand(out_c) + stdv
+ self.appr_bias = nn.Parameter(appr_bias_value)
+
+ if self.attention_type[3]:
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
+ geom_bias_value = -2 * stdv * torch.rand(out_c) + stdv
+ self.geom_bias = nn.Parameter(geom_bias_value)
+
+ self.proj_conv = nn.Conv2d(
+ in_channels=self.v_dim * num_heads,
+ out_channels=in_channels,
+ kernel_size=1,
+ bias=True)
+ self.proj_conv.kaiming_init = True
+ self.gamma = nn.Parameter(torch.zeros(1))
+
+ if self.spatial_range >= 0:
+ # only works when non local is after 3*3 conv
+ if in_channels == 256:
+ max_len = 84
+ elif in_channels == 512:
+ max_len = 42
+
+ max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
+ local_constraint_map = np.ones(
+ (max_len, max_len, max_len_kv, max_len_kv), dtype=int)
+ for iy in range(max_len):
+ for ix in range(max_len):
+ local_constraint_map[
+ iy, ix,
+ max((iy - self.spatial_range) //
+ self.kv_stride, 0):min((iy + self.spatial_range +
+ 1) // self.kv_stride +
+ 1, max_len),
+ max((ix - self.spatial_range) //
+ self.kv_stride, 0):min((ix + self.spatial_range +
+ 1) // self.kv_stride +
+ 1, max_len)] = 0
+
+ self.local_constraint_map = nn.Parameter(
+ torch.from_numpy(local_constraint_map).byte(),
+ requires_grad=False)
+
+ if self.q_stride > 1:
+ self.q_downsample = nn.AvgPool2d(
+ kernel_size=1, stride=self.q_stride)
+ else:
+ self.q_downsample = None
+
+ if self.kv_stride > 1:
+ self.kv_downsample = nn.AvgPool2d(
+ kernel_size=1, stride=self.kv_stride)
+ else:
+ self.kv_downsample = None
+
+ self.init_weights()
+
+ def get_position_embedding(self,
+ h,
+ w,
+ h_kv,
+ w_kv,
+ q_stride,
+ kv_stride,
+ device,
+ dtype,
+ feat_dim,
+ wave_length=1000):
+ # the default type of Tensor is float32, leading to type mismatch
+ # in fp16 mode. Cast it to support fp16 mode.
+ h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
+ h_idxs = h_idxs.view((h, 1)) * q_stride
+
+ w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
+ w_idxs = w_idxs.view((w, 1)) * q_stride
+
+ h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
+ device=device, dtype=dtype)
+ h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
+
+ w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
+ device=device, dtype=dtype)
+ w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
+
+ # (h, h_kv, 1)
+ h_diff = h_idxs.unsqueeze(1) - h_kv_idxs.unsqueeze(0)
+ h_diff *= self.position_magnitude
+
+ # (w, w_kv, 1)
+ w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
+ w_diff *= self.position_magnitude
+
+ feat_range = torch.arange(0, feat_dim / 4).to(
+ device=device, dtype=dtype)
+
+ dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
+ dim_mat = dim_mat**((4. / feat_dim) * feat_range)
+ dim_mat = dim_mat.view((1, 1, -1))
+
+ embedding_x = torch.cat(
+ ((w_diff / dim_mat).sin(), (w_diff / dim_mat).cos()), dim=2)
+
+ embedding_y = torch.cat(
+ ((h_diff / dim_mat).sin(), (h_diff / dim_mat).cos()), dim=2)
+
+ return embedding_x, embedding_y
+
+ def forward(self, x_input: torch.Tensor) -> torch.Tensor:
+ num_heads = self.num_heads
+
+ # use empirical_attention
+ if self.q_downsample is not None:
+ x_q = self.q_downsample(x_input)
+ else:
+ x_q = x_input
+ n, _, h, w = x_q.shape
+
+ if self.kv_downsample is not None:
+ x_kv = self.kv_downsample(x_input)
+ else:
+ x_kv = x_input
+ _, _, h_kv, w_kv = x_kv.shape
+
+ if self.attention_type[0] or self.attention_type[1]:
+ proj_query = self.query_conv(x_q).view(
+ (n, num_heads, self.qk_embed_dim, h * w))
+ proj_query = proj_query.permute(0, 1, 3, 2)
+
+ if self.attention_type[0] or self.attention_type[2]:
+ proj_key = self.key_conv(x_kv).view(
+ (n, num_heads, self.qk_embed_dim, h_kv * w_kv))
+
+ if self.attention_type[1] or self.attention_type[3]:
+ position_embed_x, position_embed_y = self.get_position_embedding(
+ h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
+ x_input.device, x_input.dtype, self.position_embedding_dim)
+ # (n, num_heads, w, w_kv, dim)
+ position_feat_x = self.appr_geom_fc_x(position_embed_x).\
+ view(1, w, w_kv, num_heads, self.qk_embed_dim).\
+ permute(0, 3, 1, 2, 4).\
+ repeat(n, 1, 1, 1, 1)
+
+ # (n, num_heads, h, h_kv, dim)
+ position_feat_y = self.appr_geom_fc_y(position_embed_y).\
+ view(1, h, h_kv, num_heads, self.qk_embed_dim).\
+ permute(0, 3, 1, 2, 4).\
+ repeat(n, 1, 1, 1, 1)
+
+ position_feat_x /= math.sqrt(2)
+ position_feat_y /= math.sqrt(2)
+
+ # accelerate for saliency only
+ if (np.sum(self.attention_type) == 1) and self.attention_type[2]:
+ appr_bias = self.appr_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim).\
+ repeat(n, 1, 1, 1)
+
+ energy = torch.matmul(appr_bias, proj_key).\
+ view(n, num_heads, 1, h_kv * w_kv)
+
+ h = 1
+ w = 1
+ else:
+ # (n, num_heads, h*w, h_kv*w_kv), query before key, 540mb for
+ if not self.attention_type[0]:
+ energy = torch.zeros(
+ n,
+ num_heads,
+ h,
+ w,
+ h_kv,
+ w_kv,
+ dtype=x_input.dtype,
+ device=x_input.device)
+
+ # attention_type[0]: appr - appr
+ # attention_type[1]: appr - position
+ # attention_type[2]: bias - appr
+ # attention_type[3]: bias - position
+ if self.attention_type[0] or self.attention_type[2]:
+ if self.attention_type[0] and self.attention_type[2]:
+ appr_bias = self.appr_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim)
+ energy = torch.matmul(proj_query + appr_bias, proj_key).\
+ view(n, num_heads, h, w, h_kv, w_kv)
+
+ elif self.attention_type[0]:
+ energy = torch.matmul(proj_query, proj_key).\
+ view(n, num_heads, h, w, h_kv, w_kv)
+
+ elif self.attention_type[2]:
+ appr_bias = self.appr_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim).\
+ repeat(n, 1, 1, 1)
+
+ energy += torch.matmul(appr_bias, proj_key).\
+ view(n, num_heads, 1, 1, h_kv, w_kv)
+
+ if self.attention_type[1] or self.attention_type[3]:
+ if self.attention_type[1] and self.attention_type[3]:
+ geom_bias = self.geom_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim)
+
+ proj_query_reshape = (proj_query + geom_bias).\
+ view(n, num_heads, h, w, self.qk_embed_dim)
+
+ energy_x = torch.matmul(
+ proj_query_reshape.permute(0, 1, 3, 2, 4),
+ position_feat_x.permute(0, 1, 2, 4, 3))
+ energy_x = energy_x.\
+ permute(0, 1, 3, 2, 4).unsqueeze(4)
+
+ energy_y = torch.matmul(
+ proj_query_reshape,
+ position_feat_y.permute(0, 1, 2, 4, 3))
+ energy_y = energy_y.unsqueeze(5)
+
+ energy += energy_x + energy_y
+
+ elif self.attention_type[1]:
+ proj_query_reshape = proj_query.\
+ view(n, num_heads, h, w, self.qk_embed_dim)
+ proj_query_reshape = proj_query_reshape.\
+ permute(0, 1, 3, 2, 4)
+ position_feat_x_reshape = position_feat_x.\
+ permute(0, 1, 2, 4, 3)
+ position_feat_y_reshape = position_feat_y.\
+ permute(0, 1, 2, 4, 3)
+
+ energy_x = torch.matmul(proj_query_reshape,
+ position_feat_x_reshape)
+ energy_x = energy_x.permute(0, 1, 3, 2, 4).unsqueeze(4)
+
+ energy_y = torch.matmul(proj_query_reshape,
+ position_feat_y_reshape)
+ energy_y = energy_y.unsqueeze(5)
+
+ energy += energy_x + energy_y
+
+ elif self.attention_type[3]:
+ geom_bias = self.geom_bias.\
+ view(1, num_heads, self.qk_embed_dim, 1).\
+ repeat(n, 1, 1, 1)
+
+ position_feat_x_reshape = position_feat_x.\
+ view(n, num_heads, w * w_kv, self.qk_embed_dim)
+
+ position_feat_y_reshape = position_feat_y.\
+ view(n, num_heads, h * h_kv, self.qk_embed_dim)
+
+ energy_x = torch.matmul(position_feat_x_reshape, geom_bias)
+ energy_x = energy_x.view(n, num_heads, 1, w, 1, w_kv)
+
+ energy_y = torch.matmul(position_feat_y_reshape, geom_bias)
+ energy_y = energy_y.view(n, num_heads, h, 1, h_kv, 1)
+
+ energy += energy_x + energy_y
+
+ energy = energy.view(n, num_heads, h * w, h_kv * w_kv)
+
+ if self.spatial_range >= 0:
+ cur_local_constraint_map = \
+ self.local_constraint_map[:h, :w, :h_kv, :w_kv].\
+ contiguous().\
+ view(1, 1, h*w, h_kv*w_kv)
+
+ energy = energy.masked_fill_(cur_local_constraint_map,
+ float('-inf'))
+
+ attention = F.softmax(energy, 3)
+
+ proj_value = self.value_conv(x_kv)
+ proj_value_reshape = proj_value.\
+ view((n, num_heads, self.v_dim, h_kv * w_kv)).\
+ permute(0, 1, 3, 2)
+
+ out = torch.matmul(attention, proj_value_reshape).\
+ permute(0, 1, 3, 2).\
+ contiguous().\
+ view(n, self.v_dim * self.num_heads, h, w)
+
+ out = self.proj_conv(out)
+
+ # output is downsampled, upsample back to input size
+ if self.q_downsample is not None:
+ out = F.interpolate(
+ out,
+ size=x_input.shape[2:],
+ mode='bilinear',
+ align_corners=False)
+
+ out = self.gamma * out + x_input
+ return out
+
+ def init_weights(self):
+ for m in self.modules():
+ if hasattr(m, 'kaiming_init') and m.kaiming_init:
+ kaiming_init(
+ m,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ bias=0,
+ distribution='uniform',
+ a=1)
diff --git a/mmcv/mmcv/cnn/bricks/hsigmoid.py b/mmcv/mmcv/cnn/bricks/hsigmoid.py
new file mode 100644
index 0000000000000000000000000000000000000000..5eb97e8ab13e76c6916a7ebba15cb50f8b846897
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/hsigmoid.py
@@ -0,0 +1,51 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+import torch.nn as nn
+
+from .registry import ACTIVATION_LAYERS
+
+
+@ACTIVATION_LAYERS.register_module()
+class HSigmoid(nn.Module):
+ """Hard Sigmoid Module. Apply the hard sigmoid function:
+ Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
+ Default: Hsigmoid(x) = min(max((x + 3) / 6, 0), 1)
+
+ Note:
+ In MMCV v1.4.4, we modified the default value of args to align with
+ PyTorch official.
+
+ Args:
+ bias (float): Bias of the input feature map. Default: 3.0.
+ divisor (float): Divisor of the input feature map. Default: 6.0.
+ min_value (float): Lower bound value. Default: 0.0.
+ max_value (float): Upper bound value. Default: 1.0.
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self,
+ bias: float = 3.0,
+ divisor: float = 6.0,
+ min_value: float = 0.0,
+ max_value: float = 1.0):
+ super().__init__()
+ warnings.warn(
+ 'In MMCV v1.4.4, we modified the default value of args to align '
+ 'with PyTorch official. Previous Implementation: '
+ 'Hsigmoid(x) = min(max((x + 1) / 2, 0), 1). '
+ 'Current Implementation: '
+ 'Hsigmoid(x) = min(max((x + 3) / 6, 0), 1).')
+ self.bias = bias
+ self.divisor = divisor
+ assert self.divisor != 0
+ self.min_value = min_value
+ self.max_value = max_value
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = (x + self.bias) / self.divisor
+
+ return x.clamp_(self.min_value, self.max_value)
diff --git a/mmcv/mmcv/cnn/bricks/hswish.py b/mmcv/mmcv/cnn/bricks/hswish.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f6cc276c10a5c49bd9c0e30a1ffad4a1b6018d4
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/hswish.py
@@ -0,0 +1,39 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from mmcv.utils import TORCH_VERSION, digit_version
+from .registry import ACTIVATION_LAYERS
+
+
+class HSwish(nn.Module):
+ """Hard Swish Module.
+
+ This module applies the hard swish function:
+
+ .. math::
+ Hswish(x) = x * ReLU6(x + 3) / 6
+
+ Args:
+ inplace (bool): can optionally do the operation in-place.
+ Default: False.
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self, inplace: bool = False):
+ super().__init__()
+ self.act = nn.ReLU6(inplace)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x * self.act(x + 3) / 6
+
+
+if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.7')):
+ # Hardswish is not supported when PyTorch version < 1.6.
+ # And Hardswish in PyTorch 1.6 does not support inplace.
+ ACTIVATION_LAYERS.register_module(module=HSwish)
+else:
+ ACTIVATION_LAYERS.register_module(module=nn.Hardswish, name='HSwish')
diff --git a/mmcv/mmcv/cnn/bricks/non_local.py b/mmcv/mmcv/cnn/bricks/non_local.py
new file mode 100644
index 0000000000000000000000000000000000000000..159db245e80950d9b94e2744361bca2a09e67c13
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/non_local.py
@@ -0,0 +1,308 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta
+from typing import Dict, Optional
+
+import torch
+import torch.nn as nn
+
+from ..utils import constant_init, normal_init
+from .conv_module import ConvModule
+from .registry import PLUGIN_LAYERS
+
+
+class _NonLocalNd(nn.Module, metaclass=ABCMeta):
+ """Basic Non-local module.
+
+ This module is proposed in
+ "Non-local Neural Networks"
+ Paper reference: https://arxiv.org/abs/1711.07971
+ Code reference: https://github.com/AlexHex7/Non-local_pytorch
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ reduction (int): Channel reduction ratio. Default: 2.
+ use_scale (bool): Whether to scale pairwise_weight by
+ `1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
+ Default: True.
+ conv_cfg (None | dict): The config dict for convolution layers.
+ If not specified, it will use `nn.Conv2d` for convolution layers.
+ Default: None.
+ norm_cfg (None | dict): The config dict for normalization layers.
+ Default: None. (This parameter is only applicable to conv_out.)
+ mode (str): Options are `gaussian`, `concatenation`,
+ `embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
+ """
+
+ def __init__(self,
+ in_channels: int,
+ reduction: int = 2,
+ use_scale: bool = True,
+ conv_cfg: Optional[Dict] = None,
+ norm_cfg: Optional[Dict] = None,
+ mode: str = 'embedded_gaussian',
+ **kwargs):
+ super().__init__()
+ self.in_channels = in_channels
+ self.reduction = reduction
+ self.use_scale = use_scale
+ self.inter_channels = max(in_channels // reduction, 1)
+ self.mode = mode
+
+ if mode not in [
+ 'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
+ ]:
+ raise ValueError("Mode should be in 'gaussian', 'concatenation', "
+ f"'embedded_gaussian' or 'dot_product', but got "
+ f'{mode} instead.')
+
+ # g, theta, phi are defaulted as `nn.ConvNd`.
+ # Here we use ConvModule for potential usage.
+ self.g = ConvModule(
+ self.in_channels,
+ self.inter_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ act_cfg=None) # type: ignore
+ self.conv_out = ConvModule(
+ self.inter_channels,
+ self.in_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ if self.mode != 'gaussian':
+ self.theta = ConvModule(
+ self.in_channels,
+ self.inter_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ act_cfg=None)
+ self.phi = ConvModule(
+ self.in_channels,
+ self.inter_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ act_cfg=None)
+
+ if self.mode == 'concatenation':
+ self.concat_project = ConvModule(
+ self.inter_channels * 2,
+ 1,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ act_cfg=dict(type='ReLU'))
+
+ self.init_weights(**kwargs)
+
+ def init_weights(self, std: float = 0.01, zeros_init: bool = True) -> None:
+ if self.mode != 'gaussian':
+ for m in [self.g, self.theta, self.phi]:
+ normal_init(m.conv, std=std)
+ else:
+ normal_init(self.g.conv, std=std)
+ if zeros_init:
+ if self.conv_out.norm_cfg is None:
+ constant_init(self.conv_out.conv, 0)
+ else:
+ constant_init(self.conv_out.norm, 0)
+ else:
+ if self.conv_out.norm_cfg is None:
+ normal_init(self.conv_out.conv, std=std)
+ else:
+ normal_init(self.conv_out.norm, std=std)
+
+ def gaussian(self, theta_x: torch.Tensor,
+ phi_x: torch.Tensor) -> torch.Tensor:
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
+ return pairwise_weight
+
+ def embedded_gaussian(self, theta_x: torch.Tensor,
+ phi_x: torch.Tensor) -> torch.Tensor:
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ if self.use_scale:
+ # theta_x.shape[-1] is `self.inter_channels`
+ pairwise_weight /= theta_x.shape[-1]**0.5
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
+ return pairwise_weight
+
+ def dot_product(self, theta_x: torch.Tensor,
+ phi_x: torch.Tensor) -> torch.Tensor:
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ pairwise_weight /= pairwise_weight.shape[-1]
+ return pairwise_weight
+
+ def concatenation(self, theta_x: torch.Tensor,
+ phi_x: torch.Tensor) -> torch.Tensor:
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ h = theta_x.size(2)
+ w = phi_x.size(3)
+ theta_x = theta_x.repeat(1, 1, 1, w)
+ phi_x = phi_x.repeat(1, 1, h, 1)
+
+ concat_feature = torch.cat([theta_x, phi_x], dim=1)
+ pairwise_weight = self.concat_project(concat_feature)
+ n, _, h, w = pairwise_weight.size()
+ pairwise_weight = pairwise_weight.view(n, h, w)
+ pairwise_weight /= pairwise_weight.shape[-1]
+
+ return pairwise_weight
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Assume `reduction = 1`, then `inter_channels = C`
+ # or `inter_channels = C` when `mode="gaussian"`
+
+ # NonLocal1d x: [N, C, H]
+ # NonLocal2d x: [N, C, H, W]
+ # NonLocal3d x: [N, C, T, H, W]
+ n = x.size(0)
+
+ # NonLocal1d g_x: [N, H, C]
+ # NonLocal2d g_x: [N, HxW, C]
+ # NonLocal3d g_x: [N, TxHxW, C]
+ g_x = self.g(x).view(n, self.inter_channels, -1)
+ g_x = g_x.permute(0, 2, 1)
+
+ # NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
+ # NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
+ # NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
+ if self.mode == 'gaussian':
+ theta_x = x.view(n, self.in_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ if self.sub_sample:
+ phi_x = self.phi(x).view(n, self.in_channels, -1)
+ else:
+ phi_x = x.view(n, self.in_channels, -1)
+ elif self.mode == 'concatenation':
+ theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
+ else:
+ theta_x = self.theta(x).view(n, self.inter_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, -1)
+
+ pairwise_func = getattr(self, self.mode)
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = pairwise_func(theta_x, phi_x)
+
+ # NonLocal1d y: [N, H, C]
+ # NonLocal2d y: [N, HxW, C]
+ # NonLocal3d y: [N, TxHxW, C]
+ y = torch.matmul(pairwise_weight, g_x)
+ # NonLocal1d y: [N, C, H]
+ # NonLocal2d y: [N, C, H, W]
+ # NonLocal3d y: [N, C, T, H, W]
+ y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
+ *x.size()[2:])
+
+ output = x + self.conv_out(y)
+
+ return output
+
+
+class NonLocal1d(_NonLocalNd):
+ """1D Non-local module.
+
+ Args:
+ in_channels (int): Same as `NonLocalND`.
+ sub_sample (bool): Whether to apply max pooling after pairwise
+ function (Note that the `sub_sample` is applied on spatial only).
+ Default: False.
+ conv_cfg (None | dict): Same as `NonLocalND`.
+ Default: dict(type='Conv1d').
+ """
+
+ def __init__(self,
+ in_channels: int,
+ sub_sample: bool = False,
+ conv_cfg: Dict = dict(type='Conv1d'),
+ **kwargs):
+ super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
+
+ self.sub_sample = sub_sample
+
+ if sub_sample:
+ max_pool_layer = nn.MaxPool1d(kernel_size=2)
+ self.g = nn.Sequential(self.g, max_pool_layer)
+ if self.mode != 'gaussian':
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
+ else:
+ self.phi = max_pool_layer
+
+
+@PLUGIN_LAYERS.register_module()
+class NonLocal2d(_NonLocalNd):
+ """2D Non-local module.
+
+ Args:
+ in_channels (int): Same as `NonLocalND`.
+ sub_sample (bool): Whether to apply max pooling after pairwise
+ function (Note that the `sub_sample` is applied on spatial only).
+ Default: False.
+ conv_cfg (None | dict): Same as `NonLocalND`.
+ Default: dict(type='Conv2d').
+ """
+
+ _abbr_ = 'nonlocal_block'
+
+ def __init__(self,
+ in_channels: int,
+ sub_sample: bool = False,
+ conv_cfg: Dict = dict(type='Conv2d'),
+ **kwargs):
+ super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
+
+ self.sub_sample = sub_sample
+
+ if sub_sample:
+ max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
+ self.g = nn.Sequential(self.g, max_pool_layer)
+ if self.mode != 'gaussian':
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
+ else:
+ self.phi = max_pool_layer
+
+
+class NonLocal3d(_NonLocalNd):
+ """3D Non-local module.
+
+ Args:
+ in_channels (int): Same as `NonLocalND`.
+ sub_sample (bool): Whether to apply max pooling after pairwise
+ function (Note that the `sub_sample` is applied on spatial only).
+ Default: False.
+ conv_cfg (None | dict): Same as `NonLocalND`.
+ Default: dict(type='Conv3d').
+ """
+
+ def __init__(self,
+ in_channels: int,
+ sub_sample: bool = False,
+ conv_cfg: Dict = dict(type='Conv3d'),
+ **kwargs):
+ super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
+ self.sub_sample = sub_sample
+
+ if sub_sample:
+ max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
+ self.g = nn.Sequential(self.g, max_pool_layer)
+ if self.mode != 'gaussian':
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
+ else:
+ self.phi = max_pool_layer
diff --git a/mmcv/mmcv/cnn/bricks/norm.py b/mmcv/mmcv/cnn/bricks/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6281a7c697483fbdaaba5a37d88a00f3c259d31
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/norm.py
@@ -0,0 +1,148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+from typing import Dict, Tuple, Union
+
+import torch.nn as nn
+
+from mmcv.utils import is_tuple_of
+from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm
+from .registry import NORM_LAYERS
+
+NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
+NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
+NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
+NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
+NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm)
+NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
+NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
+NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
+NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d)
+NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d)
+NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d)
+
+
+def infer_abbr(class_type):
+ """Infer abbreviation from the class name.
+
+ When we build a norm layer with `build_norm_layer()`, we want to preserve
+ the norm type in variable names, e.g, self.bn1, self.gn. This method will
+ infer the abbreviation to map class types to abbreviations.
+
+ Rule 1: If the class has the property "_abbr_", return the property.
+ Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
+ InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
+ "in" respectively.
+ Rule 3: If the class name contains "batch", "group", "layer" or "instance",
+ the abbreviation of this layer will be "bn", "gn", "ln" and "in"
+ respectively.
+ Rule 4: Otherwise, the abbreviation falls back to "norm".
+
+ Args:
+ class_type (type): The norm layer type.
+
+ Returns:
+ str: The inferred abbreviation.
+ """
+ if not inspect.isclass(class_type):
+ raise TypeError(
+ f'class_type must be a type, but got {type(class_type)}')
+ if hasattr(class_type, '_abbr_'):
+ return class_type._abbr_
+ if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
+ return 'in'
+ elif issubclass(class_type, _BatchNorm):
+ return 'bn'
+ elif issubclass(class_type, nn.GroupNorm):
+ return 'gn'
+ elif issubclass(class_type, nn.LayerNorm):
+ return 'ln'
+ else:
+ class_name = class_type.__name__.lower()
+ if 'batch' in class_name:
+ return 'bn'
+ elif 'group' in class_name:
+ return 'gn'
+ elif 'layer' in class_name:
+ return 'ln'
+ elif 'instance' in class_name:
+ return 'in'
+ else:
+ return 'norm_layer'
+
+
+def build_norm_layer(cfg: Dict,
+ num_features: int,
+ postfix: Union[int, str] = '') -> Tuple[str, nn.Module]:
+ """Build normalization layer.
+
+ Args:
+ cfg (dict): The norm layer config, which should contain:
+
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate a norm layer.
+ - requires_grad (bool, optional): Whether stop gradient updates.
+ num_features (int): Number of input channels.
+ postfix (int | str): The postfix to be appended into norm abbreviation
+ to create named layer.
+
+ Returns:
+ tuple[str, nn.Module]: The first element is the layer name consisting
+ of abbreviation and postfix, e.g., bn1, gn. The second element is the
+ created norm layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in NORM_LAYERS:
+ raise KeyError(f'Unrecognized norm type {layer_type}')
+
+ norm_layer = NORM_LAYERS.get(layer_type)
+ abbr = infer_abbr(norm_layer)
+
+ assert isinstance(postfix, (int, str))
+ name = abbr + str(postfix)
+
+ requires_grad = cfg_.pop('requires_grad', True)
+ cfg_.setdefault('eps', 1e-5)
+ if layer_type != 'GN':
+ layer = norm_layer(num_features, **cfg_)
+ if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
+ layer._specify_ddp_gpu_num(1)
+ else:
+ assert 'num_groups' in cfg_
+ layer = norm_layer(num_channels=num_features, **cfg_)
+
+ for param in layer.parameters():
+ param.requires_grad = requires_grad
+
+ return name, layer
+
+
+def is_norm(layer: nn.Module,
+ exclude: Union[type, tuple, None] = None) -> bool:
+ """Check if a layer is a normalization layer.
+
+ Args:
+ layer (nn.Module): The layer to be checked.
+ exclude (type | tuple[type]): Types to be excluded.
+
+ Returns:
+ bool: Whether the layer is a norm layer.
+ """
+ if exclude is not None:
+ if not isinstance(exclude, tuple):
+ exclude = (exclude, )
+ if not is_tuple_of(exclude, type):
+ raise TypeError(
+ f'"exclude" must be either None or type or a tuple of types, '
+ f'but got {type(exclude)}: {exclude}')
+
+ if exclude and isinstance(layer, exclude):
+ return False
+
+ all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
+ return isinstance(layer, all_norm_bases)
diff --git a/mmcv/mmcv/cnn/bricks/padding.py b/mmcv/mmcv/cnn/bricks/padding.py
new file mode 100644
index 0000000000000000000000000000000000000000..8412b0c6576fd220eca52382943ad5889f0dfd1f
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/padding.py
@@ -0,0 +1,38 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict
+
+import torch.nn as nn
+
+from .registry import PADDING_LAYERS
+
+PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d)
+PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
+PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
+
+
+def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
+ """Build padding layer.
+
+ Args:
+ cfg (dict): The padding layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate a padding layer.
+
+ Returns:
+ nn.Module: Created padding layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+
+ cfg_ = cfg.copy()
+ padding_type = cfg_.pop('type')
+ if padding_type not in PADDING_LAYERS:
+ raise KeyError(f'Unrecognized padding type {padding_type}.')
+ else:
+ padding_layer = PADDING_LAYERS.get(padding_type)
+
+ layer = padding_layer(*args, **kwargs, **cfg_)
+
+ return layer
diff --git a/mmcv/mmcv/cnn/bricks/plugin.py b/mmcv/mmcv/cnn/bricks/plugin.py
new file mode 100644
index 0000000000000000000000000000000000000000..095ef9234501d0bca54373d4422244b80f818341
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/plugin.py
@@ -0,0 +1,94 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import platform
+from typing import Dict, Tuple, Union
+
+import torch.nn as nn
+
+from .registry import PLUGIN_LAYERS
+
+if platform.system() == 'Windows':
+ import regex as re # type: ignore
+else:
+ import re # type: ignore
+
+
+def infer_abbr(class_type: type) -> str:
+ """Infer abbreviation from the class name.
+
+ This method will infer the abbreviation to map class types to
+ abbreviations.
+
+ Rule 1: If the class has the property "abbr", return the property.
+ Rule 2: Otherwise, the abbreviation falls back to snake case of class
+ name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
+
+ Args:
+ class_type (type): The norm layer type.
+
+ Returns:
+ str: The inferred abbreviation.
+ """
+
+ def camel2snack(word):
+ """Convert camel case word into snack case.
+
+ Modified from `inflection lib
+ `_.
+
+ Example::
+
+ >>> camel2snack("FancyBlock")
+ 'fancy_block'
+ """
+
+ word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
+ word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
+ word = word.replace('-', '_')
+ return word.lower()
+
+ if not inspect.isclass(class_type):
+ raise TypeError(
+ f'class_type must be a type, but got {type(class_type)}')
+ if hasattr(class_type, '_abbr_'):
+ return class_type._abbr_ # type: ignore
+ else:
+ return camel2snack(class_type.__name__)
+
+
+def build_plugin_layer(cfg: Dict,
+ postfix: Union[int, str] = '',
+ **kwargs) -> Tuple[str, nn.Module]:
+ """Build plugin layer.
+
+ Args:
+ cfg (dict): cfg should contain:
+
+ - type (str): identify plugin layer type.
+ - layer args: args needed to instantiate a plugin layer.
+ postfix (int, str): appended into norm abbreviation to
+ create named layer. Default: ''.
+
+ Returns:
+ tuple[str, nn.Module]: The first one is the concatenation of
+ abbreviation and postfix. The second is the created plugin layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in PLUGIN_LAYERS:
+ raise KeyError(f'Unrecognized plugin type {layer_type}')
+
+ plugin_layer = PLUGIN_LAYERS.get(layer_type)
+ abbr = infer_abbr(plugin_layer)
+
+ assert isinstance(postfix, (int, str))
+ name = abbr + str(postfix)
+
+ layer = plugin_layer(**kwargs, **cfg_)
+
+ return name, layer
diff --git a/mmcv/mmcv/cnn/bricks/registry.py b/mmcv/mmcv/cnn/bricks/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..c29279776dd523e706b6af8f9b9de700bed05ba7
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/registry.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.utils import Registry
+
+CONV_LAYERS = Registry('conv layer')
+NORM_LAYERS = Registry('norm layer')
+ACTIVATION_LAYERS = Registry('activation layer')
+PADDING_LAYERS = Registry('padding layer')
+UPSAMPLE_LAYERS = Registry('upsample layer')
+PLUGIN_LAYERS = Registry('plugin layer')
+
+DROPOUT_LAYERS = Registry('drop out layers')
+POSITIONAL_ENCODING = Registry('position encoding')
+ATTENTION = Registry('attention')
+FEEDFORWARD_NETWORK = Registry('feed-forward Network')
+TRANSFORMER_LAYER = Registry('transformerLayer')
+TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
diff --git a/mmcv/mmcv/cnn/bricks/scale.py b/mmcv/mmcv/cnn/bricks/scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbd07c6a445e116bd6f32c96d8b52079ccf9b28a
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/scale.py
@@ -0,0 +1,21 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+
+class Scale(nn.Module):
+ """A learnable scale parameter.
+
+ This layer scales the input by a learnable factor. It multiplies a
+ learnable scale parameter of shape (1,) with input of any shape.
+
+ Args:
+ scale (float): Initial value of scale factor. Default: 1.0
+ """
+
+ def __init__(self, scale: float = 1.0):
+ super().__init__()
+ self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x * self.scale
diff --git a/mmcv/mmcv/cnn/bricks/swish.py b/mmcv/mmcv/cnn/bricks/swish.py
new file mode 100644
index 0000000000000000000000000000000000000000..b297adff068661859265a5057c1b2204ac8eefa7
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/swish.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from .registry import ACTIVATION_LAYERS
+
+
+@ACTIVATION_LAYERS.register_module()
+class Swish(nn.Module):
+ """Swish Module.
+
+ This module applies the swish function:
+
+ .. math::
+ Swish(x) = x * Sigmoid(x)
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x * torch.sigmoid(x)
diff --git a/mmcv/mmcv/cnn/bricks/transformer.py b/mmcv/mmcv/cnn/bricks/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7ba4d9f836609cec8526607db98c4b03ec4fee3
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/transformer.py
@@ -0,0 +1,944 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import math
+import warnings
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
+ build_norm_layer)
+from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
+from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
+ to_2tuple)
+from .drop import build_dropout
+from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
+ TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
+
+# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
+try:
+ from mmcv.ops.multi_scale_deform_attn import \
+ MultiScaleDeformableAttention # noqa F401
+ warnings.warn(
+ ImportWarning(
+ '``MultiScaleDeformableAttention`` has been moved to '
+ '``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501
+ '``from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501
+ 'to ``from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501
+ ))
+
+except ImportError:
+ warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
+ '``mmcv.ops.multi_scale_deform_attn``, '
+ 'You should install ``mmcv-full`` if you need this module. ')
+
+
+def build_positional_encoding(cfg, default_args=None):
+ """Builder for Position Encoding."""
+ return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
+
+
+def build_attention(cfg, default_args=None):
+ """Builder for attention."""
+ return build_from_cfg(cfg, ATTENTION, default_args)
+
+
+def build_feedforward_network(cfg, default_args=None):
+ """Builder for feed-forward network (FFN)."""
+ return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
+
+
+def build_transformer_layer(cfg, default_args=None):
+ """Builder for transformer layer."""
+ return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
+
+
+def build_transformer_layer_sequence(cfg, default_args=None):
+ """Builder for transformer encoder and transformer decoder."""
+ return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
+
+
+class AdaptivePadding(nn.Module):
+ """Applies padding adaptively to the input.
+
+ This module can make input get fully covered by filter
+ you specified. It support two modes "same" and "corner". The
+ "same" mode is same with "SAME" padding mode in TensorFlow, pad
+ zero around input. The "corner" mode would pad zero
+ to bottom right.
+
+ Args:
+ kernel_size (int | tuple): Size of the kernel. Default: 1.
+ stride (int | tuple): Stride of the filter. Default: 1.
+ dilation (int | tuple): Spacing between kernel elements.
+ Default: 1.
+ padding (str): Support "same" and "corner", "corner" mode
+ would pad zero to bottom right, and "same" mode would
+ pad zero around input. Default: "corner".
+
+ Example:
+ >>> kernel_size = 16
+ >>> stride = 16
+ >>> dilation = 1
+ >>> input = torch.rand(1, 1, 15, 17)
+ >>> adap_pad = AdaptivePadding(
+ >>> kernel_size=kernel_size,
+ >>> stride=stride,
+ >>> dilation=dilation,
+ >>> padding="corner")
+ >>> out = adap_pad(input)
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
+ >>> input = torch.rand(1, 1, 16, 17)
+ >>> out = adap_pad(input)
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
+ """
+
+ def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
+ super().__init__()
+ assert padding in ('same', 'corner')
+
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ dilation = to_2tuple(dilation)
+
+ self.padding = padding
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.dilation = dilation
+
+ def get_pad_shape(self, input_shape):
+ """Calculate the padding size of input.
+
+ Args:
+ input_shape (:obj:`torch.Size`): arrange as (H, W).
+
+ Returns:
+ Tuple[int]: The padding size along the
+ original H and W directions
+ """
+ input_h, input_w = input_shape
+ kernel_h, kernel_w = self.kernel_size
+ stride_h, stride_w = self.stride
+ output_h = math.ceil(input_h / stride_h)
+ output_w = math.ceil(input_w / stride_w)
+ pad_h = max((output_h - 1) * stride_h +
+ (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
+ pad_w = max((output_w - 1) * stride_w +
+ (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
+ return pad_h, pad_w
+
+ def forward(self, x):
+ """Add padding to `x`
+
+ Args:
+ x (Tensor): Input tensor has shape (B, C, H, W).
+
+ Returns:
+ Tensor: The tensor with adaptive padding
+ """
+ pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
+ if pad_h > 0 or pad_w > 0:
+ if self.padding == 'corner':
+ x = F.pad(x, [0, pad_w, 0, pad_h])
+ elif self.padding == 'same':
+ x = F.pad(x, [
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
+ pad_h - pad_h // 2
+ ])
+ return x
+
+
+class PatchEmbed(BaseModule):
+ """Image to Patch Embedding.
+
+ We use a conv layer to implement PatchEmbed.
+
+ Args:
+ in_channels (int): The num of input channels. Default: 3
+ embed_dims (int): The dimensions of embedding. Default: 768
+ conv_type (str): The type of convolution
+ to generate patch embedding. Default: "Conv2d".
+ kernel_size (int): The kernel_size of embedding conv. Default: 16.
+ stride (int): The slide stride of embedding conv.
+ Default: 16.
+ padding (int | tuple | string): The padding length of
+ embedding conv. When it is a string, it means the mode
+ of adaptive padding, support "same" and "corner" now.
+ Default: "corner".
+ dilation (int): The dilation rate of embedding conv. Default: 1.
+ bias (bool): Bias of embed conv. Default: True.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: None.
+ input_size (int | tuple | None): The size of input, which will be
+ used to calculate the out size. Only works when `dynamic_size`
+ is False. Default: None.
+ init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels=3,
+ embed_dims=768,
+ conv_type='Conv2d',
+ kernel_size=16,
+ stride=16,
+ padding='corner',
+ dilation=1,
+ bias=True,
+ norm_cfg=None,
+ input_size=None,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ self.embed_dims = embed_dims
+ if stride is None:
+ stride = kernel_size
+
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ dilation = to_2tuple(dilation)
+
+ if isinstance(padding, str):
+ self.adaptive_padding = AdaptivePadding(
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=padding)
+ # disable the padding of conv
+ padding = 0
+ else:
+ self.adaptive_padding = None
+ padding = to_2tuple(padding)
+
+ self.projection = build_conv_layer(
+ dict(type=conv_type),
+ in_channels=in_channels,
+ out_channels=embed_dims,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias)
+
+ if norm_cfg is not None:
+ self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
+ else:
+ self.norm = None
+
+ if input_size:
+ input_size = to_2tuple(input_size)
+ # `init_out_size` would be used outside to
+ # calculate the num_patches
+ # e.g. when `use_abs_pos_embed` outside
+ self.init_input_size = input_size
+ if self.adaptive_padding:
+ pad_h, pad_w = self.adaptive_padding.get_pad_shape(input_size)
+ input_h, input_w = input_size
+ input_h = input_h + pad_h
+ input_w = input_w + pad_w
+ input_size = (input_h, input_w)
+
+ # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
+ h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
+ (kernel_size[0] - 1) - 1) // stride[0] + 1
+ w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
+ (kernel_size[1] - 1) - 1) // stride[1] + 1
+ self.init_out_size = (h_out, w_out)
+ else:
+ self.init_input_size = None
+ self.init_out_size = None
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
+
+ Returns:
+ tuple: Contains merged results and its spatial shape.
+
+ - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
+ - out_size (tuple[int]): Spatial shape of x, arrange as
+ (out_h, out_w).
+ """
+
+ if self.adaptive_padding:
+ x = self.adaptive_padding(x)
+
+ x = self.projection(x)
+ out_size = (x.shape[2], x.shape[3])
+ x = x.flatten(2).transpose(1, 2)
+ if self.norm is not None:
+ x = self.norm(x)
+ return x, out_size
+
+
+class PatchMerging(BaseModule):
+ """Merge patch feature map.
+
+ This layer groups feature map by kernel_size, and applies norm and linear
+ layers to the grouped feature map ((used in Swin Transformer)).
+ Our implementation uses `nn.Unfold` to
+ merge patches, which is about 25% faster than the original
+ implementation. However, we need to modify pretrained
+ models for compatibility.
+
+ Args:
+ in_channels (int): The num of input channels.
+ to gets fully covered by filter and stride you specified.
+ out_channels (int): The num of output channels.
+ kernel_size (int | tuple, optional): the kernel size in the unfold
+ layer. Defaults to 2.
+ stride (int | tuple, optional): the stride of the sliding blocks in the
+ unfold layer. Default: None. (Would be set as `kernel_size`)
+ padding (int | tuple | string ): The padding length of
+ embedding conv. When it is a string, it means the mode
+ of adaptive padding, support "same" and "corner" now.
+ Default: "corner".
+ dilation (int | tuple, optional): dilation parameter in the unfold
+ layer. Default: 1.
+ bias (bool, optional): Whether to add bias in linear layer or not.
+ Defaults: False.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: dict(type='LN').
+ init_cfg (dict, optional): The extra config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=2,
+ stride=None,
+ padding='corner',
+ dilation=1,
+ bias=False,
+ norm_cfg=dict(type='LN'),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ if stride:
+ stride = stride
+ else:
+ stride = kernel_size
+
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ dilation = to_2tuple(dilation)
+
+ if isinstance(padding, str):
+ self.adaptive_padding = AdaptivePadding(
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=padding)
+ # disable the padding of unfold
+ padding = 0
+ else:
+ self.adaptive_padding = None
+
+ padding = to_2tuple(padding)
+ self.sampler = nn.Unfold(
+ kernel_size=kernel_size,
+ dilation=dilation,
+ padding=padding,
+ stride=stride)
+
+ sample_dim = kernel_size[0] * kernel_size[1] * in_channels
+
+ if norm_cfg is not None:
+ self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
+ else:
+ self.norm = None
+
+ self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
+
+ def forward(self, x, input_size):
+ """
+ Args:
+ x (Tensor): Has shape (B, H*W, C_in).
+ input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
+ Default: None.
+
+ Returns:
+ tuple: Contains merged results and its spatial shape.
+
+ - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
+ - out_size (tuple[int]): Spatial shape of x, arrange as
+ (Merged_H, Merged_W).
+ """
+ B, L, C = x.shape
+ assert isinstance(input_size, Sequence), f'Expect ' \
+ f'input_size is ' \
+ f'`Sequence` ' \
+ f'but get {input_size}'
+
+ H, W = input_size
+ assert L == H * W, 'input feature has wrong size'
+
+ x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
+
+ if self.adaptive_padding:
+ x = self.adaptive_padding(x)
+ H, W = x.shape[-2:]
+
+ # Use nn.Unfold to merge patch. About 25% faster than original method,
+ # but need to modify pretrained model for compatibility
+ # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
+ x = self.sampler(x)
+
+ out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
+ (self.sampler.kernel_size[0] - 1) -
+ 1) // self.sampler.stride[0] + 1
+ out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
+ (self.sampler.kernel_size[1] - 1) -
+ 1) // self.sampler.stride[1] + 1
+
+ output_size = (out_h, out_w)
+ x = x.transpose(1, 2) # B, H/2*W/2, 4*C
+ x = self.norm(x) if self.norm else x
+ x = self.reduction(x)
+ return x, output_size
+
+
+@ATTENTION.register_module()
+class MultiheadAttention(BaseModule):
+ """A wrapper for ``torch.nn.MultiheadAttention``.
+
+ This module implements MultiheadAttention with identity connection,
+ and positional encoding is also passed as input.
+
+ Args:
+ embed_dims (int): The embedding dimension.
+ num_heads (int): Parallel attention heads.
+ attn_drop (float): A Dropout layer on attn_output_weights.
+ Default: 0.0.
+ proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
+ Default: 0.0.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ batch_first (bool): When it is True, Key, Query and Value are shape of
+ (batch, n, embed_dim), otherwise (n, batch, embed_dim).
+ Default to False.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ attn_drop=0.,
+ proj_drop=0.,
+ dropout_layer=dict(type='Dropout', drop_prob=0.),
+ init_cfg=None,
+ batch_first=False,
+ **kwargs):
+ super().__init__(init_cfg)
+ if 'dropout' in kwargs:
+ warnings.warn(
+ 'The arguments `dropout` in MultiheadAttention '
+ 'has been deprecated, now you can separately '
+ 'set `attn_drop`(float), proj_drop(float), '
+ 'and `dropout_layer`(dict) ', DeprecationWarning)
+ attn_drop = kwargs['dropout']
+ dropout_layer['drop_prob'] = kwargs.pop('dropout')
+
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.batch_first = batch_first
+
+ self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
+ **kwargs)
+
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else nn.Identity()
+
+ @deprecated_api_warning({'residual': 'identity'},
+ cls_name='MultiheadAttention')
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ identity=None,
+ query_pos=None,
+ key_pos=None,
+ attn_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `MultiheadAttention`.
+
+ **kwargs allow passing a more general data flow when combining
+ with other operations in `transformerlayer`.
+
+ Args:
+ query (Tensor): The input query with shape [num_queries, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ key (Tensor): The key tensor with shape [num_keys, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_keys, embed_dims] .
+ If None, the ``query`` will be used. Defaults to None.
+ value (Tensor): The value tensor with same shape as `key`.
+ Same in `nn.MultiheadAttention.forward`. Defaults to None.
+ If None, the `key` will be used.
+ identity (Tensor): This tensor, with the same shape as x,
+ will be used for the identity link.
+ If None, `x` will be used. Defaults to None.
+ query_pos (Tensor): The positional encoding for query, with
+ the same shape as `x`. If not None, it will
+ be added to `x` before forward function. Defaults to None.
+ key_pos (Tensor): The positional encoding for `key`, with the
+ same shape as `key`. Defaults to None. If not None, it will
+ be added to `key` before forward function. If None, and
+ `query_pos` has the same shape as `key`, then `query_pos`
+ will be used for `key_pos`. Defaults to None.
+ attn_mask (Tensor): ByteTensor mask with shape [num_queries,
+ num_keys]. Same in `nn.MultiheadAttention.forward`.
+ Defaults to None.
+ key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
+ Defaults to None.
+
+ Returns:
+ Tensor: forwarded results with shape
+ [num_queries, bs, embed_dims]
+ if self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ """
+
+ if key is None:
+ key = query
+ if value is None:
+ value = key
+ if identity is None:
+ identity = query
+ if key_pos is None:
+ if query_pos is not None:
+ # use query_pos if key_pos is not available
+ if query_pos.shape == key.shape:
+ key_pos = query_pos
+ else:
+ warnings.warn(f'position encoding of key is'
+ f'missing in {self.__class__.__name__}.')
+ if query_pos is not None:
+ query = query + query_pos
+ if key_pos is not None:
+ key = key + key_pos
+
+ # Because the dataflow('key', 'query', 'value') of
+ # ``torch.nn.MultiheadAttention`` is (num_query, batch,
+ # embed_dims), We should adjust the shape of dataflow from
+ # batch_first (batch, num_query, embed_dims) to num_query_first
+ # (num_query ,batch, embed_dims), and recover ``attn_output``
+ # from num_query_first to batch_first.
+ if self.batch_first:
+ query = query.transpose(0, 1)
+ key = key.transpose(0, 1)
+ value = value.transpose(0, 1)
+
+ out = self.attn(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask)[0]
+
+ if self.batch_first:
+ out = out.transpose(0, 1)
+
+ return identity + self.dropout_layer(self.proj_drop(out))
+
+
+@FEEDFORWARD_NETWORK.register_module()
+class FFN(BaseModule):
+ """Implements feed-forward networks (FFNs) with identity connection.
+
+ Args:
+ embed_dims (int): The feature dimension. Same as
+ `MultiheadAttention`. Defaults: 256.
+ feedforward_channels (int): The hidden dimension of FFNs.
+ Defaults: 1024.
+ num_fcs (int, optional): The number of fully-connected layers in
+ FFNs. Default: 2.
+ act_cfg (dict, optional): The activation config for FFNs.
+ Default: dict(type='ReLU')
+ ffn_drop (float, optional): Probability of an element to be
+ zeroed in FFN. Default 0.0.
+ add_identity (bool, optional): Whether to add the
+ identity connection. Default: `True`.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ @deprecated_api_warning(
+ {
+ 'dropout': 'ffn_drop',
+ 'add_residual': 'add_identity'
+ },
+ cls_name='FFN')
+ def __init__(self,
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_drop=0.,
+ dropout_layer=None,
+ add_identity=True,
+ init_cfg=None,
+ **kwargs):
+ super().__init__(init_cfg)
+ assert num_fcs >= 2, 'num_fcs should be no less ' \
+ f'than 2. got {num_fcs}.'
+ self.embed_dims = embed_dims
+ self.feedforward_channels = feedforward_channels
+ self.num_fcs = num_fcs
+ self.act_cfg = act_cfg
+ self.activate = build_activation_layer(act_cfg)
+
+ layers = []
+ in_channels = embed_dims
+ for _ in range(num_fcs - 1):
+ layers.append(
+ Sequential(
+ Linear(in_channels, feedforward_channels), self.activate,
+ nn.Dropout(ffn_drop)))
+ in_channels = feedforward_channels
+ layers.append(Linear(feedforward_channels, embed_dims))
+ layers.append(nn.Dropout(ffn_drop))
+ self.layers = Sequential(*layers)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else torch.nn.Identity()
+ self.add_identity = add_identity
+
+ @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
+ def forward(self, x, identity=None):
+ """Forward function for `FFN`.
+
+ The function would add x to the output tensor if residue is None.
+ """
+ out = self.layers(x)
+ if not self.add_identity:
+ return self.dropout_layer(out)
+ if identity is None:
+ identity = x
+ return identity + self.dropout_layer(out)
+
+
+@TRANSFORMER_LAYER.register_module()
+class BaseTransformerLayer(BaseModule):
+ """Base `TransformerLayer` for vision transformer.
+
+ It can be built from `mmcv.ConfigDict` and support more flexible
+ customization, for example, using any number of `FFN or LN ` and
+ use different kinds of `attention` by specifying a list of `ConfigDict`
+ named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
+ when you specifying `norm` as the first element of `operation_order`.
+ More details about the `prenorm`: `On Layer Normalization in the
+ Transformer Architecture `_ .
+
+ Args:
+ attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
+ Configs for `self_attention` or `cross_attention` modules,
+ The order of the configs in the list should be consistent with
+ corresponding attentions in operation_order.
+ If it is a dict, all of the attention modules in operation_order
+ will be built with this config. Default: None.
+ ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
+ Configs for FFN, The order of the configs in the list should be
+ consistent with corresponding ffn in operation_order.
+ If it is a dict, all of the attention modules in operation_order
+ will be built with this config.
+ operation_order (tuple[str]): The execution order of operation
+ in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
+ Support `prenorm` when you specifying first element as `norm`.
+ Default:None.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ batch_first (bool): Key, Query and Value are shape
+ of (batch, n, embed_dim)
+ or (n, batch, embed_dim). Default to False.
+ """
+
+ def __init__(self,
+ attn_cfgs=None,
+ ffn_cfgs=dict(
+ type='FFN',
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ ffn_drop=0.,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ),
+ operation_order=None,
+ norm_cfg=dict(type='LN'),
+ init_cfg=None,
+ batch_first=False,
+ **kwargs):
+
+ deprecated_args = dict(
+ feedforward_channels='feedforward_channels',
+ ffn_dropout='ffn_drop',
+ ffn_num_fcs='num_fcs')
+ for ori_name, new_name in deprecated_args.items():
+ if ori_name in kwargs:
+ warnings.warn(
+ f'The arguments `{ori_name}` in BaseTransformerLayer '
+ f'has been deprecated, now you should set `{new_name}` '
+ f'and other FFN related arguments '
+ f'to a dict named `ffn_cfgs`. ', DeprecationWarning)
+ ffn_cfgs[new_name] = kwargs[ori_name]
+
+ super().__init__(init_cfg)
+
+ self.batch_first = batch_first
+
+ assert set(operation_order) & {
+ 'self_attn', 'norm', 'ffn', 'cross_attn'} == \
+ set(operation_order), f'The operation_order of' \
+ f' {self.__class__.__name__} should ' \
+ f'contains all four operation type ' \
+ f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
+
+ num_attn = operation_order.count('self_attn') + operation_order.count(
+ 'cross_attn')
+ if isinstance(attn_cfgs, dict):
+ attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
+ else:
+ assert num_attn == len(attn_cfgs), f'The length ' \
+ f'of attn_cfg {num_attn} is ' \
+ f'not consistent with the number of attention' \
+ f'in operation_order {operation_order}.'
+
+ self.num_attn = num_attn
+ self.operation_order = operation_order
+ self.norm_cfg = norm_cfg
+ self.pre_norm = operation_order[0] == 'norm'
+ self.attentions = ModuleList()
+
+ index = 0
+ for operation_name in operation_order:
+ if operation_name in ['self_attn', 'cross_attn']:
+ if 'batch_first' in attn_cfgs[index]:
+ assert self.batch_first == attn_cfgs[index]['batch_first']
+ else:
+ attn_cfgs[index]['batch_first'] = self.batch_first
+ attention = build_attention(attn_cfgs[index])
+ # Some custom attentions used as `self_attn`
+ # or `cross_attn` can have different behavior.
+ attention.operation_name = operation_name
+ self.attentions.append(attention)
+ index += 1
+
+ self.embed_dims = self.attentions[0].embed_dims
+
+ self.ffns = ModuleList()
+ num_ffns = operation_order.count('ffn')
+ if isinstance(ffn_cfgs, dict):
+ ffn_cfgs = ConfigDict(ffn_cfgs)
+ if isinstance(ffn_cfgs, dict):
+ ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
+ assert len(ffn_cfgs) == num_ffns
+ for ffn_index in range(num_ffns):
+ if 'embed_dims' not in ffn_cfgs[ffn_index]:
+ ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims
+ else:
+ assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
+ self.ffns.append(
+ build_feedforward_network(ffn_cfgs[ffn_index],
+ dict(type='FFN')))
+
+ self.norms = ModuleList()
+ num_norms = operation_order.count('norm')
+ for _ in range(num_norms):
+ self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
+
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ query_pos=None,
+ key_pos=None,
+ attn_masks=None,
+ query_key_padding_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `TransformerDecoderLayer`.
+
+ **kwargs contains some specific arguments of attentions.
+
+ Args:
+ query (Tensor): The input query with shape
+ [num_queries, bs, embed_dims] if
+ self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ key (Tensor): The key tensor with shape [num_keys, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_keys, embed_dims] .
+ value (Tensor): The value tensor with same shape as `key`.
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ key_pos (Tensor): The positional encoding for `key`.
+ Default: None.
+ attn_masks (List[Tensor] | None): 2D Tensor used in
+ calculation of corresponding attention. The length of
+ it should equal to the number of `attention` in
+ `operation_order`. Default: None.
+ query_key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_queries]. Only used in `self_attn` layer.
+ Defaults to None.
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_keys]. Default: None.
+
+ Returns:
+ Tensor: forwarded results with shape [num_queries, bs, embed_dims].
+ """
+
+ norm_index = 0
+ attn_index = 0
+ ffn_index = 0
+ identity = query
+ if attn_masks is None:
+ attn_masks = [None for _ in range(self.num_attn)]
+ elif isinstance(attn_masks, torch.Tensor):
+ attn_masks = [
+ copy.deepcopy(attn_masks) for _ in range(self.num_attn)
+ ]
+ warnings.warn(f'Use same attn_mask in all attentions in '
+ f'{self.__class__.__name__} ')
+ else:
+ assert len(attn_masks) == self.num_attn, f'The length of ' \
+ f'attn_masks {len(attn_masks)} must be equal ' \
+ f'to the number of attention in ' \
+ f'operation_order {self.num_attn}'
+
+ for layer in self.operation_order:
+ if layer == 'self_attn':
+ temp_key = temp_value = query
+ query = self.attentions[attn_index](
+ query,
+ temp_key,
+ temp_value,
+ identity if self.pre_norm else None,
+ query_pos=query_pos,
+ key_pos=query_pos,
+ attn_mask=attn_masks[attn_index],
+ key_padding_mask=query_key_padding_mask,
+ **kwargs)
+ attn_index += 1
+ identity = query
+
+ elif layer == 'norm':
+ query = self.norms[norm_index](query)
+ norm_index += 1
+
+ elif layer == 'cross_attn':
+ query = self.attentions[attn_index](
+ query,
+ key,
+ value,
+ identity if self.pre_norm else None,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ attn_mask=attn_masks[attn_index],
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ attn_index += 1
+ identity = query
+
+ elif layer == 'ffn':
+ query = self.ffns[ffn_index](
+ query, identity if self.pre_norm else None)
+ ffn_index += 1
+
+ return query
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class TransformerLayerSequence(BaseModule):
+ """Base class for TransformerEncoder and TransformerDecoder in vision
+ transformer.
+
+ As base-class of Encoder and Decoder in vision transformer.
+ Support customization such as specifying different kind
+ of `transformer_layer` in `transformer_coder`.
+
+ Args:
+ transformerlayer (list[obj:`mmcv.ConfigDict`] |
+ obj:`mmcv.ConfigDict`): Config of transformerlayer
+ in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
+ it would be repeated `num_layer` times to a
+ list[`mmcv.ConfigDict`]. Default: None.
+ num_layers (int): The number of `TransformerLayer`. Default: None.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
+ super().__init__(init_cfg)
+ if isinstance(transformerlayers, dict):
+ transformerlayers = [
+ copy.deepcopy(transformerlayers) for _ in range(num_layers)
+ ]
+ else:
+ assert isinstance(transformerlayers, list) and \
+ len(transformerlayers) == num_layers
+ self.num_layers = num_layers
+ self.layers = ModuleList()
+ for i in range(num_layers):
+ self.layers.append(build_transformer_layer(transformerlayers[i]))
+ self.embed_dims = self.layers[0].embed_dims
+ self.pre_norm = self.layers[0].pre_norm
+
+ def forward(self,
+ query,
+ key,
+ value,
+ query_pos=None,
+ key_pos=None,
+ attn_masks=None,
+ query_key_padding_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `TransformerCoder`.
+
+ Args:
+ query (Tensor): Input query with shape
+ `(num_queries, bs, embed_dims)`.
+ key (Tensor): The key tensor with shape
+ `(num_keys, bs, embed_dims)`.
+ value (Tensor): The value tensor with shape
+ `(num_keys, bs, embed_dims)`.
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ key_pos (Tensor): The positional encoding for `key`.
+ Default: None.
+ attn_masks (List[Tensor], optional): Each element is 2D Tensor
+ which is used in calculation of corresponding attention in
+ operation_order. Default: None.
+ query_key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_queries]. Only used in self-attention
+ Default: None.
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_keys]. Default: None.
+
+ Returns:
+ Tensor: results with shape [num_queries, bs, embed_dims].
+ """
+ for layer in self.layers:
+ query = layer(
+ query,
+ key,
+ value,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ attn_masks=attn_masks,
+ query_key_padding_mask=query_key_padding_mask,
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ return query
diff --git a/mmcv/mmcv/cnn/bricks/upsample.py b/mmcv/mmcv/cnn/bricks/upsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..d86c5f54a22ed26b09f66bd59659ff7ab1f5b3d9
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/upsample.py
@@ -0,0 +1,87 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..utils import xavier_init
+from .registry import UPSAMPLE_LAYERS
+
+UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
+UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
+
+
+@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
+class PixelShufflePack(nn.Module):
+ """Pixel Shuffle upsample layer.
+
+ This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to
+ achieve a simple upsampling with pixel shuffle.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ scale_factor (int): Upsample ratio.
+ upsample_kernel (int): Kernel size of the conv layer to expand the
+ channels.
+ """
+
+ def __init__(self, in_channels: int, out_channels: int, scale_factor: int,
+ upsample_kernel: int):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.scale_factor = scale_factor
+ self.upsample_kernel = upsample_kernel
+ self.upsample_conv = nn.Conv2d(
+ self.in_channels,
+ self.out_channels * scale_factor * scale_factor,
+ self.upsample_kernel,
+ padding=(self.upsample_kernel - 1) // 2)
+ self.init_weights()
+
+ def init_weights(self):
+ xavier_init(self.upsample_conv, distribution='uniform')
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.upsample_conv(x)
+ x = F.pixel_shuffle(x, self.scale_factor)
+ return x
+
+
+def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
+ """Build upsample layer.
+
+ Args:
+ cfg (dict): The upsample layer config, which should contain:
+
+ - type (str): Layer type.
+ - scale_factor (int): Upsample ratio, which is not applicable to
+ deconv.
+ - layer args: Args needed to instantiate a upsample layer.
+ args (argument list): Arguments passed to the ``__init__``
+ method of the corresponding conv layer.
+ kwargs (keyword arguments): Keyword arguments passed to the
+ ``__init__`` method of the corresponding conv layer.
+
+ Returns:
+ nn.Module: Created upsample layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+ if 'type' not in cfg:
+ raise KeyError(
+ f'the cfg dict must contain the key "type", but got {cfg}')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in UPSAMPLE_LAYERS:
+ raise KeyError(f'Unrecognized upsample type {layer_type}')
+ else:
+ upsample = UPSAMPLE_LAYERS.get(layer_type)
+
+ if upsample is nn.Upsample:
+ cfg_['mode'] = layer_type
+ layer = upsample(*args, **kwargs, **cfg_)
+ return layer
diff --git a/mmcv/mmcv/cnn/bricks/wrappers.py b/mmcv/mmcv/cnn/bricks/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a07eff00e49970c7692ee3f2625c7f7aba9d7b22
--- /dev/null
+++ b/mmcv/mmcv/cnn/bricks/wrappers.py
@@ -0,0 +1,180 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py # noqa: E501
+
+Wrap some nn modules to support empty tensor input. Currently, these wrappers
+are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask
+heads are trained on only positive RoIs.
+"""
+import math
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.utils import _pair, _triple
+
+from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
+
+if torch.__version__ == 'parrots':
+ TORCH_VERSION = torch.__version__
+else:
+ # torch.__version__ could be 1.3.1+cu92, we only need the first two
+ # for comparison
+ TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
+
+
+def obsolete_torch_version(torch_version, version_threshold) -> bool:
+ return torch_version == 'parrots' or torch_version <= version_threshold
+
+
+class NewEmptyTensorOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x: torch.Tensor, new_shape: tuple) -> torch.Tensor:
+ ctx.shape = x.shape
+ return x.new_empty(new_shape)
+
+ @staticmethod
+ def backward(ctx, grad: torch.Tensor) -> tuple:
+ shape = ctx.shape
+ return NewEmptyTensorOp.apply(grad, shape), None
+
+
+@CONV_LAYERS.register_module('Conv', force=True)
+class Conv2d(nn.Conv2d):
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
+ self.padding, self.stride, self.dilation):
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+@CONV_LAYERS.register_module('Conv3d', force=True)
+class Conv3d(nn.Conv3d):
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
+ self.padding, self.stride, self.dilation):
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+@CONV_LAYERS.register_module()
+@CONV_LAYERS.register_module('deconv')
+@UPSAMPLE_LAYERS.register_module('deconv', force=True)
+class ConvTranspose2d(nn.ConvTranspose2d):
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
+ self.padding, self.stride,
+ self.dilation, self.output_padding):
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+@CONV_LAYERS.register_module()
+@CONV_LAYERS.register_module('deconv3d')
+@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
+class ConvTranspose3d(nn.ConvTranspose3d):
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
+ self.padding, self.stride,
+ self.dilation, self.output_padding):
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+class MaxPool2d(nn.MaxPool2d):
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # PyTorch 1.9 does not support empty tensor inference yet
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+ out_shape = list(x.shape[:2])
+ for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
+ _pair(self.padding), _pair(self.stride),
+ _pair(self.dilation)):
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ return empty
+
+ return super().forward(x)
+
+
+class MaxPool3d(nn.MaxPool3d):
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # PyTorch 1.9 does not support empty tensor inference yet
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+ out_shape = list(x.shape[:2])
+ for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
+ _triple(self.padding),
+ _triple(self.stride),
+ _triple(self.dilation)):
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ return empty
+
+ return super().forward(x)
+
+
+class Linear(torch.nn.Linear):
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # empty tensor forward of Linear layer is supported in Pytorch 1.6
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
+ out_shape = [x.shape[0], self.out_features]
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
diff --git a/mmcv/mmcv/cnn/builder.py b/mmcv/mmcv/cnn/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7567316c566bd3aca6d8f65a84b00e9e890948a7
--- /dev/null
+++ b/mmcv/mmcv/cnn/builder.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..runner import Sequential
+from ..utils import Registry, build_from_cfg
+
+
+def build_model_from_cfg(cfg, registry, default_args=None):
+ """Build a PyTorch model from config dict(s). Different from
+ ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
+
+ Args:
+ cfg (dict, list[dict]): The config of modules, is is either a config
+ dict or a list of config dicts. If cfg is a list, a
+ the built modules will be wrapped with ``nn.Sequential``.
+ registry (:obj:`Registry`): A registry the module belongs to.
+ default_args (dict, optional): Default arguments to build the module.
+ Defaults to None.
+
+ Returns:
+ nn.Module: A built nn module.
+ """
+ if isinstance(cfg, list):
+ modules = [
+ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
+ ]
+ return Sequential(*modules)
+ else:
+ return build_from_cfg(cfg, registry, default_args)
+
+
+MODELS = Registry('model', build_func=build_model_from_cfg)
diff --git a/mmcv/mmcv/cnn/resnet.py b/mmcv/mmcv/cnn/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb29e6256280b671acfbf73fd9a01f079749b260
--- /dev/null
+++ b/mmcv/mmcv/cnn/resnet.py
@@ -0,0 +1,322 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+from typing import Optional, Sequence, Tuple, Union
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from torch import Tensor
+
+from .utils import constant_init, kaiming_init
+
+
+def conv3x3(in_planes: int,
+ out_planes: int,
+ stride: int = 1,
+ dilation: int = 1):
+ """3x3 convolution with padding."""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ dilation: int = 1,
+ downsample: Optional[nn.Module] = None,
+ style: str = 'pytorch',
+ with_cp: bool = False):
+ super().__init__()
+ assert style in ['pytorch', 'caffe']
+ self.conv1 = conv3x3(inplanes, planes, stride, dilation)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ assert not with_cp
+
+ def forward(self, x: Tensor) -> Tensor:
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ dilation: int = 1,
+ downsample: Optional[nn.Module] = None,
+ style: str = 'pytorch',
+ with_cp: bool = False):
+ """Bottleneck block.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super().__init__()
+ assert style in ['pytorch', 'caffe']
+ if style == 'pytorch':
+ conv1_stride = 1
+ conv2_stride = stride
+ else:
+ conv1_stride = stride
+ conv2_stride = 1
+ self.conv1 = nn.Conv2d(
+ inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False)
+ self.conv2 = nn.Conv2d(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(
+ planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ def forward(self, x: Tensor) -> Tensor:
+
+ def _inner_forward(x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+def make_res_layer(block: nn.Module,
+ inplanes: int,
+ planes: int,
+ blocks: int,
+ stride: int = 1,
+ dilation: int = 1,
+ style: str = 'pytorch',
+ with_cp: bool = False) -> nn.Module:
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ stride,
+ dilation,
+ downsample,
+ style=style,
+ with_cp=with_cp))
+ inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp))
+
+ return nn.Sequential(*layers)
+
+
+class ResNet(nn.Module):
+ """ResNet backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ num_stages (int): Resnet stages, normally 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
+ running stats (mean and var).
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ """
+
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ depth: int,
+ num_stages: int = 4,
+ strides: Sequence[int] = (1, 2, 2, 2),
+ dilations: Sequence[int] = (1, 1, 1, 1),
+ out_indices: Sequence[int] = (0, 1, 2, 3),
+ style: str = 'pytorch',
+ frozen_stages: int = -1,
+ bn_eval: bool = True,
+ bn_frozen: bool = False,
+ with_cp: bool = False):
+ super().__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+ assert num_stages >= 1 and num_stages <= 4
+ block, stage_blocks = self.arch_settings[depth]
+ stage_blocks = stage_blocks[:num_stages] # type: ignore
+ assert len(strides) == len(dilations) == num_stages
+ assert max(out_indices) < num_stages
+
+ self.out_indices = out_indices
+ self.style = style
+ self.frozen_stages = frozen_stages
+ self.bn_eval = bn_eval
+ self.bn_frozen = bn_frozen
+ self.with_cp = with_cp
+
+ self.inplanes: int = 64
+ self.conv1 = nn.Conv2d(
+ 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.res_layers = []
+ for i, num_blocks in enumerate(stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ planes = 64 * 2**i
+ res_layer = make_res_layer(
+ block,
+ self.inplanes,
+ planes,
+ num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ with_cp=with_cp)
+ self.inplanes = planes * block.expansion # type: ignore
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self.feat_dim = block.expansion * 64 * 2**( # type: ignore
+ len(stage_blocks) - 1)
+
+ def init_weights(self, pretrained: Optional[str] = None) -> None:
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]:
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+ def train(self, mode: bool = True) -> None:
+ super().train(mode)
+ if self.bn_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ if self.bn_frozen:
+ for params in m.parameters():
+ params.requires_grad = False
+ if mode and self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for param in self.bn1.parameters():
+ param.requires_grad = False
+ self.bn1.eval()
+ self.bn1.weight.requires_grad = False
+ self.bn1.bias.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ mod = getattr(self, f'layer{i}')
+ mod.eval()
+ for param in mod.parameters():
+ param.requires_grad = False
diff --git a/mmcv/mmcv/cnn/utils/__init__.py b/mmcv/mmcv/cnn/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a263e31c1e3977712827ca229bbc04910b4e928e
--- /dev/null
+++ b/mmcv/mmcv/cnn/utils/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .flops_counter import get_model_complexity_info
+from .fuse_conv_bn import fuse_conv_bn
+from .sync_bn import revert_sync_batchnorm
+from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
+ KaimingInit, NormalInit, PretrainedInit,
+ TruncNormalInit, UniformInit, XavierInit,
+ bias_init_with_prob, caffe2_xavier_init,
+ constant_init, initialize, kaiming_init, normal_init,
+ trunc_normal_init, uniform_init, xavier_init)
+
+__all__ = [
+ 'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
+ 'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
+ 'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
+ 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+ 'Caffe2XavierInit', 'revert_sync_batchnorm'
+]
diff --git a/mmcv/mmcv/cnn/utils/flops_counter.py b/mmcv/mmcv/cnn/utils/flops_counter.py
new file mode 100644
index 0000000000000000000000000000000000000000..150a55992a9561073626d26df503ba4ef37efa18
--- /dev/null
+++ b/mmcv/mmcv/cnn/utils/flops_counter.py
@@ -0,0 +1,603 @@
+# Modified from flops-counter.pytorch by Vladislav Sovrasov
+# original repo: https://github.com/sovrasov/flops-counter.pytorch
+
+# MIT License
+
+# Copyright (c) 2018 Vladislav Sovrasov
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import sys
+import warnings
+from functools import partial
+from typing import Any, Callable, Dict, Optional, TextIO, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+import mmcv
+
+
+def get_model_complexity_info(model: nn.Module,
+ input_shape: tuple,
+ print_per_layer_stat: bool = True,
+ as_strings: bool = True,
+ input_constructor: Optional[Callable] = None,
+ flush: bool = False,
+ ost: TextIO = sys.stdout) -> tuple:
+ """Get complexity information of a model.
+
+ This method can calculate FLOPs and parameter counts of a model with
+ corresponding input shape. It can also print complexity information for
+ each layer in a model.
+
+ Supported layers are listed as below:
+ - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
+ - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``,
+ ``nn.LeakyReLU``, ``nn.ReLU6``.
+ - Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
+ ``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
+ ``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
+ ``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
+ ``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
+ - BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
+ ``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``,
+ ``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``.
+ - Linear: ``nn.Linear``.
+ - Deconvolution: ``nn.ConvTranspose2d``.
+ - Upsample: ``nn.Upsample``.
+
+ Args:
+ model (nn.Module): The model for complexity calculation.
+ input_shape (tuple): Input shape used for calculation.
+ print_per_layer_stat (bool): Whether to print complexity information
+ for each layer in a model. Default: True.
+ as_strings (bool): Output FLOPs and params counts in a string form.
+ Default: True.
+ input_constructor (None | callable): If specified, it takes a callable
+ method that generates input. otherwise, it will generate a random
+ tensor with input shape to calculate FLOPs. Default: None.
+ flush (bool): same as that in :func:`print`. Default: False.
+ ost (stream): same as ``file`` param in :func:`print`.
+ Default: sys.stdout.
+
+ Returns:
+ tuple[float | str]: If ``as_strings`` is set to True, it will return
+ FLOPs and parameter counts in a string format. otherwise, it will
+ return those in a float number format.
+ """
+ assert type(input_shape) is tuple
+ assert len(input_shape) >= 1
+ assert isinstance(model, nn.Module)
+ flops_model = add_flops_counting_methods(model)
+ flops_model.eval()
+ flops_model.start_flops_count()
+ if input_constructor:
+ input = input_constructor(input_shape)
+ _ = flops_model(**input)
+ else:
+ try:
+ batch = torch.ones(()).new_empty(
+ (1, *input_shape),
+ dtype=next(flops_model.parameters()).dtype,
+ device=next(flops_model.parameters()).device)
+ except StopIteration:
+ # Avoid StopIteration for models which have no parameters,
+ # like `nn.Relu()`, `nn.AvgPool2d`, etc.
+ batch = torch.ones(()).new_empty((1, *input_shape))
+
+ _ = flops_model(batch)
+
+ flops_count, params_count = flops_model.compute_average_flops_cost()
+ if print_per_layer_stat:
+ print_model_with_flops(
+ flops_model, flops_count, params_count, ost=ost, flush=flush)
+ flops_model.stop_flops_count()
+
+ if as_strings:
+ return flops_to_string(flops_count), params_to_string(params_count)
+
+ return flops_count, params_count
+
+
+def flops_to_string(flops: float,
+ units: Optional[str] = 'GFLOPs',
+ precision: int = 2) -> str:
+ """Convert FLOPs number into a string.
+
+ Note that Here we take a multiply-add counts as one FLOP.
+
+ Args:
+ flops (float): FLOPs number to be converted.
+ units (str | None): Converted FLOPs units. Options are None, 'GFLOPs',
+ 'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically
+ choose the most suitable unit for FLOPs. Default: 'GFLOPs'.
+ precision (int): Digit number after the decimal point. Default: 2.
+
+ Returns:
+ str: The converted FLOPs number with units.
+
+ Examples:
+ >>> flops_to_string(1e9)
+ '1.0 GFLOPs'
+ >>> flops_to_string(2e5, 'MFLOPs')
+ '0.2 MFLOPs'
+ >>> flops_to_string(3e-9, None)
+ '3e-09 FLOPs'
+ """
+ if units is None:
+ if flops // 10**9 > 0:
+ return str(round(flops / 10.**9, precision)) + ' GFLOPs'
+ elif flops // 10**6 > 0:
+ return str(round(flops / 10.**6, precision)) + ' MFLOPs'
+ elif flops // 10**3 > 0:
+ return str(round(flops / 10.**3, precision)) + ' KFLOPs'
+ else:
+ return str(flops) + ' FLOPs'
+ else:
+ if units == 'GFLOPs':
+ return str(round(flops / 10.**9, precision)) + ' ' + units
+ elif units == 'MFLOPs':
+ return str(round(flops / 10.**6, precision)) + ' ' + units
+ elif units == 'KFLOPs':
+ return str(round(flops / 10.**3, precision)) + ' ' + units
+ else:
+ return str(flops) + ' FLOPs'
+
+
+def params_to_string(num_params: float,
+ units: Optional[str] = None,
+ precision: int = 2) -> str:
+ """Convert parameter number into a string.
+
+ Args:
+ num_params (float): Parameter number to be converted.
+ units (str | None): Converted FLOPs units. Options are None, 'M',
+ 'K' and ''. If set to None, it will automatically choose the most
+ suitable unit for Parameter number. Default: None.
+ precision (int): Digit number after the decimal point. Default: 2.
+
+ Returns:
+ str: The converted parameter number with units.
+
+ Examples:
+ >>> params_to_string(1e9)
+ '1000.0 M'
+ >>> params_to_string(2e5)
+ '200.0 k'
+ >>> params_to_string(3e-9)
+ '3e-09'
+ """
+ if units is None:
+ if num_params // 10**6 > 0:
+ return str(round(num_params / 10**6, precision)) + ' M'
+ elif num_params // 10**3:
+ return str(round(num_params / 10**3, precision)) + ' k'
+ else:
+ return str(num_params)
+ else:
+ if units == 'M':
+ return str(round(num_params / 10.**6, precision)) + ' ' + units
+ elif units == 'K':
+ return str(round(num_params / 10.**3, precision)) + ' ' + units
+ else:
+ return str(num_params)
+
+
+def print_model_with_flops(model: nn.Module,
+ total_flops: float,
+ total_params: float,
+ units: Optional[str] = 'GFLOPs',
+ precision: int = 3,
+ ost: TextIO = sys.stdout,
+ flush: bool = False) -> None:
+ """Print a model with FLOPs for each layer.
+
+ Args:
+ model (nn.Module): The model to be printed.
+ total_flops (float): Total FLOPs of the model.
+ total_params (float): Total parameter counts of the model.
+ units (str | None): Converted FLOPs units. Default: 'GFLOPs'.
+ precision (int): Digit number after the decimal point. Default: 3.
+ ost (stream): same as `file` param in :func:`print`.
+ Default: sys.stdout.
+ flush (bool): same as that in :func:`print`. Default: False.
+
+ Example:
+ >>> class ExampleModel(nn.Module):
+
+ >>> def __init__(self):
+ >>> super().__init__()
+ >>> self.conv1 = nn.Conv2d(3, 8, 3)
+ >>> self.conv2 = nn.Conv2d(8, 256, 3)
+ >>> self.conv3 = nn.Conv2d(256, 8, 3)
+ >>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
+ >>> self.flatten = nn.Flatten()
+ >>> self.fc = nn.Linear(8, 1)
+
+ >>> def forward(self, x):
+ >>> x = self.conv1(x)
+ >>> x = self.conv2(x)
+ >>> x = self.conv3(x)
+ >>> x = self.avg_pool(x)
+ >>> x = self.flatten(x)
+ >>> x = self.fc(x)
+ >>> return x
+
+ >>> model = ExampleModel()
+ >>> x = (3, 16, 16)
+ to print the complexity information state for each layer, you can use
+ >>> get_model_complexity_info(model, x)
+ or directly use
+ >>> print_model_with_flops(model, 4579784.0, 37361)
+ ExampleModel(
+ 0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs,
+ (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501
+ (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
+ (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
+ (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
+ (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
+ (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
+ )
+ """
+
+ def accumulate_params(self):
+ if is_supported_instance(self):
+ return self.__params__
+ else:
+ sum = 0
+ for m in self.children():
+ sum += m.accumulate_params()
+ return sum
+
+ def accumulate_flops(self):
+ if is_supported_instance(self):
+ return self.__flops__ / model.__batch_counter__
+ else:
+ sum = 0
+ for m in self.children():
+ sum += m.accumulate_flops()
+ return sum
+
+ def flops_repr(self):
+ accumulated_num_params = self.accumulate_params()
+ accumulated_flops_cost = self.accumulate_flops()
+ return ', '.join([
+ params_to_string(
+ accumulated_num_params, units='M', precision=precision),
+ f'{accumulated_num_params / total_params:.3%} Params',
+ flops_to_string(
+ accumulated_flops_cost, units=units, precision=precision),
+ f'{accumulated_flops_cost / total_flops:.3%} FLOPs',
+ self.original_extra_repr()
+ ])
+
+ def add_extra_repr(m):
+ m.accumulate_flops = accumulate_flops.__get__(m)
+ m.accumulate_params = accumulate_params.__get__(m)
+ flops_extra_repr = flops_repr.__get__(m)
+ if m.extra_repr != flops_extra_repr:
+ m.original_extra_repr = m.extra_repr
+ m.extra_repr = flops_extra_repr
+ assert m.extra_repr != m.original_extra_repr
+
+ def del_extra_repr(m):
+ if hasattr(m, 'original_extra_repr'):
+ m.extra_repr = m.original_extra_repr
+ del m.original_extra_repr
+ if hasattr(m, 'accumulate_flops'):
+ del m.accumulate_flops
+
+ model.apply(add_extra_repr)
+ print(model, file=ost, flush=flush)
+ model.apply(del_extra_repr)
+
+
+def get_model_parameters_number(model: nn.Module) -> float:
+ """Calculate parameter number of a model.
+
+ Args:
+ model (nn.module): The model for parameter number calculation.
+
+ Returns:
+ float: Parameter number of the model.
+ """
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return num_params
+
+
+def add_flops_counting_methods(net_main_module: nn.Module) -> nn.Module:
+ # adding additional methods to the existing module object,
+ # this is done this way so that each function has access to self object
+ net_main_module.start_flops_count = start_flops_count.__get__( # type: ignore # noqa E501
+ net_main_module)
+ net_main_module.stop_flops_count = stop_flops_count.__get__( # type: ignore # noqa E501
+ net_main_module)
+ net_main_module.reset_flops_count = reset_flops_count.__get__( # type: ignore # noqa E501
+ net_main_module)
+ net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # type: ignore # noqa E501
+ net_main_module)
+
+ net_main_module.reset_flops_count()
+
+ return net_main_module
+
+
+def compute_average_flops_cost(self) -> Tuple[float, float]:
+ """Compute average FLOPs cost.
+
+ A method to compute average FLOPs cost, which will be available after
+ `add_flops_counting_methods()` is called on a desired net object.
+
+ Returns:
+ float: Current mean flops consumption per image.
+ """
+ batches_count = self.__batch_counter__
+ flops_sum = 0
+ for module in self.modules():
+ if is_supported_instance(module):
+ flops_sum += module.__flops__
+ params_sum = get_model_parameters_number(self)
+ return flops_sum / batches_count, params_sum
+
+
+def start_flops_count(self) -> None:
+ """Activate the computation of mean flops consumption per image.
+
+ A method to activate the computation of mean flops consumption per image.
+ which will be available after ``add_flops_counting_methods()`` is called on
+ a desired net object. It should be called before running the network.
+ """
+ add_batch_counter_hook_function(self)
+
+ def add_flops_counter_hook_function(module: nn.Module) -> None:
+ if is_supported_instance(module):
+ if hasattr(module, '__flops_handle__'):
+ return
+
+ else:
+ handle = module.register_forward_hook(
+ get_modules_mapping()[type(module)])
+
+ module.__flops_handle__ = handle
+
+ self.apply(partial(add_flops_counter_hook_function))
+
+
+def stop_flops_count(self) -> None:
+ """Stop computing the mean flops consumption per image.
+
+ A method to stop computing the mean flops consumption per image, which will
+ be available after ``add_flops_counting_methods()`` is called on a desired
+ net object. It can be called to pause the computation whenever.
+ """
+ remove_batch_counter_hook_function(self)
+ self.apply(remove_flops_counter_hook_function)
+
+
+def reset_flops_count(self) -> None:
+ """Reset statistics computed so far.
+
+ A method to Reset computed statistics, which will be available after
+ `add_flops_counting_methods()` is called on a desired net object.
+ """
+ add_batch_counter_variables_or_reset(self)
+ self.apply(add_flops_counter_variable_or_reset)
+
+
+# ---- Internal functions
+def empty_flops_counter_hook(module: nn.Module, input: tuple,
+ output: Any) -> None:
+ module.__flops__ += 0
+
+
+def upsample_flops_counter_hook(module: nn.Module, input: tuple,
+ output: torch.Tensor) -> None:
+ output_size = output[0]
+ batch_size = output_size.shape[0]
+ output_elements_count = batch_size
+ for val in output_size.shape[1:]:
+ output_elements_count *= val
+ module.__flops__ += int(output_elements_count)
+
+
+def relu_flops_counter_hook(module: nn.Module, input: tuple,
+ output: torch.Tensor) -> None:
+ active_elements_count = output.numel()
+ module.__flops__ += int(active_elements_count)
+
+
+def linear_flops_counter_hook(module: nn.Module, input: tuple,
+ output: torch.Tensor) -> None:
+ output_last_dim = output.shape[
+ -1] # pytorch checks dimensions, so here we don't care much
+ module.__flops__ += int(np.prod(input[0].shape) * output_last_dim)
+
+
+def pool_flops_counter_hook(module: nn.Module, input: tuple,
+ output: torch.Tensor) -> None:
+ module.__flops__ += int(np.prod(input[0].shape))
+
+
+def norm_flops_counter_hook(module: nn.Module, input: tuple,
+ output: torch.Tensor) -> None:
+ batch_flops = np.prod(input[0].shape)
+ if (getattr(module, 'affine', False)
+ or getattr(module, 'elementwise_affine', False)):
+ batch_flops *= 2
+ module.__flops__ += int(batch_flops)
+
+
+def deconv_flops_counter_hook(conv_module: nn.Module, input: tuple,
+ output: torch.Tensor) -> None:
+ # Can have multiple inputs, getting the first one
+ batch_size = input[0].shape[0]
+ input_height, input_width = input[0].shape[2:]
+
+ kernel_height, kernel_width = conv_module.kernel_size
+ in_channels = conv_module.in_channels
+ out_channels = conv_module.out_channels
+ groups = conv_module.groups
+
+ filters_per_channel = out_channels // groups
+ conv_per_position_flops = (
+ kernel_height * kernel_width * in_channels * filters_per_channel)
+
+ active_elements_count = batch_size * input_height * input_width
+ overall_conv_flops = conv_per_position_flops * active_elements_count
+ bias_flops = 0
+ if conv_module.bias is not None:
+ output_height, output_width = output.shape[2:]
+ bias_flops = out_channels * batch_size * output_height * output_width
+ overall_flops = overall_conv_flops + bias_flops
+
+ conv_module.__flops__ += int(overall_flops)
+
+
+def conv_flops_counter_hook(conv_module: nn.Module, input: tuple,
+ output: torch.Tensor) -> None:
+ # Can have multiple inputs, getting the first one
+ batch_size = input[0].shape[0]
+ output_dims = list(output.shape[2:])
+
+ kernel_dims = list(conv_module.kernel_size)
+ in_channels = conv_module.in_channels
+ out_channels = conv_module.out_channels
+ groups = conv_module.groups
+
+ filters_per_channel = out_channels // groups
+ conv_per_position_flops = int(
+ np.prod(kernel_dims)) * in_channels * filters_per_channel
+
+ active_elements_count = batch_size * int(np.prod(output_dims))
+
+ overall_conv_flops = conv_per_position_flops * active_elements_count
+
+ bias_flops = 0
+
+ if conv_module.bias is not None:
+
+ bias_flops = out_channels * active_elements_count
+
+ overall_flops = overall_conv_flops + bias_flops
+
+ conv_module.__flops__ += int(overall_flops)
+
+
+def batch_counter_hook(module: nn.Module, input: tuple, output: Any) -> None:
+ batch_size = 1
+ if len(input) > 0:
+ # Can have multiple inputs, getting the first one
+ batch_size = len(input[0])
+ else:
+ warnings.warn('No positional inputs found for a module, '
+ 'assuming batch size is 1.')
+ module.__batch_counter__ += batch_size
+
+
+def add_batch_counter_variables_or_reset(module: nn.Module) -> None:
+
+ module.__batch_counter__ = 0
+
+
+def add_batch_counter_hook_function(module: nn.Module) -> None:
+ if hasattr(module, '__batch_counter_handle__'):
+ return
+
+ handle = module.register_forward_hook(batch_counter_hook)
+ module.__batch_counter_handle__ = handle
+
+
+def remove_batch_counter_hook_function(module: nn.Module) -> None:
+ if hasattr(module, '__batch_counter_handle__'):
+ module.__batch_counter_handle__.remove()
+ del module.__batch_counter_handle__
+
+
+def add_flops_counter_variable_or_reset(module: nn.Module) -> None:
+ if is_supported_instance(module):
+ if hasattr(module, '__flops__') or hasattr(module, '__params__'):
+ warnings.warn('variables __flops__ or __params__ are already '
+ 'defined for the module' + type(module).__name__ +
+ ' ptflops can affect your code!')
+ module.__flops__ = 0
+ module.__params__ = get_model_parameters_number(module)
+
+
+def is_supported_instance(module: nn.Module) -> bool:
+ if type(module) in get_modules_mapping():
+ return True
+ return False
+
+
+def remove_flops_counter_hook_function(module: nn.Module) -> None:
+ if is_supported_instance(module):
+ if hasattr(module, '__flops_handle__'):
+ module.__flops_handle__.remove()
+ del module.__flops_handle__
+
+
+def get_modules_mapping() -> Dict:
+ return {
+ # convolutions
+ nn.Conv1d: conv_flops_counter_hook,
+ nn.Conv2d: conv_flops_counter_hook,
+ mmcv.cnn.bricks.Conv2d: conv_flops_counter_hook,
+ nn.Conv3d: conv_flops_counter_hook,
+ mmcv.cnn.bricks.Conv3d: conv_flops_counter_hook,
+ # activations
+ nn.ReLU: relu_flops_counter_hook,
+ nn.PReLU: relu_flops_counter_hook,
+ nn.ELU: relu_flops_counter_hook,
+ nn.LeakyReLU: relu_flops_counter_hook,
+ nn.ReLU6: relu_flops_counter_hook,
+ # poolings
+ nn.MaxPool1d: pool_flops_counter_hook,
+ nn.AvgPool1d: pool_flops_counter_hook,
+ nn.AvgPool2d: pool_flops_counter_hook,
+ nn.MaxPool2d: pool_flops_counter_hook,
+ mmcv.cnn.bricks.MaxPool2d: pool_flops_counter_hook,
+ nn.MaxPool3d: pool_flops_counter_hook,
+ mmcv.cnn.bricks.MaxPool3d: pool_flops_counter_hook,
+ nn.AvgPool3d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
+ # normalizations
+ nn.BatchNorm1d: norm_flops_counter_hook,
+ nn.BatchNorm2d: norm_flops_counter_hook,
+ nn.BatchNorm3d: norm_flops_counter_hook,
+ nn.GroupNorm: norm_flops_counter_hook,
+ nn.InstanceNorm1d: norm_flops_counter_hook,
+ nn.InstanceNorm2d: norm_flops_counter_hook,
+ nn.InstanceNorm3d: norm_flops_counter_hook,
+ nn.LayerNorm: norm_flops_counter_hook,
+ # FC
+ nn.Linear: linear_flops_counter_hook,
+ mmcv.cnn.bricks.Linear: linear_flops_counter_hook,
+ # Upscale
+ nn.Upsample: upsample_flops_counter_hook,
+ # Deconvolution
+ nn.ConvTranspose2d: deconv_flops_counter_hook,
+ mmcv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook,
+ }
diff --git a/mmcv/mmcv/cnn/utils/fuse_conv_bn.py b/mmcv/mmcv/cnn/utils/fuse_conv_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ccaab3bf1eb3ce615bad910d6dc45a467bb1fe4
--- /dev/null
+++ b/mmcv/mmcv/cnn/utils/fuse_conv_bn.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+
+def _fuse_conv_bn(conv: nn.Module, bn: nn.Module) -> nn.Module:
+ """Fuse conv and bn into one module.
+
+ Args:
+ conv (nn.Module): Conv to be fused.
+ bn (nn.Module): BN to be fused.
+
+ Returns:
+ nn.Module: Fused module.
+ """
+ conv_w = conv.weight
+ conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
+ bn.running_mean)
+
+ factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
+ conv.weight = nn.Parameter(conv_w *
+ factor.reshape([conv.out_channels, 1, 1, 1]))
+ conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
+ return conv
+
+
+def fuse_conv_bn(module: nn.Module) -> nn.Module:
+ """Recursively fuse conv and bn in a module.
+
+ During inference, the functionary of batch norm layers is turned off
+ but only the mean and var alone channels are used, which exposes the
+ chance to fuse it with the preceding conv layers to save computations and
+ simplify network structures.
+
+ Args:
+ module (nn.Module): Module to be fused.
+
+ Returns:
+ nn.Module: Fused module.
+ """
+ last_conv = None
+ last_conv_name = None
+
+ for name, child in module.named_children():
+ if isinstance(child,
+ (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
+ if last_conv is None: # only fuse BN that is after Conv
+ continue
+ fused_conv = _fuse_conv_bn(last_conv, child)
+ module._modules[last_conv_name] = fused_conv
+ # To reduce changes, set BN as Identity instead of deleting it.
+ module._modules[name] = nn.Identity()
+ last_conv = None
+ elif isinstance(child, nn.Conv2d):
+ last_conv = child
+ last_conv_name = name
+ else:
+ fuse_conv_bn(child)
+ return module
diff --git a/mmcv/mmcv/cnn/utils/sync_bn.py b/mmcv/mmcv/cnn/utils/sync_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c534fc0e17506dde31c20529ce7bef64eef87140
--- /dev/null
+++ b/mmcv/mmcv/cnn/utils/sync_bn.py
@@ -0,0 +1,61 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+import mmcv
+
+
+class _BatchNormXd(nn.modules.batchnorm._BatchNorm):
+ """A general BatchNorm layer without input dimension check.
+
+ Reproduced from @kapily's work:
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+ The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
+ is `_check_input_dim` that is designed for tensor sanity checks.
+ The check has been bypassed in this class for the convenience of converting
+ SyncBatchNorm.
+ """
+
+ def _check_input_dim(self, input: torch.Tensor):
+ return
+
+
+def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
+ """Helper function to convert all `SyncBatchNorm` (SyncBN) and
+ `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
+ `BatchNormXd` layers.
+
+ Adapted from @kapily's work:
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+
+ Args:
+ module (nn.Module): The module containing `SyncBatchNorm` layers.
+
+ Returns:
+ module_output: The converted module with `BatchNormXd` layers.
+ """
+ module_output = module
+ module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
+ if hasattr(mmcv, 'ops'):
+ module_checklist.append(mmcv.ops.SyncBatchNorm)
+ if isinstance(module, tuple(module_checklist)):
+ module_output = _BatchNormXd(module.num_features, module.eps,
+ module.momentum, module.affine,
+ module.track_running_stats)
+ if module.affine:
+ # no_grad() may not be needed here but
+ # just to be consistent with `convert_sync_batchnorm()`
+ with torch.no_grad():
+ module_output.weight = module.weight
+ module_output.bias = module.bias
+ module_output.running_mean = module.running_mean
+ module_output.running_var = module.running_var
+ module_output.num_batches_tracked = module.num_batches_tracked
+ module_output.training = module.training
+ # qconfig exists in quantized models
+ if hasattr(module, 'qconfig'):
+ module_output.qconfig = module.qconfig
+ for name, child in module.named_children():
+ module_output.add_module(name, revert_sync_batchnorm(child))
+ del module
+ return module_output
diff --git a/mmcv/mmcv/cnn/utils/weight_init.py b/mmcv/mmcv/cnn/utils/weight_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e0d293ad4fb315462e34d5899ae6fccc4a7ba86
--- /dev/null
+++ b/mmcv/mmcv/cnn/utils/weight_init.py
@@ -0,0 +1,708 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import math
+import warnings
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from mmcv.utils import Registry, build_from_cfg, get_logger, print_log
+
+INITIALIZERS = Registry('initializer')
+
+
+def update_init_info(module: nn.Module, init_info: str) -> None:
+ """Update the `_params_init_info` in the module if the value of parameters
+ are changed.
+
+ Args:
+ module (obj:`nn.Module`): The module of PyTorch with a user-defined
+ attribute `_params_init_info` which records the initialization
+ information.
+ init_info (str): The string that describes the initialization.
+ """
+ assert hasattr(
+ module,
+ '_params_init_info'), f'Can not find `_params_init_info` in {module}'
+ for name, param in module.named_parameters():
+
+ assert param in module._params_init_info, (
+ f'Find a new :obj:`Parameter` '
+ f'named `{name}` during executing the '
+ f'`init_weights` of '
+ f'`{module.__class__.__name__}`. '
+ f'Please do not add or '
+ f'replace parameters during executing '
+ f'the `init_weights`. ')
+
+ # The parameter has been changed during executing the
+ # `init_weights` of module
+ mean_value = param.data.mean()
+ if module._params_init_info[param]['tmp_mean_value'] != mean_value:
+ module._params_init_info[param]['init_info'] = init_info
+ module._params_init_info[param]['tmp_mean_value'] = mean_value
+
+
+def constant_init(module: nn.Module, val: float, bias: float = 0) -> None:
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def xavier_init(module: nn.Module,
+ gain: float = 1,
+ bias: float = 0,
+ distribution: str = 'normal') -> None:
+ assert distribution in ['uniform', 'normal']
+ if hasattr(module, 'weight') and module.weight is not None:
+ if distribution == 'uniform':
+ nn.init.xavier_uniform_(module.weight, gain=gain)
+ else:
+ nn.init.xavier_normal_(module.weight, gain=gain)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def normal_init(module: nn.Module,
+ mean: float = 0,
+ std: float = 1,
+ bias: float = 0) -> None:
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.normal_(module.weight, mean, std)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def trunc_normal_init(module: nn.Module,
+ mean: float = 0,
+ std: float = 1,
+ a: float = -2,
+ b: float = 2,
+ bias: float = 0) -> None:
+ if hasattr(module, 'weight') and module.weight is not None:
+ trunc_normal_(module.weight, mean, std, a, b) # type: ignore
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias) # type: ignore
+
+
+def uniform_init(module: nn.Module,
+ a: float = 0,
+ b: float = 1,
+ bias: float = 0) -> None:
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.uniform_(module.weight, a, b)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def kaiming_init(module: nn.Module,
+ a: float = 0,
+ mode: str = 'fan_out',
+ nonlinearity: str = 'relu',
+ bias: float = 0,
+ distribution: str = 'normal') -> None:
+ assert distribution in ['uniform', 'normal']
+ if hasattr(module, 'weight') and module.weight is not None:
+ if distribution == 'uniform':
+ nn.init.kaiming_uniform_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ else:
+ nn.init.kaiming_normal_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def caffe2_xavier_init(module: nn.Module, bias: float = 0) -> None:
+ # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+ # Acknowledgment to FAIR's internal code
+ kaiming_init(
+ module,
+ a=1,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ bias=bias,
+ distribution='uniform')
+
+
+def bias_init_with_prob(prior_prob: float) -> float:
+ """initialize conv/fc bias value according to a given probability value."""
+ bias_init = float(-np.log((1 - prior_prob) / prior_prob))
+ return bias_init
+
+
+def _get_bases_name(m: nn.Module) -> List[str]:
+ return [b.__name__ for b in m.__class__.__bases__]
+
+
+class BaseInit:
+
+ def __init__(self,
+ *,
+ bias: float = 0,
+ bias_prob: Optional[float] = None,
+ layer: Union[str, List, None] = None):
+ self.wholemodule = False
+ if not isinstance(bias, (int, float)):
+ raise TypeError(f'bias must be a number, but got a {type(bias)}')
+
+ if bias_prob is not None:
+ if not isinstance(bias_prob, float):
+ raise TypeError(f'bias_prob type must be float, \
+ but got {type(bias_prob)}')
+
+ if layer is not None:
+ if not isinstance(layer, (str, list)):
+ raise TypeError(f'layer must be a str or a list of str, \
+ but got a {type(layer)}')
+ else:
+ layer = []
+
+ if bias_prob is not None:
+ self.bias = bias_init_with_prob(bias_prob)
+ else:
+ self.bias = bias
+ self.layer = [layer] if isinstance(layer, str) else layer
+
+ def _get_init_info(self) -> str:
+ info = f'{self.__class__.__name__}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Constant')
+class ConstantInit(BaseInit):
+ """Initialize module parameters with constant values.
+
+ Args:
+ val (int | float): the value to fill the weights in the module with
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self, val: Union[int, float], **kwargs):
+ super().__init__(**kwargs)
+ self.val = val
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ constant_init(m, self.val, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ constant_init(m, self.val, self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self) -> str:
+ info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Xavier')
+class XavierInit(BaseInit):
+ r"""Initialize module parameters with values according to the method
+ described in `Understanding the difficulty of training deep feedforward
+ neural networks - Glorot, X. & Bengio, Y. (2010).
+ `_
+
+ Args:
+ gain (int | float): an optional scaling factor. Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ distribution (str): distribution either be ``'normal'``
+ or ``'uniform'``. Defaults to ``'normal'``.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ gain: float = 1,
+ distribution: str = 'normal',
+ **kwargs):
+ super().__init__(**kwargs)
+ self.gain = gain
+ self.distribution = distribution
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ xavier_init(m, self.gain, self.bias, self.distribution)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ xavier_init(m, self.gain, self.bias, self.distribution)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self) -> str:
+ info = f'{self.__class__.__name__}: gain={self.gain}, ' \
+ f'distribution={self.distribution}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Normal')
+class NormalInit(BaseInit):
+ r"""Initialize module parameters with the values drawn from the normal
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
+
+ Args:
+ mean (int | float):the mean of the normal distribution. Defaults to 0.
+ std (int | float): the standard deviation of the normal distribution.
+ Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+
+ """
+
+ def __init__(self, mean: float = 0, std: float = 1, **kwargs):
+ super().__init__(**kwargs)
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ normal_init(m, self.mean, self.std, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ normal_init(m, self.mean, self.std, self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self) -> str:
+ info = f'{self.__class__.__name__}: mean={self.mean},' \
+ f' std={self.std}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='TruncNormal')
+class TruncNormalInit(BaseInit):
+ r"""Initialize module parameters with the values drawn from the normal
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
+ outside :math:`[a, b]`.
+
+ Args:
+ mean (float): the mean of the normal distribution. Defaults to 0.
+ std (float): the standard deviation of the normal distribution.
+ Defaults to 1.
+ a (float): The minimum cutoff value.
+ b ( float): The maximum cutoff value.
+ bias (float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+
+ """
+
+ def __init__(self,
+ mean: float = 0,
+ std: float = 1,
+ a: float = -2,
+ b: float = 2,
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.mean = mean
+ self.std = std
+ self.a = a
+ self.b = b
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+ self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+ self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
+ f' mean={self.mean}, std={self.std}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Uniform')
+class UniformInit(BaseInit):
+ r"""Initialize module parameters with values drawn from the uniform
+ distribution :math:`\mathcal{U}(a, b)`.
+
+ Args:
+ a (int | float): the lower bound of the uniform distribution.
+ Defaults to 0.
+ b (int | float): the upper bound of the uniform distribution.
+ Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self, a: float = 0., b: float = 1., **kwargs):
+ super().__init__(**kwargs)
+ self.a = a
+ self.b = b
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ uniform_init(m, self.a, self.b, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ uniform_init(m, self.a, self.b, self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self) -> str:
+ info = f'{self.__class__.__name__}: a={self.a},' \
+ f' b={self.b}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Kaiming')
+class KaimingInit(BaseInit):
+ r"""Initialize module parameters with the values according to the method
+ described in `Delving deep into rectifiers: Surpassing human-level
+ performance on ImageNet classification - He, K. et al. (2015).
+ `_
+
+ Args:
+ a (int | float): the negative slope of the rectifier used after this
+ layer (only used with ``'leaky_relu'``). Defaults to 0.
+ mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
+ ``'fan_in'`` preserves the magnitude of the variance of the weights
+ in the forward pass. Choosing ``'fan_out'`` preserves the
+ magnitudes in the backwards pass. Defaults to ``'fan_out'``.
+ nonlinearity (str): the non-linear function (`nn.functional` name),
+ recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
+ Defaults to 'relu'.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ distribution (str): distribution either be ``'normal'`` or
+ ``'uniform'``. Defaults to ``'normal'``.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ a: float = 0,
+ mode: str = 'fan_out',
+ nonlinearity: str = 'relu',
+ distribution: str = 'normal',
+ **kwargs):
+ super().__init__(**kwargs)
+ self.a = a
+ self.mode = mode
+ self.nonlinearity = nonlinearity
+ self.distribution = distribution
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ kaiming_init(m, self.a, self.mode, self.nonlinearity,
+ self.bias, self.distribution)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ kaiming_init(m, self.a, self.mode, self.nonlinearity,
+ self.bias, self.distribution)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self) -> str:
+ info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
+ f'nonlinearity={self.nonlinearity}, ' \
+ f'distribution ={self.distribution}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Caffe2Xavier')
+class Caffe2XavierInit(KaimingInit):
+ # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+ # Acknowledgment to FAIR's internal code
+ def __init__(self, **kwargs):
+ super().__init__(
+ a=1,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ distribution='uniform',
+ **kwargs)
+
+ def __call__(self, module: nn.Module) -> None:
+ super().__call__(module)
+
+
+@INITIALIZERS.register_module(name='Pretrained')
+class PretrainedInit:
+ """Initialize module by loading a pretrained model.
+
+ Args:
+ checkpoint (str): the checkpoint file of the pretrained model should
+ be load.
+ prefix (str, optional): the prefix of a sub-module in the pretrained
+ model. it is for loading a part of the pretrained model to
+ initialize. For example, if we would like to only load the
+ backbone of a detector model, we can set ``prefix='backbone.'``.
+ Defaults to None.
+ map_location (str): map tensors into proper locations.
+ """
+
+ def __init__(self,
+ checkpoint: str,
+ prefix: Optional[str] = None,
+ map_location: Optional[str] = None):
+ self.checkpoint = checkpoint
+ self.prefix = prefix
+ self.map_location = map_location
+
+ def __call__(self, module: nn.Module) -> None:
+ from mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
+ load_state_dict)
+ logger = get_logger('mmcv')
+ if self.prefix is None:
+ print_log(f'load model from: {self.checkpoint}', logger=logger)
+ load_checkpoint(
+ module,
+ self.checkpoint,
+ map_location=self.map_location,
+ strict=False,
+ logger=logger)
+ else:
+ print_log(
+ f'load {self.prefix} in model from: {self.checkpoint}',
+ logger=logger)
+ state_dict = _load_checkpoint_with_prefix(
+ self.prefix, self.checkpoint, map_location=self.map_location)
+ load_state_dict(module, state_dict, strict=False, logger=logger)
+
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self) -> str:
+ info = f'{self.__class__.__name__}: load from {self.checkpoint}'
+ return info
+
+
+def _initialize(module: nn.Module,
+ cfg: Dict,
+ wholemodule: bool = False) -> None:
+ func = build_from_cfg(cfg, INITIALIZERS)
+ # wholemodule flag is for override mode, there is no layer key in override
+ # and initializer will give init values for the whole module with the name
+ # in override.
+ func.wholemodule = wholemodule
+ func(module)
+
+
+def _initialize_override(module: nn.Module, override: Union[Dict, List],
+ cfg: Dict) -> None:
+ if not isinstance(override, (dict, list)):
+ raise TypeError(f'override must be a dict or a list of dict, \
+ but got {type(override)}')
+
+ override = [override] if isinstance(override, dict) else override
+
+ for override_ in override:
+
+ cp_override = copy.deepcopy(override_)
+ name = cp_override.pop('name', None)
+ if name is None:
+ raise ValueError('`override` must contain the key "name",'
+ f'but got {cp_override}')
+ # if override only has name key, it means use args in init_cfg
+ if not cp_override:
+ cp_override.update(cfg)
+ # if override has name key and other args except type key, it will
+ # raise error
+ elif 'type' not in cp_override.keys():
+ raise ValueError(
+ f'`override` need "type" key, but got {cp_override}')
+
+ if hasattr(module, name):
+ _initialize(getattr(module, name), cp_override, wholemodule=True)
+ else:
+ raise RuntimeError(f'module did not have attribute {name}, '
+ f'but init_cfg is {cp_override}.')
+
+
+def initialize(module: nn.Module, init_cfg: Union[Dict, List[dict]]) -> None:
+ r"""Initialize a module.
+
+ Args:
+ module (``torch.nn.Module``): the module will be initialized.
+ init_cfg (dict | list[dict]): initialization configuration dict to
+ define initializer. OpenMMLab has implemented 6 initializers
+ including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
+ ``Kaiming``, and ``Pretrained``.
+
+ Example:
+ >>> module = nn.Linear(2, 3, bias=True)
+ >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
+ >>> initialize(module, init_cfg)
+
+ >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
+ >>> # define key ``'layer'`` for initializing layer with different
+ >>> # configuration
+ >>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
+ dict(type='Constant', layer='Linear', val=2)]
+ >>> initialize(module, init_cfg)
+
+ >>> # define key``'override'`` to initialize some specific part in
+ >>> # module
+ >>> class FooNet(nn.Module):
+ >>> def __init__(self):
+ >>> super().__init__()
+ >>> self.feat = nn.Conv2d(3, 16, 3)
+ >>> self.reg = nn.Conv2d(16, 10, 3)
+ >>> self.cls = nn.Conv2d(16, 5, 3)
+ >>> model = FooNet()
+ >>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
+ >>> override=dict(type='Constant', name='reg', val=3, bias=4))
+ >>> initialize(model, init_cfg)
+
+ >>> model = ResNet(depth=50)
+ >>> # Initialize weights with the pretrained model.
+ >>> init_cfg = dict(type='Pretrained',
+ checkpoint='torchvision://resnet50')
+ >>> initialize(model, init_cfg)
+
+ >>> # Initialize weights of a sub-module with the specific part of
+ >>> # a pretrained model by using "prefix".
+ >>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
+ >>> 'retinanet_r50_fpn_1x_coco/'\
+ >>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
+ >>> init_cfg = dict(type='Pretrained',
+ checkpoint=url, prefix='backbone.')
+ """
+ if not isinstance(init_cfg, (dict, list)):
+ raise TypeError(f'init_cfg must be a dict or a list of dict, \
+ but got {type(init_cfg)}')
+
+ if isinstance(init_cfg, dict):
+ init_cfg = [init_cfg]
+
+ for cfg in init_cfg:
+ # should deeply copy the original config because cfg may be used by
+ # other modules, e.g., one init_cfg shared by multiple bottleneck
+ # blocks, the expected cfg will be changed after pop and will change
+ # the initialization behavior of other modules
+ cp_cfg = copy.deepcopy(cfg)
+ override = cp_cfg.pop('override', None)
+ _initialize(module, cp_cfg)
+
+ if override is not None:
+ cp_cfg.pop('layer', None)
+ _initialize_override(module, override, cp_cfg)
+ else:
+ # All attributes in module have same initialization.
+ pass
+
+
+def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
+ b: float) -> Tensor:
+ # Method based on
+ # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ # Modified from
+ # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ lower = norm_cdf((a - mean) / std)
+ upper = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [lower, upper], then translate
+ # to [2lower-1, 2upper-1].
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor: Tensor,
+ mean: float = 0.,
+ std: float = 1.,
+ a: float = -2.,
+ b: float = 2.) -> Tensor:
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+
+ Modified from
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+
+ Args:
+ tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
+ mean (float): the mean of the normal distribution.
+ std (float): the standard deviation of the normal distribution.
+ a (float): the minimum cutoff value.
+ b (float): the maximum cutoff value.
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/mmcv/mmcv/cnn/vgg.py b/mmcv/mmcv/cnn/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1d9ba211eb4b0056eb4127e19159e9ed5d5251f
--- /dev/null
+++ b/mmcv/mmcv/cnn/vgg.py
@@ -0,0 +1,177 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+from typing import List, Optional, Sequence, Tuple, Union
+
+import torch.nn as nn
+from torch import Tensor
+
+from .utils import constant_init, kaiming_init, normal_init
+
+
+def conv3x3(in_planes: int, out_planes: int, dilation: int = 1) -> nn.Module:
+ """3x3 convolution with padding."""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ padding=dilation,
+ dilation=dilation)
+
+
+def make_vgg_layer(inplanes: int,
+ planes: int,
+ num_blocks: int,
+ dilation: int = 1,
+ with_bn: bool = False,
+ ceil_mode: bool = False) -> List[nn.Module]:
+ layers = []
+ for _ in range(num_blocks):
+ layers.append(conv3x3(inplanes, planes, dilation))
+ if with_bn:
+ layers.append(nn.BatchNorm2d(planes))
+ layers.append(nn.ReLU(inplace=True))
+ inplanes = planes
+ layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
+
+ return layers
+
+
+class VGG(nn.Module):
+ """VGG backbone.
+
+ Args:
+ depth (int): Depth of vgg, from {11, 13, 16, 19}.
+ with_bn (bool): Use BatchNorm or not.
+ num_classes (int): number of classes for classification.
+ num_stages (int): VGG stages, normally 5.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
+ running stats (mean and var).
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+ """
+
+ arch_settings = {
+ 11: (1, 1, 2, 2, 2),
+ 13: (2, 2, 2, 2, 2),
+ 16: (2, 2, 3, 3, 3),
+ 19: (2, 2, 4, 4, 4)
+ }
+
+ def __init__(self,
+ depth: int,
+ with_bn: bool = False,
+ num_classes: int = -1,
+ num_stages: int = 5,
+ dilations: Sequence[int] = (1, 1, 1, 1, 1),
+ out_indices: Sequence[int] = (0, 1, 2, 3, 4),
+ frozen_stages: int = -1,
+ bn_eval: bool = True,
+ bn_frozen: bool = False,
+ ceil_mode: bool = False,
+ with_last_pool: bool = True):
+ super().__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for vgg')
+ assert num_stages >= 1 and num_stages <= 5
+ stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ assert len(dilations) == num_stages
+ assert max(out_indices) <= num_stages
+
+ self.num_classes = num_classes
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.bn_eval = bn_eval
+ self.bn_frozen = bn_frozen
+
+ self.inplanes = 3
+ start_idx = 0
+ vgg_layers = []
+ self.range_sub_modules = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ num_modules = num_blocks * (2 + with_bn) + 1
+ end_idx = start_idx + num_modules
+ dilation = dilations[i]
+ planes = 64 * 2**i if i < 4 else 512
+ vgg_layer = make_vgg_layer(
+ self.inplanes,
+ planes,
+ num_blocks,
+ dilation=dilation,
+ with_bn=with_bn,
+ ceil_mode=ceil_mode)
+ vgg_layers.extend(vgg_layer)
+ self.inplanes = planes
+ self.range_sub_modules.append([start_idx, end_idx])
+ start_idx = end_idx
+ if not with_last_pool:
+ vgg_layers.pop(-1)
+ self.range_sub_modules[-1][1] -= 1
+ self.module_name = 'features'
+ self.add_module(self.module_name, nn.Sequential(*vgg_layers))
+
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Linear(512 * 7 * 7, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, num_classes),
+ )
+
+ def init_weights(self, pretrained: Optional[str] = None) -> None:
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ elif isinstance(m, nn.Linear):
+ normal_init(m, std=0.01)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, ...]]:
+ outs = []
+ vgg_layers = getattr(self, self.module_name)
+ for i in range(len(self.stage_blocks)):
+ for j in range(*self.range_sub_modules[i]):
+ vgg_layer = vgg_layers[j]
+ x = vgg_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if self.num_classes > 0:
+ x = x.view(x.size(0), -1)
+ x = self.classifier(x)
+ outs.append(x)
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+ def train(self, mode: bool = True) -> None:
+ super().train(mode)
+ if self.bn_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ if self.bn_frozen:
+ for params in m.parameters():
+ params.requires_grad = False
+ vgg_layers = getattr(self, self.module_name)
+ if mode and self.frozen_stages >= 0:
+ for i in range(self.frozen_stages):
+ for j in range(*self.range_sub_modules[i]):
+ mod = vgg_layers[j]
+ mod.eval()
+ for param in mod.parameters():
+ param.requires_grad = False
diff --git a/mmcv/mmcv/device/__init__.py b/mmcv/mmcv/device/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba217b0771bcfada461d7c61a78f41a274e5aa6a
--- /dev/null
+++ b/mmcv/mmcv/device/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from . import ipu, mlu, mps
+from .scatter_gather import scatter, scatter_kwargs
+from .utils import get_device
+
+__all__ = ['mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs']
diff --git a/mmcv/mmcv/device/_functions.py b/mmcv/mmcv/device/_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..462a7e4ddca14685047b7937e3054108e164cf91
--- /dev/null
+++ b/mmcv/mmcv/device/_functions.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Union
+
+import torch
+
+from mmcv.utils import deprecated_api_warning
+from .utils import get_device
+
+
+def scatter(input: Union[List, torch.Tensor], devices: List) -> List:
+ """scatter copies tensor to devices directly."""
+ current_device = get_device()
+ if isinstance(input, list):
+ outputs = [scatter(_input, devices) for _input in input]
+ return outputs
+ elif isinstance(input, torch.Tensor):
+ output = input.contiguous()
+ return output.to(current_device) if devices != [-1] else output
+ else:
+ raise Exception(f'Unknown type {type(input)}.')
+
+
+class Scatter:
+
+ @staticmethod
+ @deprecated_api_warning({'target_mlus': 'target_devices'},
+ cls_name='Scatter')
+ def forward(target_devices, input):
+ outputs = scatter(input, target_devices)
+ return tuple(outputs) if isinstance(outputs, list) else (outputs, )
diff --git a/mmcv/mmcv/device/ipu/__init__.py b/mmcv/mmcv/device/ipu/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..d550865ad20790f0eb79015abc866548c0f2f83b
--- /dev/null
+++ b/mmcv/mmcv/device/ipu/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.utils import IS_IPU_AVAILABLE
+
+if IS_IPU_AVAILABLE:
+ from .dataloader import IPUDataLoader
+ from .hook_wrapper import IPUFp16OptimizerHook
+ from .model_wrapper import ipu_model_wrapper
+ from .runner import IPUBaseRunner, IPUEpochBasedRunner, IPUIterBasedRunner
+ from .utils import cfg2options
+ __all__ = [
+ 'cfg2options', 'ipu_model_wrapper', 'IPUFp16OptimizerHook',
+ 'IPUDataLoader', 'IPUBaseRunner', 'IPUEpochBasedRunner',
+ 'IPUIterBasedRunner'
+ ]
diff --git a/mmcv/mmcv/device/ipu/dataloader.py b/mmcv/mmcv/device/ipu/dataloader.py
new file mode 100755
index 0000000000000000000000000000000000000000..1485df2f31facff79238c70d89fdd9030fddcbce
--- /dev/null
+++ b/mmcv/mmcv/device/ipu/dataloader.py
@@ -0,0 +1,157 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections.abc import Mapping, Sequence
+from functools import partial
+
+import poptorch
+from torch.utils.data.dataloader import default_collate
+
+from mmcv.parallel import DataContainer
+
+
+def collate(batch, samples_per_gpu=1):
+ """Put each data field into a tensor/DataContainer with outer dimension
+ batch size.
+
+ TODO support for
+ :type:`~mmcv.parallel.DataContainer`. Currently, it will be ignored.
+ There are 3 cases.
+
+ 1. cpu_only = True, e.g., meta data.
+ 2. cpu_only = False, stack = True, e.g., images tensors.
+ 3. cpu_only = False, stack = False, e.g., gt bboxes.
+ """
+
+ if not isinstance(batch, Sequence):
+ raise TypeError(
+ f'`batch` should be a sequence, but got {type(batch)}.')
+
+ if isinstance(batch[0], DataContainer):
+ # TODO `DataContainer` will be supported in the future.
+ raise TypeError('DataContainer is not supported in ipu data loader.')
+ elif isinstance(batch[0], Sequence):
+ transposed = zip(*batch)
+ collated_batch = []
+ for samples in transposed:
+ if not isinstance(samples[0], DataContainer):
+ # At present, we will skip the processing of datacontainer,
+ # which will reduce the performance of IPU DataLoder
+ collated_batch.append(collate(samples, samples_per_gpu))
+ return collated_batch
+ elif isinstance(batch[0], Mapping):
+ collated_batch = {}
+ for key in batch[0]:
+ if not isinstance(batch[0][key], DataContainer):
+ # At present, we will skip the processing of datacontainer,
+ # which will reduce the performance of IPU DataLoder
+ collated_batch[key] = collate([d[key] for d in batch])
+ return collated_batch
+ else:
+ return default_collate(batch)
+
+
+class IPUDataLoader(poptorch.DataLoader):
+ """Thin wrapper of `torch.utils.data.DataLoader`.
+
+ Compared with the pytorch DataLoder, this DataLoder changes the way of
+ calculation of batch size and adds the AsynchronousDataAccessor to
+ load and release data faster in cpu mode.
+
+ If this data loader is used in a distributed execution environment, it will
+ ensure that each process uses a different subset of the dataset, providing
+ you first call ``options.randomSeed(N)`` with an integer N which is the
+ same across all hosts.
+
+ Args:
+ dataset (torch.utils.data.Dataset): The dataset to get the data from.
+ options (poptorch.Options): Options that will be used to compile
+ and run the model.
+ batch_size (int, optional): This is the batch size in the conventional
+ sense of being the size that runs through an operation in the model
+ at any given time.
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
+ at every epoch (default: ``False``).
+ num_workers (int, optional): how many subprocesses to use for data
+ loading. ``0`` means that the data will be loaded in the main
+ process. (default: ``0``)
+ drop_last (bool, optional): If True and the number of elements in the
+ dataset is not a multiple of the combined batch size then the
+ incomplete batch at the end will be dropped.
+ persistent_workers (bool, optional): Re-use workers between
+ iterations if True.
+ auto_distributed_partitioning (bool, optional): If True, partitions the
+ dataset for distributed execution automatically. Otherwise, it is
+ assumed that partitioning has been handled manually.
+ mode (poptorch.DataLoaderMode, optional): If `DataLoaderMode.Async`,
+ uses an :py:class:`~poptorch.AsynchronousDataAccessor` to access
+ the dataset. If `DataLoaderMode.Sync`, accesses the dataset
+ synchronously.
+ async_options (Dict[str, Any], optional): Options to pass to
+ :py:class:`~poptorch.AsynchronousDataAccessor`.
+ rebatched_worker_size (int, optional): When using AsyncRebatched: batch
+ size of the tensors loaded by the workers.
+ Default to the combined batch size.
+ If specified the ``rebatched_worker_size`` must be less than
+ or equal to the combined batch size.
+ kwargs (Dict[str, Any], optional): Other options to pass to PyTorch's
+ ``DataLoader`` constructor.
+ """
+
+ def __init__(self,
+ dataset,
+ options,
+ batch_size=1,
+ shuffle=False,
+ num_workers=0,
+ drop_last=True,
+ persistent_workers=True,
+ auto_distributed_partitioning=True,
+ mode='sync',
+ async_options=None,
+ rebatched_worker_size=None,
+ **kwargs):
+ """Lazy init:
+
+ In many frameworks, the dataloader will be constructed before the
+ initialization of the ipu options, so the lazy init method is used
+ here, and the real initialization will not be done until the dataloader
+ needs to be used and the options are input.
+ """
+ # lazy init: sometimes, we cannot get IPU options when build data
+ # loader
+ self.kwargs = {
+ 'dataset': dataset,
+ 'batch_size': batch_size,
+ 'shuffle': shuffle,
+ 'num_workers': num_workers,
+ 'drop_last': drop_last,
+ 'persistent_workers': persistent_workers,
+ 'auto_distributed_partitioning': auto_distributed_partitioning,
+ 'mode': mode,
+ 'collate_fn': partial(collate, samples_per_gpu=batch_size),
+ 'async_options': async_options,
+ 'rebatched_worker_size': rebatched_worker_size,
+ **kwargs
+ }
+ self.dataset = dataset
+ self.initialized = False
+ if options:
+ self.init(options=options)
+
+ def init(self, options, **kwargs):
+ if not self.initialized:
+ kwargs = {**self.kwargs, **kwargs, 'options': options}
+ if kwargs['mode'] == 'sync':
+ kwargs['mode'] = poptorch.DataLoaderMode.Sync
+ elif kwargs['mode'] == 'async':
+ kwargs['mode'] = poptorch.DataLoaderMode.AsyncRebatched
+ if kwargs['async_options'] is None:
+ kwargs['async_options'] = {
+ 'load_indefinitely': True,
+ 'buffer_size': 8
+ }
+ if kwargs['rebatched_worker_size'] is None:
+ kwargs['rebatched_worker_size'] = 128
+ super().__init__(**kwargs)
+ self.initialized = True
+
+ return self
diff --git a/mmcv/mmcv/device/ipu/hierarchical_data_manager.py b/mmcv/mmcv/device/ipu/hierarchical_data_manager.py
new file mode 100755
index 0000000000000000000000000000000000000000..a6f3b3cd2a139bcbc7852e7849071ab4b9fbb76f
--- /dev/null
+++ b/mmcv/mmcv/device/ipu/hierarchical_data_manager.py
@@ -0,0 +1,243 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import numpy as np
+import torch
+
+from mmcv.parallel import DataContainer
+
+# A customized None type for HierarchicalDataManager
+HierarchicalDataNone = object()
+
+
+class HierarchicalDataManager:
+ """A class manage all the tensors in the hierarchical data.
+
+ At present, the input data structure accepted by IPU is limited,
+ when the input data structure of mmcv varies.
+ Here, an intermediate class is needed to get and update tensors
+ from the original data.
+
+ HierarchicalDataManager will record a hierarchical input/output data in
+ self._hierarchical_data. For example, we have an input data:
+ {'img': tensorA, 'label': tensorB, 'img_metas': [tensorC, tensorD]}
+ To enable IPU to use the input, HierarchicalDataManager will collect
+ the torch tensors from self._hierarchical_data into a tuple like:
+ (tensorA, tensorB, tensorC, tensorD).
+ Meanwhile, the return of IPU is a tuple of tensors, HierarchicalDataManager
+ also have a function named update_all_tensors to update tensors in
+ self._hierarchical_data which is the output for upper calls.
+
+ Args:
+ logger (:obj:`logging.Logger`): Logger used during running.
+ Defaults to None.
+ """
+
+ def __init__(self, logger=None):
+ self.atomic_types = (int, str, float, np.ndarray, type(None))
+ self.warning = warnings.warn if logger is None else logger.warning
+ # enable or disable input data's shape and value check
+ self.quick_mode = False
+ self._hierarchical_data = None
+
+ def quick(self):
+ self.quick_mode = True
+
+ def compare_atomic_type(self, a, b):
+ """Compare data, supported datatypes are numpy array and python basic
+ types."""
+ if isinstance(a, np.ndarray):
+ return np.all(a == b)
+ else:
+ return a == b
+
+ def record_hierarchical_data(self, data):
+ """Record a hierarchical data."""
+ if self._hierarchical_data is not None:
+ if isinstance(data, torch.Tensor):
+ assert isinstance(self._hierarchical_data, torch.Tensor), \
+ 'original hierarchical data is not torch.tensor'
+ self._hierarchical_data = data
+ else:
+ self.update_hierarchical_data(data)
+ else:
+ self._hierarchical_data = data
+
+ @property
+ def hierarchical_data(self):
+ return self._hierarchical_data
+
+ def update_hierarchical_data(self,
+ dataA,
+ dataB=HierarchicalDataNone,
+ strict=True,
+ address='data'):
+ """Update dataB with dataA in-place.
+
+ Args:
+ dataA (list or dict or tuple): New hierarchical data.
+ dataB (list or dict or tuple): hierarchical data to update.
+ if not specified, self.hierarchical_data will be updated then.
+ strict (bool, optional): If true, an error will be reported
+ when the following conditions occur:
+ 1. Non-torch.Tensor data changed.
+ 2. Torch.Tensor data shape changed.
+ address (str): Record the address of current data to be updated.
+ Default: 'data'.
+ """
+ if dataB is HierarchicalDataNone:
+ dataB = self.hierarchical_data
+
+ # Update with a da ta with the same structure
+ # but different values(tensors and basic python data types)
+ if isinstance(dataA, (tuple, list)):
+ for idx, node in enumerate(dataA):
+ new_address = ''
+ if not self.quick_mode:
+ new_address = address + f'[{str(idx)}]'
+ assert isinstance(node, type(dataB[idx])),\
+ f'data structure changed: {new_address}'
+ if isinstance(node, torch.Tensor):
+ dataB[idx] = node
+ else:
+ self.update_hierarchical_data(
+ node, dataB[idx], strict, address=new_address)
+ elif isinstance(dataA, dict):
+ for k, v in dataA.items():
+ new_address = ''
+ if not self.quick_mode:
+ new_address = address + f'[{str(k)}]'
+ assert isinstance(v, type(dataB[k])),\
+ f'data structure changed: {new_address}'
+ if isinstance(v, torch.Tensor):
+ dataB[k] = v
+ else:
+ self.update_hierarchical_data(
+ v, dataB[k], strict, address=new_address)
+ elif isinstance(dataA, self.atomic_types):
+ if not self.quick_mode:
+ is_equal = self.compare_atomic_type(dataA, dataB)
+ if not is_equal:
+ if strict:
+ raise ValueError(
+ 'all data except torch.Tensor should be same, '
+ f'but data({address}) is changed.')
+ else:
+ self.warning(
+ f'find a non-torch.Tensor data({type(dataA)}) '
+ f'changed, and the address is {address}')
+ elif isinstance(dataA, DataContainer):
+ if not self.quick_mode:
+ assert isinstance(dataB, DataContainer)
+ new_address = address + '.data'
+ self.update_hierarchical_data(
+ dataA.data, dataB.data, False, address=new_address)
+ else:
+ raise NotImplementedError(
+ f'not supported datatype:{type(dataA)}, address is {address}')
+
+ def collect_all_tensors(self, hierarchical_data=None):
+ """Collect torch.Tensor data from self.hierarchical_data to a list and
+ return."""
+ # get a list of tensor from self._hierarchical_data
+ if hierarchical_data is None:
+ hierarchical_data = self._hierarchical_data
+ tensors = []
+ if isinstance(hierarchical_data, torch.Tensor):
+ tensors = [hierarchical_data]
+ else:
+ self._collect_tensors(hierarchical_data, tensors)
+ return tensors
+
+ def _collect_tensors(self, data, tensors):
+ if isinstance(data, (tuple, list)):
+ for node in data:
+ if isinstance(node, torch.Tensor):
+ tensors.append(node)
+ else:
+ self._collect_tensors(node, tensors)
+ elif isinstance(data, dict):
+ for v in data.values():
+ if isinstance(v, torch.Tensor):
+ tensors.append(v)
+ else:
+ self._collect_tensors(v, tensors)
+ elif isinstance(data, self.atomic_types):
+ pass
+ elif isinstance(data, DataContainer):
+ self._collect_tensors(data.data, tensors)
+ else:
+ raise NotImplementedError(f'not supported datatype:{type(data)}')
+
+ def update_all_tensors(self, tensors):
+ """Put tensors from tuple back to self.hierarchical_data."""
+ if isinstance(self._hierarchical_data, torch.Tensor):
+ print(tensors, len(tensors))
+ assert len(tensors) == 1
+ assert isinstance(tensors[0], torch.Tensor)
+ self._hierarchical_data = tensors[0]
+ else:
+ # convert to list if tensors is tuple
+ tensors = list(tensors)
+ self._set_tensors(self._hierarchical_data, tensors)
+ return self.hierarchical_data
+
+ def _set_tensors(self, data, tensors):
+ if isinstance(data, tuple):
+ data = list(data)
+ for idx in range(len(data)):
+ if isinstance(data[idx], torch.Tensor):
+ data[idx] = tensors.pop(0)
+ else:
+ self._set_tensors(data[idx], tensors)
+ data = tuple(data)
+ elif isinstance(data, list):
+ for idx in range(len(data)):
+ if isinstance(data[idx], torch.Tensor):
+ data[idx] = tensors.pop(0)
+ else:
+ self._set_tensors(data[idx], tensors)
+ elif isinstance(data, dict):
+ for k, v in data.items():
+ if isinstance(v, torch.Tensor):
+ data[k] = tensors.pop(0)
+ else:
+ self._set_tensors(v, tensors)
+ elif isinstance(data, self.atomic_types):
+ pass
+ elif isinstance(data, DataContainer):
+ self._set_tensors(data.data, tensors)
+ else:
+ raise NotImplementedError(f'not supported datatype:{type(data)}')
+
+ def clean_all_tensors(self):
+ """Delete tensors from self.hierarchical_data."""
+ self._clean_tensors(self._hierarchical_data)
+
+ def _clean_tensors(self, data):
+ if isinstance(data, tuple):
+ data = list(data)
+ for idx in range(len(data)):
+ if isinstance(data[idx], torch.Tensor):
+ data[idx] = None
+ else:
+ self._clean_tensors(data[idx])
+ data = tuple(data)
+ elif isinstance(data, list):
+ for idx in range(len(data)):
+ if isinstance(data[idx], torch.Tensor):
+ data[idx] = None
+ else:
+ self._clean_tensors(data[idx])
+ elif isinstance(data, dict):
+ for k, v in data.items():
+ if isinstance(v, torch.Tensor):
+ data[k] = None
+ else:
+ self._clean_tensors(v)
+ elif isinstance(data, self.atomic_types):
+ pass
+ elif isinstance(data, DataContainer):
+ self._clean_tensors(data.data)
+ else:
+ raise NotImplementedError(f'not supported datatype:{type(data)}')
diff --git a/mmcv/mmcv/device/ipu/hook_wrapper.py b/mmcv/mmcv/device/ipu/hook_wrapper.py
new file mode 100755
index 0000000000000000000000000000000000000000..141afb86d05a42c06fb5c4355cb47cae18e9bb2f
--- /dev/null
+++ b/mmcv/mmcv/device/ipu/hook_wrapper.py
@@ -0,0 +1,105 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.runner import HOOKS, LrUpdaterHook, OptimizerHook
+from mmcv.utils import TORCH_VERSION, digit_version
+
+
+def wrap_lr_updater_hook(lr_hook_class):
+ """A wrapper function to wrap any subclass of LrUpdaterHook.
+
+ IPU needs extra operations to upload optimizer settings. This wrapper will
+ override function(_set_lr) of a subclass of LrUpdaterHook.
+ """
+ assert issubclass(lr_hook_class, LrUpdaterHook)
+
+ class ipu_lr_hook_class(lr_hook_class):
+
+ def _set_lr(self, runner, *args, **kwargs):
+ super()._set_lr(runner, *args, **kwargs)
+ # convert torch optimizer to poptorch optimizer
+ runner.model.setOptimizer(runner.optimizer)
+
+ return ipu_lr_hook_class
+
+
+def wrap_optimizer_hook(optimizer_hook_class):
+ """A wrapper function to wrap OptimizerHook.
+
+ This is an non-intrusive implementation of wrapping optimizer hook (or you
+ need to change every config file to use IPU optimizer hook) IPU's clip-norm
+ implementation is different from pytorch, so there should be an error
+ raised when using clip-norm.
+ """
+
+ class ipu_optimizer_hook_class(OptimizerHook):
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ if self.grad_clip is not None:
+ raise NotImplementedError('IPU does not support gradient clip')
+
+ return ipu_optimizer_hook_class
+
+
+if (TORCH_VERSION != 'parrots'
+ and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+
+ @HOOKS.register_module()
+ class IPUFp16OptimizerHook(OptimizerHook):
+ """FP16 optimizer hook (using PyTorch's implementation).
+
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
+ to take care of the optimization procedure.
+
+ Args:
+ loss_scale (float | str | dict): Scale factor configuration.
+ If loss_scale is a float, static loss scaling will be used with
+ the specified scale. If loss_scale is a string, it must be
+ 'dynamic', then dynamic loss scaling will be used.
+ It can also be a dict containing arguments of GradScalar.
+ Defaults to 512. For Pytorch >= 1.6, mmcv uses official
+ implementation of GradScaler. If you use a dict version of
+ loss_scale to create GradScaler, please refer to:
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
+ for the parameters.
+
+ Examples:
+ >>> loss_scale = dict(
+ ... init_scale=65536.0,
+ ... growth_factor=2.0,
+ ... backoff_factor=0.5,
+ ... growth_interval=2000
+ ... )
+ >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale)
+ """
+
+ def __init__(self,
+ grad_clip=None,
+ coalesce=True,
+ bucket_size_mb=-1,
+ loss_scale=512.,
+ distributed=True):
+ assert grad_clip is None,\
+ 'IPU mode does not support `grad_clip` currently'
+ assert coalesce,\
+ 'implemented all reduce in distributed training currently'
+ assert bucket_size_mb == -1,\
+ '`bucket_size_mb` should not be set in IPU mode'
+ self.distributed = distributed
+ self._scale_update_param = None
+ if loss_scale == 'dynamic':
+ raise NotImplementedError(
+ 'IPU mode does not support dynamic loss scale currently')
+ elif isinstance(loss_scale, float):
+ self.loss_scale = loss_scale
+ elif isinstance(loss_scale, dict):
+ raise NotImplementedError(
+ 'IPU mode supports single scale currently')
+ else:
+ raise ValueError(
+ f'loss_scale should be float, but got {loss_scale} ')
+
+ def after_train_iter(self, runner):
+ pass
+
+else:
+ raise RuntimeError('The IPU mode only supports torch 1.6 and above')
diff --git a/mmcv/mmcv/device/ipu/model_wrapper.py b/mmcv/mmcv/device/ipu/model_wrapper.py
new file mode 100755
index 0000000000000000000000000000000000000000..c345537e29b27cf7fff740269da8643c9570cd36
--- /dev/null
+++ b/mmcv/mmcv/device/ipu/model_wrapper.py
@@ -0,0 +1,721 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import inspect
+from collections import OrderedDict
+from typing import Optional, Union
+
+import poptorch
+import torch
+import torch.nn as nn
+from poptorch import PoplarExecutor, __version__, identity_loss
+from poptorch._args_parser import ArgsParser
+
+from mmcv.runner import auto_fp16
+from .hierarchical_data_manager import HierarchicalDataManager
+from .utils import compare_ndarray, model_sharding, recomputation_checkpoint
+
+
+class DictArgsParser(ArgsParser):
+ """A helper class for handling model input.
+
+ Args:
+ inputs (list): Inputs of model.
+ """
+
+ def __init__(self, inputs):
+ # Combine args and kwargs:
+ self._has_variadic_arguments = True
+ self._varnames = list(inputs.keys())
+ self._defaults = [inspect.Parameter.empty for _ in self._varnames]
+ self._warned_not_contiguous_input = False
+
+
+class WrappedNet(nn.Module):
+ """A net wrapper for model conversion.
+
+ This wrapper will make some changes and add some extra functions to
+ training/inference model.
+
+ Args:
+ model (:obj:`nn.Module`): The model to run.
+ inputs_manager (:obj:`HierarchicalDataManager`): A parser
+ converting inputs from tuple to dictionary.
+ outputs_manager (:obj:`HierarchicalDataManager`): A parser
+ converting outputs from dictionary to tuple.
+ inter_outputs_in_cpu (dict): Specify the features to be
+ recorded.
+ modules_to_record (mmcv.Config, list): Index or name of modules which
+ will be recorded for output. It is necessary to specify output for
+ static graph of model training or inference.
+ """
+
+ def __init__(self,
+ model,
+ inputs_manager,
+ outputs_manager,
+ inter_outputs_in_cpu,
+ modules_to_record=None):
+ super().__init__()
+ self.model = model
+ self.inputs_manager = inputs_manager
+ self.outputs_manager = outputs_manager
+ self.training = model.training
+ # Register a hook function to capture the intermediate features
+ # generated by the network to align the outputs between ipu and cpu
+ # Used to confirm whether the implementation of CPU is consistent
+ # with the implementation of IPU
+ self.inter_outputs_in_cpu = inter_outputs_in_cpu
+ if modules_to_record is None:
+ modules_to_record = []
+
+ for idx, (name, module) in enumerate(model.named_modules()):
+ if name in modules_to_record or idx in modules_to_record:
+ features_hook = self.get_input_output_hook(
+ name, idx, self.inter_outputs_in_cpu)
+ module.register_forward_hook(hook=features_hook)
+
+ def get_input_output_hook(self, name, idx, save_dict):
+
+ def input_output_hook(module, fea_in, fea_out):
+ if isinstance(fea_in, tuple):
+ fea_in = list(fea_in)
+ if isinstance(fea_out, tuple):
+ fea_out = list(fea_out)
+ save_dict[name] = {
+ 'fea_in': fea_in,
+ 'fea_out': fea_out,
+ 'idx': idx
+ }
+ return None
+
+ return input_output_hook
+
+ def forward(self, inputs_tuple):
+ """This function is used to be compiled to ipu, the inputs and outputs
+ need to be tuples, so here we need to restore the input back to a
+ dictionary and convert the output to a tuple."""
+ self.inputs_manager.update_all_tensors(inputs_tuple)
+ kwargs = {**(self.inputs_manager.hierarchical_data)}
+ if self.training:
+ outputs = self.forward_train(kwargs)
+ # tell poptorch which loss will be used finally
+ identity_loss(outputs['loss'], reduction='none')
+ else:
+ outputs = self.forward_eval(kwargs)
+
+ if isinstance(outputs, torch.Tensor):
+ # currently not support single tensor output,
+ # need to wrap it with a dictionary,
+ # use a keyword to identify this case
+ outputs = {'output of WrappedNet: single tensor': outputs}
+
+ # if there are some features need to be record, add extra outputs
+ for name in self.inter_outputs_in_cpu:
+ outputs[name] = self.inter_outputs_in_cpu[name]
+
+ # record all the places of return tensors in the converting stage
+ # while in the real run stage, all the tensor are changed in-place
+ # that means the output can be obtained directly outside this function
+ self.outputs_manager.record_hierarchical_data(outputs)
+ plain_outputs = self.outputs_manager.collect_all_tensors()
+ return plain_outputs
+
+ def forward_train(self, kwargs):
+ optimizer = kwargs.pop('optimizer')
+ outputs = self.train_step(kwargs, optimizer)
+ return outputs
+
+ def train_step(self, data, optimizer=None, **kwargs):
+ """The iteration step during training.
+
+ This method defines an iteration step during training, except for the
+ back propagation and optimizer updating, which are done in an optimizer
+ hook. Note that in some complicated cases or models, the whole process
+ including back propagation and optimizer updating are also defined in
+ this method, such as GAN.
+
+ Args:
+ data (dict): The output of dataloader.
+ optimizer (:obj:`torch.optim.Optimizer`, optional): The
+ optimizer of runner is passed to ``train_step()``. This
+ argument is unused and reserved.
+
+ Returns:
+ dict: Dict of outputs. The following fields are contained.
+ - loss (torch.Tensor): A tensor for back propagation, which \
+ can be a weighted sum of multiple losses.
+ - log_vars (dict): Dict contains all the variables to be sent \
+ to the logger.
+ - num_samples (int): Indicates the batch size (when the model \
+ is DDP, it means the batch size on each GPU), which is \
+ used for averaging the logs.
+ """
+ losses = self.model(**data)
+ loss, log_vars = self._parse_losses(losses)
+
+ outputs = dict(
+ loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
+
+ return outputs
+
+ def _parse_losses(self, losses):
+ log_vars = OrderedDict()
+ for loss_name, loss_value in losses.items():
+ if isinstance(loss_value, torch.Tensor):
+ log_vars[loss_name] = loss_value.mean()
+ elif isinstance(loss_value, list):
+ log_vars[loss_name] = sum(loss.mean() for loss in loss_value)
+ elif isinstance(loss_value, dict):
+ for name, value in loss_value.items():
+ log_vars[name] = value
+ else:
+ raise TypeError(
+ f'{loss_name} is not a tensor or list of tensors')
+
+ loss = sum(value for key, value in log_vars.items() if 'loss' in key)
+ log_vars['loss'] = loss
+
+ return loss, log_vars
+
+ def forward_eval(self, kwargs):
+ img = kwargs.pop('img')
+ img_metas = kwargs.pop('img_metas', None)
+ return_loss = kwargs.pop('return_loss')
+ assert not return_loss
+ # TODO Temporarily hard-code to close post_process,
+ # otherwise, in the third trace(_check_trace),
+ # post_process will convert output tensor to numpy array automatically,
+ # resulting in _check_trace failure
+ outputs = self.model(
+ img,
+ img_metas=img_metas,
+ return_loss=return_loss,
+ post_process=False)
+ return outputs
+
+
+class MMPoplarExecutor(PoplarExecutor):
+ """An executor for inputs/outputs parsing, model compilation, data
+ alignment and IPU upload/download.
+
+ Args:
+ model (:obj:`nn.Module`): The model to be compiled.
+ logger (:obj:`logging.Logger`): Logger used during running.
+ Defaults to None.
+ training (bool): Model in training mode or eval mode.
+ modules_to_record (mmcv.Config, list): Index or name of modules which
+ will be recorded for output. It is necessary to specify output for
+ static graph of model training or inference.
+ args (argument list): Arguments passed to the `__init__`
+ method of PoplarExecutor.
+ kwargs (keyword arguments): Keyword arguments passed to the `__init__`
+ method of PoplarExecutor.
+ """
+
+ def __init__(self,
+ model,
+ logger=None,
+ training=True,
+ modules_to_record=None,
+ *args,
+ **kwargs):
+ # self.model == self._user_model: input pytorch model
+ # self._model: wrapped model which is used to compile
+ # and update weights, these two models use same weights
+ # wrapped model only accept and output tuple, so
+ # HierarchicalDataManager will convert dictionary
+ # to tuple and convert them back
+ self.inputs_manager = HierarchicalDataManager(logger=logger)
+ self.outputs_manager = HierarchicalDataManager(logger=logger)
+ self.logger = logger
+ # the features calculated by CPU
+ self.inter_outputs_in_cpu = {}
+ # the features calculated by IPU
+ self.inter_outputs_in_ipu = {}
+ if modules_to_record is None:
+ # It is possible that the IPU implementation of some operators
+ # is inconsistent with the expected (CPU), here you can use
+ # this method to confirm whether there is a problem
+ self.compare_with_cpu = False
+ else:
+ self.compare_with_cpu = True
+ # move model.fp16_enabled to self.fp16_enabled,
+ # modify the position where the input is automatically casted to half
+ if getattr(model, 'fp16_enabled', False):
+ model.fp16_enabled = False
+ self.fp16_enabled = True
+ # make torch.jit.trace convert self._model
+ model = WrappedNet(
+ model,
+ self.inputs_manager,
+ self.outputs_manager,
+ self.inter_outputs_in_cpu,
+ modules_to_record=modules_to_record)
+ super().__init__(model, training=training, *args, **kwargs)
+ # overwrite self._args_parser in train_step or val_step
+ self._args_parser = None
+ if training:
+ assert self.training
+ else:
+ assert not self.training
+
+ @property
+ def training(self):
+ # If trying to get the attribute(training) of self,
+ # since the class has no training attribute,
+ # it will automatically look for the training attribute of self.model.
+ # However, the real attribute we want to check is self._training,
+ # self.model.training and self._training are often inconsistent.
+ # It is not clear whether it is a Poptorch bug or a special design,
+ # temporarily use this function to fix the problem
+ return self._training # comes from self.model._training
+
+ @auto_fp16(supported_types=(PoplarExecutor, ))
+ def run_model(self, data_dict):
+ # this function is used to parse input_dict
+ # and convert to output_dict
+ if self.isCompiled():
+ self.inputs_manager.record_hierarchical_data(data_dict)
+ inputs_tuple = tuple(self.inputs_manager.collect_all_tensors())
+ else:
+ # get tensors out of data and put them in a tuple
+ self.inputs_manager.record_hierarchical_data(data_dict)
+ inputs_tuple = tuple(self.inputs_manager.collect_all_tensors())
+ # turn logger in data manager off after compilation
+ self.inputs_manager.quick()
+ self.outputs_manager.quick()
+
+ # parser args in the first iter
+ if self._args_parser is None:
+ self._args_parser = DictArgsParser({'args': inputs_tuple})
+
+ # run or convert model
+ # the plain_outputs will be used in converting stage
+ plain_outputs = self(inputs_tuple)
+
+ self.inputs_manager.clean_all_tensors()
+
+ # put list of tensors back to the output dict
+ # according to the same order
+ self.outputs_manager.update_all_tensors(plain_outputs)
+ # get the real output dictionary from self.outputs_manager
+ output_dict = self.outputs_manager.hierarchical_data
+
+ # split output_dict into inter_outputs_in_ipu
+ # and output of the torch model
+ torch_model_output = {}
+ for name in output_dict:
+ if name in self.inter_outputs_in_cpu:
+ self.inter_outputs_in_ipu[name] = output_dict[name]
+ else:
+ torch_model_output[name] = output_dict[name]
+
+ if 'output of WrappedNet: single tensor' in output_dict:
+ assert len(torch_model_output) == 1
+ assert isinstance(
+ torch_model_output['output of WrappedNet: single tensor'],
+ torch.Tensor)
+ torch_model_output = \
+ torch_model_output['output of WrappedNet: single tensor']
+
+ return torch_model_output
+
+ def train_step(self, data, optimizer=None, **kwargs):
+ # arguments from mmcls/models/classifiers/base.py:
+ # BaseClassifier.train_step
+ assert self.training
+ assert len(kwargs) == 0 # TODO, support later if necessary
+
+ # TODO support datacontainer as input
+ # currently, auto_fp16 and HierarchicalDataManager take too much
+ # time on traversing datacontainer
+ data['img_metas'] = None
+ num_samples = len(data['img'].data)
+
+ # TODO we will ignore optimizer because it will not be used in model,
+ # support later if necessary
+ data['optimizer'] = None
+ output_dict = self.run_model(data)
+
+ # outputs contained loss, log_vars, num_samples,
+ # only loss(torch.tensor) has been updated
+ # remove all unchanged vars, left torch.tensor
+ neat_output_dict = {'loss': output_dict['loss']}
+
+ # re-parse outputs, get back log_vars and num_samples
+ loss, log_vars = self.model._parse_losses(neat_output_dict)
+ final_output_dict = dict(
+ loss=loss, log_vars=log_vars, num_samples=num_samples)
+ return final_output_dict
+
+ def eval_call(self, img, img_metas=None, return_loss=True, **kwargs):
+ # arguments from mmdet/models/detectors/base.py:BaseDetector.forward
+ # tmp usssage for eval mode
+ assert not self.training
+ assert len(kwargs) == 0 # TODO, support later if necessary
+ assert not return_loss
+ data = {'img': img, 'img_metas': img_metas, 'return_loss': return_loss}
+
+ output_dict = self.run_model(data)
+
+ return output_dict
+
+ def detachFromDevice(self):
+ if self.isCompiled() and self._is_attached:
+ super().detachFromDevice()
+
+ def attachToDevice(self):
+ if self.isCompiled() and not self._is_attached:
+ super().attachToDevice()
+
+
+class TrainEvalModel:
+ """A class maintaining training MMPoplarExecutor and inference
+ MMPoplarExecutor.
+
+ Args:
+ train_model (:obj:`nn.Module`): The training model to be compiled.
+ ``train_model`` can be None if only executing validation.
+ eval_model (:obj:`nn.Module`): The inference model to be compiled.
+ options (mmcv.Config, dict): Options that will be used to compile
+ and run the model.
+ optimizer (:obj:`torch.optim.Optimizer`, optional): torch
+ optimizer, necessary if in training mode
+ logger (:obj:`logging.Logger`): Logger used during running.
+ Defaults to None.
+ modules_to_record (mmcv.Config, list): Index or name of modules which
+ will be recorded for output. It is necessary to specify output for
+ static graph of model training or inference.
+ """
+
+ def __init__(self,
+ train_model,
+ eval_model,
+ options,
+ optimizer,
+ modules_to_record=None,
+ logger=None):
+ if train_model is None:
+ self._train_executor = None
+ self.training = False
+ else:
+ self._train_executor = get_training_model(
+ train_model,
+ options=options['training'],
+ optimizer=optimizer,
+ logger=logger,
+ modules_to_record=modules_to_record)
+ self.training = True
+ self._eval_executor = get_inference_model(
+ eval_model, options=options['inference'], logger=logger)
+
+ @property
+ def executor(self):
+ if self.training:
+ return self._train_executor
+ else:
+ return self._eval_executor
+
+ def train(self, mode: bool = True):
+ """Sets the module in training mode.
+
+ This has any effect only on certain modules. See documentations of
+ particular modules for details of their behaviors in
+ training/evaluation mode, if they are affected,
+ e.g. :class:`Dropout`, :class:`BatchNorm`, etc.
+
+ Args:
+ mode (bool): whether to set training mode (``True``) or evaluation
+ mode (``False``). Default: ``True``.
+
+ Returns:
+ Module: self
+ """
+ if not isinstance(mode, bool):
+ raise ValueError('training mode is expected to be boolean, '
+ f'but got {type(mode)}')
+ if self._train_executor is None and mode:
+ raise RuntimeError(
+ 'The train_executor is not initialized.'
+ 'If you want to initialize train_executor,'
+ 'you need to input optimizer when converting pytorch model')
+
+ if mode == self.training:
+ self.model.train(mode)
+ return self
+ else:
+ if self.isCompiled():
+ # copy weights from IPU to cpu before off-load current session
+ self.copyWeightsToHost()
+ # detach the current session before change the mode,
+ # if is training mode and weights are updated,
+ # poptorch will copy weights from IPU to host
+ self.detachFromDevice()
+
+ self.training = mode # session will changed with mode changing
+ self.model.train(mode)
+
+ # after changing mode, attach the current new session,
+ # and this function will copy weights of model to device
+ self.attachToDevice()
+ return self
+
+ def eval(self):
+ """Sets the module in evaluation mode.
+
+ This has any effect only on certain modules.
+ See documentations of particular modules
+ for details of their behaviors in training/evaluation mode,
+ if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc.
+
+ This is equivalent with :meth:`self.train(False)
+ `.
+
+ See :ref:`locally-disable-grad-doc` for a comparison between
+ `.eval()` and several similar mechanisms that may be confused with it.
+
+ Returns:
+ Module: self
+ """
+ return self.train(False)
+
+ def compare_data_between_ipu_and_cpu(self, inter_outputs_in_cpu,
+ inter_outputs_in_ipu):
+ for key, val in inter_outputs_in_cpu.items():
+ is_tensor = isinstance(val['fea_in'], torch.Tensor)
+ fea_in_cpu = val['fea_in']
+ fea_in_cpu_list = [fea_in_cpu] if is_tensor else fea_in_cpu
+ fea_in_ipu = inter_outputs_in_ipu[key]['fea_in']
+ fea_in_ipu_list = [fea_in_ipu] if is_tensor else fea_in_ipu
+
+ is_tensor = isinstance(val['fea_out'], torch.Tensor)
+ fea_out_cpu = val['fea_out']
+ fea_out_cpu_list = [fea_out_cpu] if is_tensor else fea_out_cpu
+ fea_out_ipu = inter_outputs_in_ipu[key]['fea_out']
+ fea_out_ipu_list = [fea_out_ipu] if is_tensor else fea_out_ipu
+
+ print('comparing layer:', key)
+ for idx, (featA, featB) in \
+ enumerate(zip(fea_in_cpu_list, fea_in_ipu_list)):
+ print('fea_in, tensor ', idx)
+ compare_ndarray(featA.detach().numpy(), featB.detach().numpy())
+ for idx, (featA, featB) in \
+ enumerate(zip(fea_out_cpu_list, fea_out_ipu_list)):
+ print('fea_out, tensor', idx)
+ compare_ndarray(featA.detach().numpy(), featB.detach().numpy())
+
+ # TODO Unified training and eval interface,
+ # merge train_step(train) and __call__(eval) together
+ def train_step(self, data, optimizer=None, **kwargs):
+ assert self.training, 'not supported train_step on eval mode'
+ inter_outputs_in_cpu = {}
+ if (self._train_executor.isCompiled()
+ and self._train_executor.compare_with_cpu):
+ self.copyWeightsToHost()
+ # run in CPU mode
+ self._train_executor.model.train_step(data, optimizer, **kwargs)
+ inter_outputs_in_cpu = {
+ **(self._train_executor.inter_outputs_in_cpu)
+ }
+ # run in IPU mode
+ result = self._train_executor.train_step(data, optimizer, **kwargs)
+ if (self._train_executor.isCompiled()
+ and self._train_executor.compare_with_cpu
+ and len(inter_outputs_in_cpu) > 0):
+ self.compare_data_between_ipu_and_cpu(
+ inter_outputs_in_cpu,
+ self._train_executor.inter_outputs_in_ipu)
+ return result
+
+ # TODO Unified training and eval interface,
+ # merge train_step(train) and __call__(eval) together
+ def __call__(self, *args, **kwargs):
+ if self.training:
+ raise NotImplementedError('use train_step rather than __call__')
+ else:
+ return self._eval_executor.eval_call(*args, **kwargs)
+
+ def __getattr__(self, attr):
+ return getattr(self.executor, attr)
+
+
+def get_training_model(model: nn.Module,
+ options: Optional[poptorch.Options] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ logger=None,
+ modules_to_record=None) -> poptorch.PoplarExecutor:
+ """Create a PopTorch training model from a PyTorch model, running on IPU
+ hardware in training mode.
+
+ Note:
+ PopTorch makes a shallow copy of the model. Changes to the
+ parameters in the returned training model affect the original model
+ and vice versa. However, primitive variable types are not synced: for
+ example calling ``model.train()`` on the original model, which
+ changes the ``training`` bool of the model instance, will not alter the
+ model returned by this function. You may need to call ``model.train()``
+ on your model before you call this function for correct behavior.
+
+ Args:
+ model (:obj:`nn.Module`): The model to run.
+ options (poptorch.Options): Options that will be used to compile
+ and run the model.
+ optimizer (:obj:`torch.optim.Optimizer`, optional): The optimizers
+ to apply during training.
+ logger (:obj:`logging.Logger`): Logger used during running.
+ Defaults to None.
+ modules_to_record (mmcv.Config, list): Index or name of modules which
+ will be recorded for output. It is necessary to specify output for
+ static graph of model training or inference.
+
+ Returns:
+ The :class:`poptorch.PoplarExecutor` wrapper to use in place
+ of ``model``.
+ """
+ # Create a copy of the original model in case it needs to be wrapped
+ maybe_wrapped_model = copy.copy(model)
+
+ return MMPoplarExecutor(
+ model=maybe_wrapped_model,
+ logger=logger,
+ options=options,
+ training=True,
+ optimizer=optimizer,
+ user_model=model,
+ modules_to_record=modules_to_record,
+ poptorch_version=__version__)
+
+
+def get_inference_model(model: Union[nn.Module, poptorch.PoplarExecutor],
+ options: Optional[poptorch.Options] = None,
+ logger=None) -> poptorch.PoplarExecutor:
+ """Create a PopTorch inference model from a PyTorch model, running on IPU
+ hardware in inference mode.
+
+ Note:
+ PopTorch makes a shallow copy of the model. Changes to the
+ parameters in the returned inference model affect the original model
+ and vice versa. However, primitive variable types are not synced: for
+ example calling ``model.eval()`` on the original model will not alter
+ the model returned by this function. You may need to call
+ ``model.eval()`` on your model before you call this function for
+ correct behavior.
+
+ Args:
+ model (:obj:`nn.Module`): The model to run.
+ options (poptorch.Options): Options that will be used to compile
+ and run the model.
+ logger (:obj:`logging.Logger`): Logger used during running.
+ Defaults to None.
+
+ Returns:
+ The :class:`poptorch.PoplarExecutor` wrapper to use in place of
+ ``model``.
+ """
+
+ return MMPoplarExecutor(
+ model=copy.copy(model),
+ logger=logger,
+ options=options,
+ training=False,
+ poptorch_version=__version__)
+
+
+def ipu_model_wrapper(model,
+ options,
+ optimizer=None,
+ logger=None,
+ modules_to_record=None,
+ ipu_model_cfg=None,
+ fp16_cfg=None):
+ """Convert torch model to IPU model.
+
+ Args:
+ model (nn.Module): The target model to be converted.
+ options (dict[str, poptorch.Options]): IPU options, generated
+ by :func:`cfg2options`.
+ optimizer (:obj:`torch.optim.Optimizer`, optional): torch
+ optimizer, necessary if in training mode
+ logger (:obj:`logging.Logger`): Logger used during training.
+ modules_to_record (mmcv.Config, list): Index or name of modules which
+ will be recorded for output. It is necessary to specify output for
+ static graph of model training or inference.
+ ipu_model_cfg (dict): A dictionary contains train_split_edges and
+ train_ckpt_nodes, See details in :func:`model_sharding` and
+ :func:`recomputation_checkpoint` functions.
+ fp16_cfg (dict): Config for IPU fp16 training. Currently supports
+ configs: `loss_scale`, `velocity_accum_type` and `accum_type`.
+ See details in
+ https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/index.html
+
+ Returns:
+ TrainEvalModel: IPU wrapped model.
+ """
+ if ipu_model_cfg is None:
+ ipu_model_cfg = {}
+ training = model.training if optimizer is not None else False
+ # set mixed-precision
+ if fp16_cfg is not None:
+ from mmcv.runner import wrap_fp16_model
+ loss_scale = fp16_cfg['loss_scale']
+ wrap_fp16_model(model)
+ model.half()
+ # TODO tmp ussage to set loss scaling for torch original optimizer
+ if optimizer is not None:
+ optimizer.loss_scaling = loss_scale
+ if fp16_cfg.get('velocity_accum_type', False):
+ if fp16_cfg['velocity_accum_type'] == 'half':
+ optimizer.velocity_accum_type = torch.half
+ else:
+ optimizer.velocity_accum_type = torch.float32
+ if fp16_cfg.get('accum_type', False):
+ if fp16_cfg['accum_type'] == 'half':
+ optimizer.accum_type = torch.half
+ else:
+ optimizer.accum_type = torch.float32
+ # TODO support feature alignment for fp16
+ if modules_to_record is not None:
+ raise NotImplementedError(
+ 'Feature alignment for fp16 is not implemented')
+
+ # set model partition
+ if optimizer is None:
+ train_model = None
+ else:
+ # split model into multi-IPUs if specified
+ train_model = model_sharding(
+ copy.copy(model).train(),
+ ipu_model_cfg.get('train_split_edges', []))
+
+ recomputation_checkpoint(train_model,
+ ipu_model_cfg.get('train_ckpt_nodes', []))
+
+ # TODO support feature alignment for gradient accumulation mode
+ gradient_accumulation = \
+ getattr(options['training'].Training, 'gradient_accumulation', 1)
+ if gradient_accumulation > 1:
+ assert modules_to_record is None, \
+ 'Feature alignment for grad-accumulation mode not implemented'
+
+ # TODO support feature alignment for multi-replica mode
+ replication_factor = \
+ getattr(options['training'], 'replication_factor', 1)
+ if replication_factor > 1:
+ assert modules_to_record is None, \
+ 'Feature alignment for multi-replica mode not implemented'
+
+ # TODO supports different model partitions between train and eval mode
+ assert len(ipu_model_cfg.get('eval_split_edges', [])) == 0,\
+ 'Currently, BeginBlock can only be used once on the same model'
+ eval_model = copy.copy(model).eval()
+
+ # wrap model for compilation
+ model = TrainEvalModel(
+ train_model,
+ eval_model,
+ options=options,
+ optimizer=optimizer,
+ logger=logger,
+ modules_to_record=modules_to_record)
+ model.train(training)
+ return model
diff --git a/mmcv/mmcv/device/ipu/runner.py b/mmcv/mmcv/device/ipu/runner.py
new file mode 100755
index 0000000000000000000000000000000000000000..e2d4922677e08b2d6b5132a01034de8b043fa3f1
--- /dev/null
+++ b/mmcv/mmcv/device/ipu/runner.py
@@ -0,0 +1,142 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from mmcv.runner import (HOOKS, RUNNERS, BaseRunner, EpochBasedRunner,
+ IterBasedRunner)
+from mmcv.utils import IS_IPU_AVAILABLE
+
+if IS_IPU_AVAILABLE:
+ from .dataloader import IPUDataLoader
+ from .hook_wrapper import (IPUFp16OptimizerHook, wrap_lr_updater_hook,
+ wrap_optimizer_hook)
+ from .model_wrapper import ipu_model_wrapper
+ from .utils import build_from_cfg_with_wrapper, cfg2options
+
+
+class IPUBaseRunner(BaseRunner):
+ """A base runner for IPU.
+
+ This runner has some extra processes for IPU which are shown below:
+
+ 1. Parse options for IPU
+ 2. wrap pytorch model for IPU
+ 3. Raise errors while encountering illegal usage
+ 4. Input IPU options and initialize dataloader if finding an instance
+ of IPUDataLoader
+
+ Args:
+ model (:obj:`nn.Module`): The model to run.
+ options_cfg (mmcv.Config, dict): Options that will be used to compile
+ and run the model.
+ modules_to_record (mmcv.Config, list): Index or name of modules which
+ will be recorded for output. It is necessary to specify output for
+ static graph of model training or inference.
+ ipu_model_cfg (mmcv.Config, dict): Config of model partition and
+ recomputing checkpoint
+ fp16_cfg (mmcv.Config): Config for fp16 training.
+ batch_processor (callable): A callable method that process a data
+ batch. Should be None for IPU runner
+ kwargs (Dict[str, Any], optional): Keyword arguments will be passed to
+ ``base_runner.BaseRunner``.
+ """
+
+ def __init__(self,
+ model,
+ options_cfg=None,
+ modules_to_record=None,
+ ipu_model_cfg=None,
+ fp16_cfg=None,
+ batch_processor=None,
+ **kwargs):
+ assert hasattr(model, 'train_step') and batch_processor is None,\
+ 'only support model with train_step'
+
+ if options_cfg is None:
+ options_cfg = {}
+ # call BaseRunner.__init__() here
+ super().__init__(model, **kwargs)
+
+ # process options of ipu
+ if IS_IPU_AVAILABLE:
+ self.options = cfg2options(options_cfg)
+ self.model = ipu_model_wrapper(
+ self.model,
+ self.options,
+ self.optimizer,
+ self.logger,
+ modules_to_record=modules_to_record,
+ ipu_model_cfg=ipu_model_cfg,
+ fp16_cfg=fp16_cfg)
+ else:
+ raise NotImplementedError('cpu mode on IPURunner is not supported')
+
+ def register_lr_hook(self, lr_config):
+ if lr_config is None:
+ return
+ assert isinstance(lr_config, dict)
+ assert 'policy' in lr_config
+ policy_type = lr_config.pop('policy')
+ # If the type of policy is all in lower case,
+ # e.g., 'cyclic', then its first letter will be capitalized,
+ # e.g., to be 'Cyclic'.
+ # This is for the convenient usage of Lr updater.
+ # Since this is not applicable for `
+ # CosineAnnealingLrUpdater`, the string will not be changed
+ # if it contains capital letters.
+ if policy_type == policy_type.lower():
+ policy_type = policy_type.title()
+ hook_type = policy_type + 'LrUpdaterHook'
+ lr_config['type'] = hook_type
+ hook = build_from_cfg_with_wrapper(lr_config, HOOKS,
+ wrap_lr_updater_hook)
+ self.register_hook(hook, priority='VERY_HIGH')
+
+ def register_optimizer_hook(self, optimizer_config):
+ if optimizer_config is None:
+ return
+ assert isinstance(optimizer_config, (dict, IPUFp16OptimizerHook))
+ if isinstance(optimizer_config, dict):
+ optimizer_config.setdefault('type', 'OptimizerHook')
+ hook = build_from_cfg_with_wrapper(optimizer_config, HOOKS,
+ wrap_optimizer_hook)
+ else:
+ hook = optimizer_config
+ self.register_hook(hook, priority='ABOVE_NORMAL')
+
+ def run(self, data_loaders, workflow, *args, **kwargs):
+ for i, flow in enumerate(workflow):
+ mode, _ = flow
+ # initialize IPU dataloader if not initialized
+ assert isinstance(data_loaders[i], IPUDataLoader),\
+ 'IPU runner can only work with `IPUDataLoader`'
+ data_loaders[i].init(options=self.get_options(mode))
+
+ super().run(data_loaders, workflow, *args, **kwargs)
+
+ def get_options(self, mode):
+ if mode == 'train':
+ return self.options['training']
+ elif mode == 'val':
+ return self.options['inference']
+ else:
+ raise ValueError(f'mode should be train or val but got {mode}')
+
+
+@RUNNERS.register_module()
+class IPUEpochBasedRunner(IPUBaseRunner, EpochBasedRunner):
+ """Epoch-based Runner for IPU.
+
+ The Inheritance order(MRO) is: IPUEpochBasedRunner -> IPUBaseRunner ->
+ EpochBasedRunner -> BaseRunner This runner train models epoch by epoch.
+ """
+ pass
+
+
+@RUNNERS.register_module()
+class IPUIterBasedRunner(IPUBaseRunner, IterBasedRunner):
+ """Iteration-based Runner for IPU.
+
+ The Inheritance order(MRO) is: IPUIterBasedRunner -> IPUBaseRunner ->
+ IterBasedRunner -> BaseRunner This runner train models iteration by
+ iteration.
+ """
+ pass
diff --git a/mmcv/mmcv/device/ipu/utils.py b/mmcv/mmcv/device/ipu/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..79709db1ee1282e8daa6614ceb23481d3cd58338
--- /dev/null
+++ b/mmcv/mmcv/device/ipu/utils.py
@@ -0,0 +1,244 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+
+import numpy as np
+import popart
+import poptorch
+import torch
+import torch.nn as nn
+
+from mmcv.utils import Registry
+
+
+def _options_assigner(cfg, options_node):
+ # set popart.options by config
+ # cfg: dict, python data type
+ # options_node: python module or function
+ if isinstance(cfg, dict):
+ for key in cfg:
+ _options_assigner(cfg[key], getattr(options_node, key))
+ elif isinstance(cfg, (int, float, str, list)):
+ if callable(options_node):
+ options_node(cfg)
+ else:
+ error_msg = f'options_node type {type(options_node)} not supported'
+ raise NotImplementedError(error_msg)
+ else:
+ error_msg = f'cfg type {type(cfg)} not supported'
+ raise NotImplementedError(error_msg)
+
+
+def cfg2options(cfg):
+ """Parse dictionary to ipu options.
+
+ Args:
+ cfg (dict): A dictionary of ipu settings.
+
+ Returns:
+ dict[str, poptorch.Options]: Training options and inference options
+ of IPU.
+ """
+ # set ipu options for inference and training by config
+ train_cfg = cfg.pop('train_cfg', {})
+ eval_cfg = cfg.pop('eval_cfg', {})
+ eval_cfg['replicationFactor'] = 1 # eval mode only use one replica
+ eval_cfg['executionStrategy'] = 'ShardedExecution'
+ # overwrite default ipu cfg with specified train cfgs
+ training_ipu_cfg = {**cfg, **train_cfg}
+ # overwrite default ipu cfg with specified eval cfgs
+ inference_ipu_cfg = {**cfg, **eval_cfg}
+
+ ipu_options = {
+ 'training': _cast_to_options(training_ipu_cfg),
+ 'inference': _cast_to_options(inference_ipu_cfg)
+ }
+
+ # TODO configure these codes
+ ipu_options['training']._Popart.set('disableGradAccumulationTensorStreams',
+ True)
+ ipu_options['training']._Popart.set(
+ 'accumulateOuterFragmentSettings.schedule',
+ int(popart.AccumulateOuterFragmentSchedule.OverlapMemoryOptimized))
+ ipu_options['training'].Precision.enableStochasticRounding(True)
+
+ return ipu_options
+
+
+def _cast_to_options(cfg):
+ # If it cannot be directly assigned, use if statement to parse it,
+ # and if it can be directly assigned, use _options_assigner to assign
+ options = poptorch.Options()
+
+ if 'availableMemoryProportion' in cfg:
+ available_memory_proportion = cfg.pop('availableMemoryProportion')
+ mem_props = {}
+ for i, mem_prop in enumerate(available_memory_proportion):
+ mem_props[f'IPU{i}'] = mem_prop
+ options.setAvailableMemoryProportion(mem_props)
+
+ if 'executionStrategy' in cfg:
+ execution_strategy = cfg.pop('executionStrategy')
+ if execution_strategy == 'SameAsIpu':
+ options.setExecutionStrategy(
+ poptorch.PipelinedExecution(
+ getattr(poptorch.AutoStage, execution_strategy)))
+ elif execution_strategy == 'ShardedExecution':
+ options.setExecutionStrategy(poptorch.ShardedExecution())
+ else:
+ raise NotImplementedError(
+ 'executionStrategy should be "SameAsIpu" or "ShardedExecution"'
+ f', but got {execution_strategy}')
+
+ if 'partialsType' in cfg:
+ partials_type = cfg.pop('partialsType')
+ options.Precision.setPartialsType(getattr(
+ torch, partials_type)) # half or float
+
+ _options_assigner(cfg, options)
+ return options
+
+
+def model_sharding(model, split_edges):
+ """split models in-place into multi-IPUs.
+
+ Args:
+ model (nn.Module): The target model to be split.
+ split_edges (list of dict): Model layer names or layer numbers
+ of split edge. Each item of ``split_edges`` is a dictionary,
+ which may contain the following key-pairs:
+
+ - layer_to_call: PyTorch module to assign to the block
+ - user_id (optional): A user defined identifier for the block.
+ - ipu_id: The id of the IPU to run on.
+
+ Examples:
+ >>> split_edges = [
+ ... dict(layer_to_call='model.conv1', ipu_id=0),
+ ... dict(layer_to_call='model.conv3', ipu_id=1)]
+ >>> sharding_model = model_sharding(torch_model, split_edges)
+
+ Returns:
+ nn.Module: Split model.
+ """
+ if len(split_edges) == 0:
+ return model
+ assert isinstance(split_edges, list)
+ spilt_edges_dict = {edge['layer_to_call']: edge for edge in split_edges}
+
+ for idx, (name, module) in enumerate(model.named_modules()):
+ if idx in spilt_edges_dict and name in spilt_edges_dict:
+ raise ValueError(
+ 'The same layer is referenced twice while doing model'
+ f' partition: idx is {idx} and name is {name}')
+
+ edge = spilt_edges_dict.pop(name, None)
+ edge = spilt_edges_dict.pop(idx, edge)
+ if edge is not None:
+ poptorch.BeginBlock(module, edge.get('user_id', name),
+ edge['ipu_id'])
+
+ # ensure all split_edges are used
+ if len(spilt_edges_dict) > 0:
+ split_edge_names = list(spilt_edges_dict.keys())
+ raise RuntimeError(
+ f'split_edges: {split_edge_names} are not contained in the model')
+ return model
+
+
+def recomputation_checkpoint(model: nn.Module, module_names: list):
+ """Annotates the output of a module to be checkpointed instead of
+ recomputed.
+
+ If recomputation mode is enabled, ipu will release the activations of
+ the middle layers to save memory. During the backward of gradient,
+ the activation of the middle layer will be recalculated again.
+ This function is used to declare the activations of some intermediate
+ layers that need to be saved in order to skip the recomputation of
+ some layers.
+
+ Args:
+ model (nn.Module): The target model to apply recomputation
+ checkpoint.
+ module_names (list): Layer names of module.
+ """
+
+ def recompute_outputs(module, inputs, outputs):
+ if isinstance(outputs, tuple):
+ return tuple(poptorch.recomputationCheckpoint(y) for y in outputs)
+ else:
+ return poptorch.recomputationCheckpoint(outputs)
+
+ for name, module in model.named_modules():
+ if name in module_names:
+ module.register_forward_hook(recompute_outputs)
+ module_names.remove(name)
+
+ # check all module_names are used
+ assert len(module_names) == 0,\
+ f'recomputed nodes: {module_names} are not contained in the model'
+
+
+def compare_ndarray(featA, featB, rtol=1e-3, atol=1e-5):
+ """Align data between two activations or weights."""
+ try:
+ np.testing.assert_allclose(featA, featB, rtol=rtol, atol=atol)
+ except AssertionError as e:
+ print(e)
+
+
+def build_from_cfg_with_wrapper(cfg,
+ registry,
+ wrapper_func=None,
+ default_args=None):
+ """Build a module from config dict and wrap module with "wrapper_func".
+
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+ registry (:obj:`Registry`): The registry to search the type from.
+ default_args (dict, optional): Default initialization arguments.
+ wrapper_func (function): Used to wrap class
+
+ Returns:
+ object: The constructed object.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+ if 'type' not in cfg:
+ if default_args is None or 'type' not in default_args:
+ raise KeyError(
+ '`cfg` or `default_args` must contain the key "type", '
+ f'but got {cfg}\n{default_args}')
+ if not isinstance(registry, Registry):
+ raise TypeError('registry must be an mmcv.Registry object, '
+ f'but got {type(registry)}')
+ if not (isinstance(default_args, dict) or default_args is None):
+ raise TypeError('default_args must be a dict or None, '
+ f'but got {type(default_args)}')
+
+ args = cfg.copy()
+
+ if default_args is not None:
+ for name, value in default_args.items():
+ args.setdefault(name, value)
+
+ obj_type = args.pop('type')
+ if isinstance(obj_type, str):
+ obj_cls = registry.get(obj_type)
+ if obj_cls is None:
+ raise KeyError(
+ f'{obj_type} is not in the {registry.name} registry')
+ elif inspect.isclass(obj_type):
+ obj_cls = obj_type
+ else:
+ raise TypeError(
+ f'type must be a str or valid type, but got {type(obj_type)}')
+
+ if wrapper_func is None:
+ wrapped_obj_cls = obj_cls
+ else:
+ wrapped_obj_cls = wrapper_func(obj_cls)
+ try:
+ return wrapped_obj_cls(**args)
+ except Exception as e:
+ # Normal TypeError does not print class name.
+ raise type(e)(f'{wrapped_obj_cls.__name__}: {e}')
diff --git a/mmcv/mmcv/device/mlu/__init__.py b/mmcv/mmcv/device/mlu/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..77c71ccf3ce38f3cbc9911f1d9d4b05a531771f2
--- /dev/null
+++ b/mmcv/mmcv/device/mlu/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .data_parallel import MLUDataParallel
+from .distributed import MLUDistributedDataParallel
+
+__all__ = ['MLUDataParallel', 'MLUDistributedDataParallel']
diff --git a/mmcv/mmcv/device/mlu/_functions.py b/mmcv/mmcv/device/mlu/_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..75660fa9b3635fed049cb150639244a658534824
--- /dev/null
+++ b/mmcv/mmcv/device/mlu/_functions.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Union
+
+import torch
+
+
+def scatter(input: Union[List, torch.Tensor], devices: List) -> List:
+ """scatter copies tensor to MLU directly."""
+ if isinstance(input, list):
+ outputs = [scatter(_input, devices) for _input in input]
+ return outputs
+ elif isinstance(input, torch.Tensor):
+ output = input.contiguous()
+ return output.to('mlu') if devices != [-1] else output
+ else:
+ raise Exception(f'Unknown type {type(input)}.')
+
+
+class Scatter:
+
+ @staticmethod
+ def forward(target_mlus, input):
+ outputs = scatter(input, target_mlus)
+ return tuple(outputs) if isinstance(outputs, list) else (outputs, )
diff --git a/mmcv/mmcv/device/mlu/data_parallel.py b/mmcv/mmcv/device/mlu/data_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebe14c0a55c92f96ec7f782a591ac10b007942dc
--- /dev/null
+++ b/mmcv/mmcv/device/mlu/data_parallel.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import torch
+
+from mmcv.parallel import MMDataParallel
+from .scatter_gather import scatter_kwargs
+
+
+class MLUDataParallel(MMDataParallel):
+ """The MLUDataParallel module that supports DataContainer.
+
+ MLUDataParallel is a class inherited from MMDataParall, which supports
+ MLU training and inference only.
+
+ The main differences with MMDataParallel:
+
+ - It only supports single-card of MLU, and only use first card to
+ run training and inference.
+
+ - It uses direct host-to-device copy instead of stream-background
+ scatter.
+
+ .. warning::
+ MLUDataParallel only supports single MLU training, if you need to
+ train with multiple MLUs, please use MLUDistributedDataParallel
+ instead. If you have multiple MLUs, you can set the environment
+ variable ``MLU_VISIBLE_DEVICES=0`` (or any other card number(s))
+ to specify the running device.
+
+ Args:
+ module (:class:`nn.Module`): Module to be encapsulated.
+ dim (int): Dimension used to scatter the data. Defaults to 0.
+ """
+
+ def __init__(self, *args, dim=0, **kwargs):
+ super().__init__(*args, dim=dim, **kwargs)
+ self.device_ids = [0]
+ self.src_device_obj = torch.device('mlu:0')
+
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
diff --git a/mmcv/mmcv/device/mlu/distributed.py b/mmcv/mmcv/device/mlu/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..3768c754c908b219fd5a770d69e6ed5416781ba8
--- /dev/null
+++ b/mmcv/mmcv/device/mlu/distributed.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from mmcv.parallel import MMDistributedDataParallel
+from .scatter_gather import scatter_kwargs
+
+
+class MLUDistributedDataParallel(MMDistributedDataParallel):
+ """The DDP module supports DataContainer.
+
+ MLUDDP has one difference from MMDDP which moves data to MLU with coping
+ instead of scattering.
+ """
+
+ def to_kwargs(self, inputs, kwargs, device_id):
+ # Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
+ # to move all tensors to device_id
+ return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
+
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
diff --git a/mmcv/mmcv/device/mlu/scatter_gather.py b/mmcv/mmcv/device/mlu/scatter_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b0c9b96f51252e4c510f66a2ec5fb7522716e29
--- /dev/null
+++ b/mmcv/mmcv/device/mlu/scatter_gather.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmcv.parallel.data_container import DataContainer
+from ._functions import Scatter
+
+
+def scatter(inputs, target_mlus, dim=0):
+ """Scatter inputs to target mlu.
+
+ The only difference from original :func:`scatter` is to add support for
+ :type:`~mmcv.parallel.DataContainer`.
+ """
+
+ def scatter_map(obj):
+ if isinstance(obj, torch.Tensor):
+ if target_mlus != [-1]:
+ obj = obj.to('mlu')
+ return [obj]
+ else:
+ # for CPU inference we use self-implemented scatter
+ return Scatter.forward(target_mlus, obj)
+ if isinstance(obj, DataContainer):
+ if obj.cpu_only:
+ return obj.data
+ else:
+ return Scatter.forward(target_mlus, obj.data)
+ if isinstance(obj, tuple) and len(obj) > 0:
+ return list(zip(*map(scatter_map, obj)))
+ if isinstance(obj, list) and len(obj) > 0:
+ out = list(map(list, zip(*map(scatter_map, obj))))
+ return out
+ if isinstance(obj, dict) and len(obj) > 0:
+ out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
+ return out
+ return [obj for targets in target_mlus]
+
+ # After scatter_map is called, a scatter_map cell will exist. This cell
+ # has a reference to the actual function scatter_map, which has references
+ # to a closure that has a reference to the scatter_map cell (because the
+ # fn is recursive). To avoid this reference cycle, we set the function to
+ # None, clearing the cell
+ try:
+ return scatter_map(inputs)
+ finally:
+ scatter_map = None
+
+
+def scatter_kwargs(inputs, kwargs, target_mlus, dim=0):
+ """Scatter with support for kwargs dictionary."""
+ inputs = scatter(inputs, target_mlus, dim) if inputs else []
+ kwargs = scatter(kwargs, target_mlus, dim) if kwargs else []
+ if len(inputs) < len(kwargs):
+ inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
+ elif len(kwargs) < len(inputs):
+ kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
+ inputs = tuple(inputs)
+ kwargs = tuple(kwargs)
+ return inputs, kwargs
diff --git a/mmcv/mmcv/device/mps/__init__.py b/mmcv/mmcv/device/mps/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e28144ef0ae8cf65527cefc469d07c7ff854c688
--- /dev/null
+++ b/mmcv/mmcv/device/mps/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .data_parallel import MPSDataParallel
+
+__all__ = ['MPSDataParallel']
diff --git a/mmcv/mmcv/device/mps/data_parallel.py b/mmcv/mmcv/device/mps/data_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ae5396d24193376432ae98b792ec89fac678738
--- /dev/null
+++ b/mmcv/mmcv/device/mps/data_parallel.py
@@ -0,0 +1,34 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+import torch
+
+from mmcv.parallel import MMDataParallel
+from ..scatter_gather import scatter_kwargs
+
+
+class MPSDataParallel(MMDataParallel):
+ """The MPSDataParallel module that supports DataContainer.
+
+ MPSDataParallel is a class inherited from MMDataParall, which supports
+ MPS training and inference only.
+
+ The main differences with MMDataParallel:
+
+ - It only supports single-card of MPS, and only use first card to
+ run training and inference.
+
+ - It uses direct host-to-device copy instead of stream-background
+ scatter.
+
+ Args:
+ module (:class:`nn.Module`): Module to be encapsulated.
+ dim (int): Dimension used to scatter the data. Defaults to 0.
+ """
+
+ def __init__(self, *args, dim=0, **kwargs):
+ super().__init__(*args, dim=dim, **kwargs)
+ self.device_ids = [0]
+ self.src_device_obj = torch.device('mps:0')
+
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
diff --git a/mmcv/mmcv/device/scatter_gather.py b/mmcv/mmcv/device/scatter_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..744b0ca51e9de4cb7c43d60a986621461519f781
--- /dev/null
+++ b/mmcv/mmcv/device/scatter_gather.py
@@ -0,0 +1,64 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmcv.parallel.data_container import DataContainer
+from mmcv.utils import deprecated_api_warning
+from ._functions import Scatter
+from .utils import get_device
+
+
+@deprecated_api_warning({'target_mlus': 'target_devices'})
+def scatter(inputs, target_devices, dim=0):
+ """Scatter inputs to target devices.
+
+ The only difference from original :func:`scatter` is to add support for
+ :type:`~mmcv.parallel.DataContainer`.
+ """
+ current_device = get_device()
+
+ def scatter_map(obj):
+ if isinstance(obj, torch.Tensor):
+ if target_devices != [-1]:
+ obj = obj.to(current_device)
+ return [obj]
+ else:
+ # for CPU inference we use self-implemented scatter
+ return Scatter.forward(target_devices, obj)
+ if isinstance(obj, DataContainer):
+ if obj.cpu_only:
+ return obj.data
+ else:
+ return Scatter.forward(target_devices, obj.data)
+ if isinstance(obj, tuple) and len(obj) > 0:
+ return list(zip(*map(scatter_map, obj)))
+ if isinstance(obj, list) and len(obj) > 0:
+ out = list(map(list, zip(*map(scatter_map, obj))))
+ return out
+ if isinstance(obj, dict) and len(obj) > 0:
+ out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
+ return out
+ return [obj for _ in target_devices]
+
+ # After scatter_map is called, a scatter_map cell will exist. This cell
+ # has a reference to the actual function scatter_map, which has references
+ # to a closure that has a reference to the scatter_map cell (because the
+ # fn is recursive). To avoid this reference cycle, we set the function to
+ # None, clearing the cell
+ try:
+ return scatter_map(inputs)
+ finally:
+ scatter_map = None
+
+
+@deprecated_api_warning({'target_mlus': 'target_devices'})
+def scatter_kwargs(inputs, kwargs, target_devices, dim=0):
+ """Scatter with support for kwargs dictionary."""
+ inputs = scatter(inputs, target_devices, dim) if inputs else []
+ kwargs = scatter(kwargs, target_devices, dim) if kwargs else []
+ if len(inputs) < len(kwargs):
+ inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
+ elif len(kwargs) < len(inputs):
+ kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
+ inputs = tuple(inputs)
+ kwargs = tuple(kwargs)
+ return inputs, kwargs
diff --git a/mmcv/mmcv/device/utils.py b/mmcv/mmcv/device/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2adec08dd98ad83cce3a9c28d3a6651808f7112
--- /dev/null
+++ b/mmcv/mmcv/device/utils.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE
+
+
+def get_device() -> str:
+ """Returns the currently existing device type.
+
+ Returns:
+ str: cuda | mlu | mps | cpu.
+ """
+ if IS_CUDA_AVAILABLE:
+ return 'cuda'
+ elif IS_MLU_AVAILABLE:
+ return 'mlu'
+ elif IS_MPS_AVAILABLE:
+ return 'mps'
+ else:
+ return 'cpu'
diff --git a/mmcv/mmcv/engine/__init__.py b/mmcv/mmcv/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3193b7f664e19ce2458d81c836597fa22e4bb082
--- /dev/null
+++ b/mmcv/mmcv/engine/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test,
+ single_gpu_test)
+
+__all__ = [
+ 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test',
+ 'single_gpu_test'
+]
diff --git a/mmcv/mmcv/engine/test.py b/mmcv/mmcv/engine/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..83546caec47fb11952fd820b342c71b83b74fac2
--- /dev/null
+++ b/mmcv/mmcv/engine/test.py
@@ -0,0 +1,213 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+import time
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.utils.data import DataLoader
+
+import mmcv
+from mmcv.runner import get_dist_info
+
+
+def single_gpu_test(model: nn.Module, data_loader: DataLoader) -> list:
+ """Test model with a single gpu.
+
+ This method tests model with a single gpu and displays test progress bar.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for data in data_loader:
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+ results.extend(result)
+
+ # Assume result has the same length of batch_size
+ # refer to https://github.com/open-mmlab/mmcv/issues/985
+ batch_size = len(result)
+ for _ in range(batch_size):
+ prog_bar.update()
+ return results
+
+
+def multi_gpu_test(model: nn.Module,
+ data_loader: DataLoader,
+ tmpdir: Optional[str] = None,
+ gpu_collect: bool = False) -> Optional[list]:
+ """Test model with multiple gpus.
+
+ This method tests model with multiple gpus and collects the results
+ under two different modes: gpu and cpu modes. By setting
+ ``gpu_collect=True``, it encodes results to gpu tensors and use gpu
+ communication for results collection. On cpu mode it saves the results on
+ different gpus to ``tmpdir`` and collects them by the rank 0 worker.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+ tmpdir (str): Path of directory to save the temporary results from
+ different gpus under cpu mode.
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
+
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+ if rank == 0:
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ time.sleep(2) # This line can prevent deadlock problem in some cases.
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+ results.extend(result)
+
+ if rank == 0:
+ batch_size = len(result)
+ batch_size_all = batch_size * world_size
+ if batch_size_all + prog_bar.completed > len(dataset):
+ batch_size_all = len(dataset) - prog_bar.completed
+ for _ in range(batch_size_all):
+ prog_bar.update()
+
+ # collect results from all ranks
+ if gpu_collect:
+ result_from_ranks = collect_results_gpu(results, len(dataset))
+ else:
+ result_from_ranks = collect_results_cpu(results, len(dataset), tmpdir)
+ return result_from_ranks
+
+
+def collect_results_cpu(result_part: list,
+ size: int,
+ tmpdir: Optional[str] = None) -> Optional[list]:
+ """Collect results under cpu mode.
+
+ On cpu mode, this function will save the results on different gpus to
+ ``tmpdir`` and collect them by the rank 0 worker.
+
+ Args:
+ result_part (list): Result list containing result parts
+ to be collected.
+ size (int): Size of the results, commonly equal to length of
+ the results.
+ tmpdir (str | None): temporal directory for collected results to
+ store. If set to None, it will create a random temporal directory
+ for it.
+
+ Returns:
+ list: The collected results.
+ """
+ rank, world_size = get_dist_info()
+ # create a tmp dir if it is not specified
+ if tmpdir is None:
+ MAX_LEN = 512
+ # 32 is whitespace
+ dir_tensor = torch.full((MAX_LEN, ),
+ 32,
+ dtype=torch.uint8,
+ device='cuda')
+ if rank == 0:
+ mmcv.mkdir_or_exist('.dist_test')
+ tmpdir = tempfile.mkdtemp(dir='.dist_test')
+ tmpdir = torch.tensor(
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
+ dir_tensor[:len(tmpdir)] = tmpdir
+ dist.broadcast(dir_tensor, 0)
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+ else:
+ mmcv.mkdir_or_exist(tmpdir)
+ # dump the part result to the dir
+ part_file = osp.join(tmpdir, f'part_{rank}.pkl') # type: ignore
+ mmcv.dump(result_part, part_file)
+ dist.barrier()
+ # collect all parts
+ if rank != 0:
+ return None
+ else:
+ # load results of all parts from tmp dir
+ part_list = []
+ for i in range(world_size):
+ part_file = osp.join(tmpdir, f'part_{i}.pkl') # type: ignore
+ part_result = mmcv.load(part_file)
+ # When data is severely insufficient, an empty part_result
+ # on a certain gpu could makes the overall outputs empty.
+ if part_result:
+ part_list.append(part_result)
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ # remove tmp dir
+ shutil.rmtree(tmpdir) # type: ignore
+ return ordered_results
+
+
+def collect_results_gpu(result_part: list, size: int) -> Optional[list]:
+ """Collect results under gpu mode.
+
+ On gpu mode, this function will encode results to gpu tensors and use gpu
+ communication for results collection.
+
+ Args:
+ result_part (list): Result list containing result parts
+ to be collected.
+ size (int): Size of the results, commonly equal to length of
+ the results.
+
+ Returns:
+ list: The collected results.
+ """
+ rank, world_size = get_dist_info()
+ # dump result part to tensor with pickle
+ part_tensor = torch.tensor(
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+ # gather all result part tensor shape
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
+ dist.all_gather(shape_list, shape_tensor)
+ # padding result part tensor to max length
+ shape_max = torch.tensor(shape_list).max()
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+ part_send[:shape_tensor[0]] = part_tensor
+ part_recv_list = [
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
+ ]
+ # gather all result part
+ dist.all_gather(part_recv_list, part_send)
+
+ if rank == 0:
+ part_list = []
+ for recv, shape in zip(part_recv_list, shape_list):
+ part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
+ # When data is severely insufficient, an empty part_result
+ # on a certain gpu could makes the overall outputs empty.
+ if part_result:
+ part_list.append(part_result)
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ return ordered_results
+ else:
+ return None
diff --git a/mmcv/mmcv/fileio/__init__.py b/mmcv/mmcv/fileio/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2051b85f7e59bff7bdbaa131849ce8cd31f059a4
--- /dev/null
+++ b/mmcv/mmcv/fileio/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .file_client import BaseStorageBackend, FileClient
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+from .io import dump, load, register_handler
+from .parse import dict_from_file, list_from_file
+
+__all__ = [
+ 'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler',
+ 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler',
+ 'list_from_file', 'dict_from_file'
+]
diff --git a/mmcv/mmcv/fileio/file_client.py b/mmcv/mmcv/fileio/file_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee7c3164e2c631c546dfe3345c45f8b8394a9995
--- /dev/null
+++ b/mmcv/mmcv/fileio/file_client.py
@@ -0,0 +1,1173 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import os
+import os.path as osp
+import re
+import tempfile
+import warnings
+from abc import ABCMeta, abstractmethod
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Any, Generator, Iterator, Optional, Tuple, Union
+from urllib.request import urlopen
+
+import mmcv
+from mmcv.utils.misc import has_method
+from mmcv.utils.path import is_filepath
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+ """Abstract class of storage backends.
+
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+ as texts.
+ """
+
+ # a flag to indicate whether the backend can create a symlink for a file
+ _allow_symlink = False
+
+ @property
+ def name(self):
+ return self.__class__.__name__
+
+ @property
+ def allow_symlink(self):
+ return self._allow_symlink
+
+ @abstractmethod
+ def get(self, filepath):
+ pass
+
+ @abstractmethod
+ def get_text(self, filepath):
+ pass
+
+
+class CephBackend(BaseStorageBackend):
+ """Ceph storage backend (for internal use).
+
+ Args:
+ path_mapping (dict|None): path mapping dict from local path to Petrel
+ path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath``
+ will be replaced by ``dst``. Default: None.
+
+ .. warning::
+ :class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
+ please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
+ """
+
+ def __init__(self, path_mapping=None):
+ try:
+ import ceph
+ except ImportError:
+ raise ImportError('Please install ceph to enable CephBackend.')
+
+ warnings.warn(
+ 'CephBackend will be deprecated, please use PetrelBackend instead',
+ DeprecationWarning)
+ self._client = ceph.S3Client()
+ assert isinstance(path_mapping, dict) or path_mapping is None
+ self.path_mapping = path_mapping
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ if self.path_mapping is not None:
+ for k, v in self.path_mapping.items():
+ filepath = filepath.replace(k, v)
+ value = self._client.Get(filepath)
+ value_buf = memoryview(value)
+ return value_buf
+
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+
+
+class PetrelBackend(BaseStorageBackend):
+ """Petrel storage backend (for internal use).
+
+ PetrelBackend supports reading and writing data to multiple clusters.
+ If the file path contains the cluster name, PetrelBackend will read data
+ from specified cluster or write data to it. Otherwise, PetrelBackend will
+ access the default cluster.
+
+ Args:
+ path_mapping (dict, optional): Path mapping dict from local path to
+ Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in
+ ``filepath`` will be replaced by ``dst``. Default: None.
+ enable_mc (bool, optional): Whether to enable memcached support.
+ Default: True.
+
+ Examples:
+ >>> filepath1 = 's3://path/of/file'
+ >>> filepath2 = 'cluster-name:s3://path/of/file'
+ >>> client = PetrelBackend()
+ >>> client.get(filepath1) # get data from default cluster
+ >>> client.get(filepath2) # get data from 'cluster-name' cluster
+ """
+
+ def __init__(self,
+ path_mapping: Optional[dict] = None,
+ enable_mc: bool = True):
+ try:
+ from petrel_client import client
+ except ImportError:
+ raise ImportError('Please install petrel_client to enable '
+ 'PetrelBackend.')
+
+ self._client = client.Client(enable_mc=enable_mc)
+ assert isinstance(path_mapping, dict) or path_mapping is None
+ self.path_mapping = path_mapping
+
+ def _map_path(self, filepath: Union[str, Path]) -> str:
+ """Map ``filepath`` to a string path whose prefix will be replaced by
+ :attr:`self.path_mapping`.
+
+ Args:
+ filepath (str): Path to be mapped.
+ """
+ filepath = str(filepath)
+ if self.path_mapping is not None:
+ for k, v in self.path_mapping.items():
+ filepath = filepath.replace(k, v)
+ return filepath
+
+ def _format_path(self, filepath: str) -> str:
+ """Convert a ``filepath`` to standard format of petrel oss.
+
+ If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
+ environment, the ``filepath`` will be the format of
+ 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the
+ above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
+
+ Args:
+ filepath (str): Path to be formatted.
+ """
+ return re.sub(r'\\+', '/', filepath)
+
+ def get(self, filepath: Union[str, Path]) -> memoryview:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ memoryview: A memory view of expected bytes object to avoid
+ copying. The memoryview object can be converted to bytes by
+ ``value_buf.tobytes()``.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ value = self._client.Get(filepath)
+ value_buf = memoryview(value)
+ return value_buf
+
+ def get_text(self,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ return str(self.get(filepath), encoding=encoding)
+
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Save data to a given ``filepath``.
+
+ Args:
+ obj (bytes): Data to be saved.
+ filepath (str or Path): Path to write data.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ self._client.put(filepath, obj)
+
+ def put_text(self,
+ obj: str,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ """Save data to a given ``filepath``.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to encode the ``obj``.
+ Default: 'utf-8'.
+ """
+ self.put(bytes(obj, encoding=encoding), filepath)
+
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+
+ Args:
+ filepath (str or Path): Path to be removed.
+ """
+ if not has_method(self._client, 'delete'):
+ raise NotImplementedError(
+ 'Current version of Petrel Python SDK has not supported '
+ 'the `delete` method, please use a higher version or dev'
+ ' branch instead.')
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ self._client.delete(filepath)
+
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ if not (has_method(self._client, 'contains')
+ and has_method(self._client, 'isdir')):
+ raise NotImplementedError(
+ 'Current version of Petrel Python SDK has not supported '
+ 'the `contains` and `isdir` methods, please use a higher'
+ 'version or dev branch instead.')
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.contains(filepath) or self._client.isdir(filepath)
+
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ if not has_method(self._client, 'isdir'):
+ raise NotImplementedError(
+ 'Current version of Petrel Python SDK has not supported '
+ 'the `isdir` method, please use a higher version or dev'
+ ' branch instead.')
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.isdir(filepath)
+
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ if not has_method(self._client, 'contains'):
+ raise NotImplementedError(
+ 'Current version of Petrel Python SDK has not supported '
+ 'the `contains` method, please use a higher version or '
+ 'dev branch instead.')
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.contains(filepath)
+
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+
+ Args:
+ filepath (str or Path): Path to be concatenated.
+
+ Returns:
+ str: The result after concatenation.
+ """
+ filepath = self._format_path(self._map_path(filepath))
+ if filepath.endswith('/'):
+ filepath = filepath[:-1]
+ formatted_paths = [filepath]
+ for path in filepaths:
+ formatted_paths.append(self._format_path(self._map_path(path)))
+ return '/'.join(formatted_paths)
+
+ @contextmanager
+ def get_local_path(
+ self,
+ filepath: Union[str,
+ Path]) -> Generator[Union[str, Path], None, None]:
+ """Download a file from ``filepath`` and return a temporary path.
+
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Args:
+ filepath (str | Path): Download a file from ``filepath``.
+
+ Examples:
+ >>> client = PetrelBackend()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with client.get_local_path('s3://path/of/your/file') as path:
+ ... # do something here
+
+ Yields:
+ Iterable[str]: Only yield one temporary path.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ assert self.isfile(filepath)
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.get(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+
+ Note:
+ Petrel has no concept of directories but it simulates the directory
+ hierarchy in the filesystem through public prefixes. In addition,
+ if the returned path ends with '/', it means the path is a public
+ prefix which is a logical directory.
+
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+ In addition, the returned path of directory will not contains the
+ suffix '/' which is consistent with other backends.
+
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ if not has_method(self._client, 'list'):
+ raise NotImplementedError(
+ 'Current version of Petrel Python SDK has not supported '
+ 'the `list` method, please use a higher version or dev'
+ ' branch instead.')
+
+ dir_path = self._map_path(dir_path)
+ dir_path = self._format_path(dir_path)
+ if list_dir and suffix is not None:
+ raise TypeError(
+ '`list_dir` should be False when `suffix` is not None')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('`suffix` must be a string or tuple of strings')
+
+ # Petrel's simulated directory hierarchy assumes that directory paths
+ # should end with `/`
+ if not dir_path.endswith('/'):
+ dir_path += '/'
+
+ root = dir_path
+
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive):
+ for path in self._client.list(dir_path):
+ # the `self.isdir` is not used here to determine whether path
+ # is a directory, because `self.isdir` relies on
+ # `self._client.list`
+ if path.endswith('/'): # a directory path
+ next_dir_path = self.join_path(dir_path, path)
+ if list_dir:
+ # get the relative path and exclude the last
+ # character '/'
+ rel_dir = next_dir_path[len(root):-1]
+ yield rel_dir
+ if recursive:
+ yield from _list_dir_or_file(next_dir_path, list_dir,
+ list_file, suffix,
+ recursive)
+ else: # a file path
+ absolute_path = self.join_path(dir_path, path)
+ rel_path = absolute_path[len(root):]
+ if (suffix is None
+ or rel_path.endswith(suffix)) and list_file:
+ yield rel_path
+
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive)
+
+
+class MemcachedBackend(BaseStorageBackend):
+ """Memcached storage backend.
+
+ Attributes:
+ server_list_cfg (str): Config file for memcached server list.
+ client_cfg (str): Config file for memcached client.
+ sys_path (str | None): Additional path to be appended to `sys.path`.
+ Default: None.
+ """
+
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+ if sys_path is not None:
+ import sys
+ sys.path.append(sys_path)
+ try:
+ import mc
+ except ImportError:
+ raise ImportError(
+ 'Please install memcached to enable MemcachedBackend.')
+
+ self.server_list_cfg = server_list_cfg
+ self.client_cfg = client_cfg
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
+ self.client_cfg)
+ # mc.pyvector servers as a point which points to a memory cache
+ self._mc_buffer = mc.pyvector()
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ import mc
+ self._client.Get(filepath, self._mc_buffer)
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
+ return value_buf
+
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+
+
+class LmdbBackend(BaseStorageBackend):
+ """Lmdb storage backend.
+
+ Args:
+ db_path (str): Lmdb database path.
+ readonly (bool, optional): Lmdb environment parameter. If True,
+ disallow any write operations. Default: True.
+ lock (bool, optional): Lmdb environment parameter. If False, when
+ concurrent access occurs, do not lock the database. Default: False.
+ readahead (bool, optional): Lmdb environment parameter. If False,
+ disable the OS filesystem readahead mechanism, which may improve
+ random read performance when a database is larger than RAM.
+ Default: False.
+
+ Attributes:
+ db_path (str): Lmdb database path.
+ """
+
+ def __init__(self,
+ db_path,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ **kwargs):
+ try:
+ import lmdb # NOQA
+ except ImportError:
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+ self.db_path = str(db_path)
+ self.readonly = readonly
+ self.lock = lock
+ self.readahead = readahead
+ self.kwargs = kwargs
+ self._client = None
+
+ def get(self, filepath):
+ """Get values according to the filepath.
+
+ Args:
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+ """
+ if self._client is None:
+ self._client = self._get_client()
+
+ with self._client.begin(write=False) as txn:
+ value_buf = txn.get(str(filepath).encode('utf-8'))
+ return value_buf
+
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+
+ def _get_client(self):
+ import lmdb
+
+ return lmdb.open(
+ self.db_path,
+ readonly=self.readonly,
+ lock=self.lock,
+ readahead=self.readahead,
+ **self.kwargs)
+
+ def __del__(self):
+ self._client.close()
+
+
+class HardDiskBackend(BaseStorageBackend):
+ """Raw hard disks storage backend."""
+
+ _allow_symlink = True
+
+ def get(self, filepath: Union[str, Path]) -> bytes:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ bytes: Expected bytes object.
+ """
+ with open(filepath, 'rb') as f:
+ value_buf = f.read()
+ return value_buf
+
+ def get_text(self,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ with open(filepath, encoding=encoding) as f:
+ value_buf = f.read()
+ return value_buf
+
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+
+ Note:
+ ``put`` will create a directory if the directory of ``filepath``
+ does not exist.
+
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ mmcv.mkdir_or_exist(osp.dirname(filepath))
+ with open(filepath, 'wb') as f:
+ f.write(obj)
+
+ def put_text(self,
+ obj: str,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+
+ Note:
+ ``put_text`` will create a directory if the directory of
+ ``filepath`` does not exist.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+ """
+ mmcv.mkdir_or_exist(osp.dirname(filepath))
+ with open(filepath, 'w', encoding=encoding) as f:
+ f.write(obj)
+
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+
+ Args:
+ filepath (str or Path): Path to be removed.
+ """
+ os.remove(filepath)
+
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ return osp.exists(filepath)
+
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ return osp.isdir(filepath)
+
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ return osp.isfile(filepath)
+
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+
+ Join one or more filepath components intelligently. The return value
+ is the concatenation of filepath and any members of *filepaths.
+
+ Args:
+ filepath (str or Path): Path to be concatenated.
+
+ Returns:
+ str: The result of concatenation.
+ """
+ return osp.join(filepath, *filepaths)
+
+ @contextmanager
+ def get_local_path(
+ self,
+ filepath: Union[str,
+ Path]) -> Generator[Union[str, Path], None, None]:
+ """Only for unified API and do nothing."""
+ yield filepath
+
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ if list_dir and suffix is not None:
+ raise TypeError('`suffix` should be None when `list_dir` is True')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('`suffix` must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ rel_path = osp.relpath(entry.path, root)
+ if (suffix is None
+ or rel_path.endswith(suffix)) and list_file:
+ yield rel_path
+ elif osp.isdir(entry.path):
+ if list_dir:
+ rel_dir = osp.relpath(entry.path, root)
+ yield rel_dir
+ if recursive:
+ yield from _list_dir_or_file(entry.path, list_dir,
+ list_file, suffix,
+ recursive)
+
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive)
+
+
+class HTTPBackend(BaseStorageBackend):
+ """HTTP and HTTPS storage bachend."""
+
+ def get(self, filepath):
+ value_buf = urlopen(filepath).read()
+ return value_buf
+
+ def get_text(self, filepath, encoding='utf-8'):
+ value_buf = urlopen(filepath).read()
+ return value_buf.decode(encoding)
+
+ @contextmanager
+ def get_local_path(
+ self, filepath: str) -> Generator[Union[str, Path], None, None]:
+ """Download a file from ``filepath``.
+
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Args:
+ filepath (str): Download a file from ``filepath``.
+
+ Examples:
+ >>> client = HTTPBackend()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with client.get_local_path('http://path/of/your/file') as path:
+ ... # do something here
+ """
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.get(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+
+
+class FileClient:
+ """A general file client to access files in different backends.
+
+ The client loads a file or text in a specified backend from its path
+ and returns it as a binary or text file. There are two ways to choose a
+ backend, the name of backend and the prefix of path. Although both of them
+ can be used to choose a storage backend, ``backend`` has a higher priority
+ that is if they are all set, the storage backend will be chosen by the
+ backend argument. If they are all `None`, the disk backend will be chosen.
+ Note that It can also register other backend accessor with a given name,
+ prefixes, and backend class. In addition, We use the singleton pattern to
+ avoid repeated object creation. If the arguments are the same, the same
+ object will be returned.
+
+ Args:
+ backend (str, optional): The storage backend type. Options are "disk",
+ "ceph", "memcached", "lmdb", "http" and "petrel". Default: None.
+ prefix (str, optional): The prefix of the registered storage backend.
+ Options are "s3", "http", "https". Default: None.
+
+ Examples:
+ >>> # only set backend
+ >>> file_client = FileClient(backend='petrel')
+ >>> # only set prefix
+ >>> file_client = FileClient(prefix='s3')
+ >>> # set both backend and prefix but use backend to choose client
+ >>> file_client = FileClient(backend='petrel', prefix='s3')
+ >>> # if the arguments are the same, the same object is returned
+ >>> file_client1 = FileClient(backend='petrel')
+ >>> file_client1 is file_client
+ True
+
+ Attributes:
+ client (:obj:`BaseStorageBackend`): The backend object.
+ """
+
+ _backends = {
+ 'disk': HardDiskBackend,
+ 'ceph': CephBackend,
+ 'memcached': MemcachedBackend,
+ 'lmdb': LmdbBackend,
+ 'petrel': PetrelBackend,
+ 'http': HTTPBackend,
+ }
+
+ _prefix_to_backends = {
+ 's3': PetrelBackend,
+ 'http': HTTPBackend,
+ 'https': HTTPBackend,
+ }
+
+ _instances: dict = {}
+
+ client: Any
+
+ def __new__(cls, backend=None, prefix=None, **kwargs):
+ if backend is None and prefix is None:
+ backend = 'disk'
+ if backend is not None and backend not in cls._backends:
+ raise ValueError(
+ f'Backend {backend} is not supported. Currently supported ones'
+ f' are {list(cls._backends.keys())}')
+ if prefix is not None and prefix not in cls._prefix_to_backends:
+ raise ValueError(
+ f'prefix {prefix} is not supported. Currently supported ones '
+ f'are {list(cls._prefix_to_backends.keys())}')
+
+ # concatenate the arguments to a unique key for determining whether
+ # objects with the same arguments were created
+ arg_key = f'{backend}:{prefix}'
+ for key, value in kwargs.items():
+ arg_key += f':{key}:{value}'
+
+ if arg_key in cls._instances:
+ _instance = cls._instances[arg_key]
+ else:
+ # create a new object and put it to _instance
+ _instance = super().__new__(cls)
+ if backend is not None:
+ _instance.client = cls._backends[backend](**kwargs)
+ else:
+ _instance.client = cls._prefix_to_backends[prefix](**kwargs)
+
+ cls._instances[arg_key] = _instance
+
+ return _instance
+
+ @property
+ def name(self):
+ return self.client.name
+
+ @property
+ def allow_symlink(self):
+ return self.client.allow_symlink
+
+ @staticmethod
+ def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]:
+ """Parse the prefix of a uri.
+
+ Args:
+ uri (str | Path): Uri to be parsed that contains the file prefix.
+
+ Examples:
+ >>> FileClient.parse_uri_prefix('s3://path/of/your/file')
+ 's3'
+
+ Returns:
+ str | None: Return the prefix of uri if the uri contains '://' else
+ ``None``.
+ """
+ assert is_filepath(uri)
+ uri = str(uri)
+ if '://' not in uri:
+ return None
+ else:
+ prefix, _ = uri.split('://')
+ # In the case of PetrelBackend, the prefix may contains the cluster
+ # name like clusterName:s3
+ if ':' in prefix:
+ _, prefix = prefix.split(':')
+ return prefix
+
+ @classmethod
+ def infer_client(cls,
+ file_client_args: Optional[dict] = None,
+ uri: Optional[Union[str, Path]] = None) -> 'FileClient':
+ """Infer a suitable file client based on the URI and arguments.
+
+ Args:
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. Default: None.
+ uri (str | Path, optional): Uri to be parsed that contains the file
+ prefix. Default: None.
+
+ Examples:
+ >>> uri = 's3://path/of/your/file'
+ >>> file_client = FileClient.infer_client(uri=uri)
+ >>> file_client_args = {'backend': 'petrel'}
+ >>> file_client = FileClient.infer_client(file_client_args)
+
+ Returns:
+ FileClient: Instantiated FileClient object.
+ """
+ assert file_client_args is not None or uri is not None
+ if file_client_args is None:
+ file_prefix = cls.parse_uri_prefix(uri) # type: ignore
+ return cls(prefix=file_prefix)
+ else:
+ return cls(**file_client_args)
+
+ @classmethod
+ def _register_backend(cls, name, backend, force=False, prefixes=None):
+ if not isinstance(name, str):
+ raise TypeError('the backend name should be a string, '
+ f'but got {type(name)}')
+ if not inspect.isclass(backend):
+ raise TypeError(
+ f'backend should be a class but got {type(backend)}')
+ if not issubclass(backend, BaseStorageBackend):
+ raise TypeError(
+ f'backend {backend} is not a subclass of BaseStorageBackend')
+ if not force and name in cls._backends:
+ raise KeyError(
+ f'{name} is already registered as a storage backend, '
+ 'add "force=True" if you want to override it')
+
+ if name in cls._backends and force:
+ for arg_key, instance in list(cls._instances.items()):
+ if isinstance(instance.client, cls._backends[name]):
+ cls._instances.pop(arg_key)
+ cls._backends[name] = backend
+
+ if prefixes is not None:
+ if isinstance(prefixes, str):
+ prefixes = [prefixes]
+ else:
+ assert isinstance(prefixes, (list, tuple))
+ for prefix in prefixes:
+ if prefix not in cls._prefix_to_backends:
+ cls._prefix_to_backends[prefix] = backend
+ elif (prefix in cls._prefix_to_backends) and force:
+ overridden_backend = cls._prefix_to_backends[prefix]
+ if isinstance(overridden_backend, list):
+ overridden_backend = tuple(overridden_backend)
+ for arg_key, instance in list(cls._instances.items()):
+ if isinstance(instance.client, overridden_backend):
+ cls._instances.pop(arg_key)
+ cls._prefix_to_backends[prefix] = backend
+ else:
+ raise KeyError(
+ f'{prefix} is already registered as a storage backend,'
+ ' add "force=True" if you want to override it')
+
+ @classmethod
+ def register_backend(cls, name, backend=None, force=False, prefixes=None):
+ """Register a backend to FileClient.
+
+ This method can be used as a normal class method or a decorator.
+
+ .. code-block:: python
+
+ class NewBackend(BaseStorageBackend):
+
+ def get(self, filepath):
+ return filepath
+
+ def get_text(self, filepath):
+ return filepath
+
+ FileClient.register_backend('new', NewBackend)
+
+ or
+
+ .. code-block:: python
+
+ @FileClient.register_backend('new')
+ class NewBackend(BaseStorageBackend):
+
+ def get(self, filepath):
+ return filepath
+
+ def get_text(self, filepath):
+ return filepath
+
+ Args:
+ name (str): The name of the registered backend.
+ backend (class, optional): The backend class to be registered,
+ which must be a subclass of :class:`BaseStorageBackend`.
+ When this method is used as a decorator, backend is None.
+ Defaults to None.
+ force (bool, optional): Whether to override the backend if the name
+ has already been registered. Defaults to False.
+ prefixes (str or list[str] or tuple[str], optional): The prefixes
+ of the registered storage backend. Default: None.
+ `New in version 1.3.15.`
+ """
+ if backend is not None:
+ cls._register_backend(
+ name, backend, force=force, prefixes=prefixes)
+ return
+
+ def _register(backend_cls):
+ cls._register_backend(
+ name, backend_cls, force=force, prefixes=prefixes)
+ return backend_cls
+
+ return _register
+
+ def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Note:
+ There are two types of return values for ``get``, one is ``bytes``
+ and the other is ``memoryview``. The advantage of using memoryview
+ is that you can avoid copying, and if you want to convert it to
+ ``bytes``, you can use ``.tobytes()``.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ bytes | memoryview: Expected bytes object or a memory view of the
+ bytes object.
+ """
+ return self.client.get(filepath)
+
+ def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ return self.client.get_text(filepath, encoding)
+
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+
+ Note:
+ ``put`` should create a directory if the directory of ``filepath``
+ does not exist.
+
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ self.client.put(obj, filepath)
+
+ def put_text(self, obj: str, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+
+ Note:
+ ``put_text`` should create a directory if the directory of
+ ``filepath`` does not exist.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str, optional): The encoding format used to open the
+ `filepath`. Default: 'utf-8'.
+ """
+ self.client.put_text(obj, filepath)
+
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+
+ Args:
+ filepath (str, Path): Path to be removed.
+ """
+ self.client.remove(filepath)
+
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ return self.client.exists(filepath)
+
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ return self.client.isdir(filepath)
+
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ return self.client.isfile(filepath)
+
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+
+ Join one or more filepath components intelligently. The return value
+ is the concatenation of filepath and any members of *filepaths.
+
+ Args:
+ filepath (str or Path): Path to be concatenated.
+
+ Returns:
+ str: The result of concatenation.
+ """
+ return self.client.join_path(filepath, *filepaths)
+
+ @contextmanager
+ def get_local_path(
+ self,
+ filepath: Union[str,
+ Path]) -> Generator[Union[str, Path], None, None]:
+ """Download data from ``filepath`` and write the data to local path.
+
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Note:
+ If the ``filepath`` is a local path, just return itself.
+
+ .. warning::
+ ``get_local_path`` is an experimental interface that may change in
+ the future.
+
+ Args:
+ filepath (str or Path): Path to be read data.
+
+ Examples:
+ >>> file_client = FileClient(prefix='s3')
+ >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path:
+ ... # do something here
+
+ Yields:
+ Iterable[str]: Only yield one path.
+ """
+ with self.client.get_local_path(str(filepath)) as local_path:
+ yield local_path
+
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ yield from self.client.list_dir_or_file(dir_path, list_dir, list_file,
+ suffix, recursive)
diff --git a/mmcv/mmcv/fileio/handlers/__init__.py b/mmcv/mmcv/fileio/handlers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa24d91972837b8756b225f4879bac20436eb72a
--- /dev/null
+++ b/mmcv/mmcv/fileio/handlers/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import BaseFileHandler
+from .json_handler import JsonHandler
+from .pickle_handler import PickleHandler
+from .yaml_handler import YamlHandler
+
+__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler']
diff --git a/mmcv/mmcv/fileio/handlers/base.py b/mmcv/mmcv/fileio/handlers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c9cc15b67cbf7d320c2b9c6cbd441a5d5adf235
--- /dev/null
+++ b/mmcv/mmcv/fileio/handlers/base.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+
+class BaseFileHandler(metaclass=ABCMeta):
+ # `str_like` is a flag to indicate whether the type of file object is
+ # str-like object or bytes-like object. Pickle only processes bytes-like
+ # objects but json only processes str-like object. If it is str-like
+ # object, `StringIO` will be used to process the buffer.
+ str_like = True
+
+ @abstractmethod
+ def load_from_fileobj(self, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_str(self, obj, **kwargs):
+ pass
+
+ def load_from_path(self, filepath: str, mode: str = 'r', **kwargs):
+ with open(filepath, mode) as f:
+ return self.load_from_fileobj(f, **kwargs)
+
+ def dump_to_path(self, obj, filepath: str, mode: str = 'w', **kwargs):
+ with open(filepath, mode) as f:
+ self.dump_to_fileobj(obj, f, **kwargs)
diff --git a/mmcv/mmcv/fileio/handlers/json_handler.py b/mmcv/mmcv/fileio/handlers/json_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..18d4f15f74139d20adff18b20be5529c592a66b6
--- /dev/null
+++ b/mmcv/mmcv/fileio/handlers/json_handler.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+
+import numpy as np
+
+from .base import BaseFileHandler
+
+
+def set_default(obj):
+ """Set default json values for non-serializable values.
+
+ It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
+ It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
+ etc.) into plain numbers of plain python built-in types.
+ """
+ if isinstance(obj, (set, range)):
+ return list(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, np.generic):
+ return obj.item()
+ raise TypeError(f'{type(obj)} is unsupported for json dump')
+
+
+class JsonHandler(BaseFileHandler):
+
+ def load_from_fileobj(self, file):
+ return json.load(file)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('default', set_default)
+ json.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('default', set_default)
+ return json.dumps(obj, **kwargs)
diff --git a/mmcv/mmcv/fileio/handlers/pickle_handler.py b/mmcv/mmcv/fileio/handlers/pickle_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..073856fd25a731b42f3cd19269ad95744b20598f
--- /dev/null
+++ b/mmcv/mmcv/fileio/handlers/pickle_handler.py
@@ -0,0 +1,26 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pickle
+
+from .base import BaseFileHandler
+
+
+class PickleHandler(BaseFileHandler):
+
+ str_like = False
+
+ def load_from_fileobj(self, file, **kwargs):
+ return pickle.load(file, **kwargs)
+
+ def load_from_path(self, filepath, **kwargs):
+ return super().load_from_path(filepath, mode='rb', **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ return pickle.dumps(obj, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ pickle.dump(obj, file, **kwargs)
+
+ def dump_to_path(self, obj, filepath, **kwargs):
+ super().dump_to_path(obj, filepath, mode='wb', **kwargs)
diff --git a/mmcv/mmcv/fileio/handlers/yaml_handler.py b/mmcv/mmcv/fileio/handlers/yaml_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c1b077943d634b3ddcf5ee470855179b8308e9c
--- /dev/null
+++ b/mmcv/mmcv/fileio/handlers/yaml_handler.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import yaml
+
+try:
+ from yaml import CDumper as Dumper
+ from yaml import CLoader as Loader
+except ImportError:
+ from yaml import Loader, Dumper # type: ignore
+
+from .base import BaseFileHandler # isort:skip
+
+
+class YamlHandler(BaseFileHandler):
+
+ def load_from_fileobj(self, file, **kwargs):
+ kwargs.setdefault('Loader', Loader)
+ return yaml.load(file, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('Dumper', Dumper)
+ yaml.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('Dumper', Dumper)
+ return yaml.dump(obj, **kwargs)
diff --git a/mmcv/mmcv/fileio/io.py b/mmcv/mmcv/fileio/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..91192103cf331e8ceb970d6f1f5ac050137c0871
--- /dev/null
+++ b/mmcv/mmcv/fileio/io.py
@@ -0,0 +1,163 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from io import BytesIO, StringIO
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, TextIO, Union
+
+from ..utils import is_list_of
+from .file_client import FileClient
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+
+FileLikeObject = Union[TextIO, StringIO, BytesIO]
+
+file_handlers = {
+ 'json': JsonHandler(),
+ 'yaml': YamlHandler(),
+ 'yml': YamlHandler(),
+ 'pickle': PickleHandler(),
+ 'pkl': PickleHandler()
+}
+
+
+def load(file: Union[str, Path, FileLikeObject],
+ file_format: Optional[str] = None,
+ file_client_args: Optional[Dict] = None,
+ **kwargs):
+ """Load data from json/yaml/pickle files.
+
+ This method provides a unified api for loading data from serialized files.
+
+ Note:
+ In v1.3.16 and later, ``load`` supports loading data from serialized
+ files those can be storaged in different backends.
+
+ Args:
+ file (str or :obj:`Path` or file-like object): Filename or a file-like
+ object.
+ file_format (str, optional): If not specified, the file format will be
+ inferred from the file extension, otherwise use the specified one.
+ Currently supported formats include "json", "yaml/yml" and
+ "pickle/pkl".
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> load('/path/of/your/file') # file is storaged in disk
+ >>> load('https://path/of/your/file') # file is storaged in Internet
+ >>> load('s3://path/of/your/file') # file is storaged in petrel
+
+ Returns:
+ The content from the file.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None and isinstance(file, str):
+ file_format = file.split('.')[-1]
+ if file_format not in file_handlers:
+ raise TypeError(f'Unsupported format: {file_format}')
+
+ handler = file_handlers[file_format]
+ f: FileLikeObject
+ if isinstance(file, str):
+ file_client = FileClient.infer_client(file_client_args, file)
+ if handler.str_like:
+ with StringIO(file_client.get_text(file)) as f:
+ obj = handler.load_from_fileobj(f, **kwargs)
+ else:
+ with BytesIO(file_client.get(file)) as f:
+ obj = handler.load_from_fileobj(f, **kwargs)
+ elif hasattr(file, 'read'):
+ obj = handler.load_from_fileobj(file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filepath str or a file-object')
+ return obj
+
+
+def dump(obj: Any,
+ file: Optional[Union[str, Path, FileLikeObject]] = None,
+ file_format: Optional[str] = None,
+ file_client_args: Optional[Dict] = None,
+ **kwargs):
+ """Dump data to json/yaml/pickle strings or files.
+
+ This method provides a unified api for dumping data as strings or to files,
+ and also supports custom arguments for each file format.
+
+ Note:
+ In v1.3.16 and later, ``dump`` supports dumping data as strings or to
+ files which is saved to different backends.
+
+ Args:
+ obj (any): The python object to be dumped.
+ file (str or :obj:`Path` or file-like object, optional): If not
+ specified, then the object is dumped to a str, otherwise to a file
+ specified by the filename or file-like object.
+ file_format (str, optional): Same as :func:`load`.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> dump('hello world', '/path/of/your/file') # disk
+ >>> dump('hello world', 's3://path/of/your/file') # ceph or petrel
+
+ Returns:
+ bool: True for success, False otherwise.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None:
+ if isinstance(file, str):
+ file_format = file.split('.')[-1]
+ elif file is None:
+ raise ValueError(
+ 'file_format must be specified since file is None')
+ if file_format not in file_handlers:
+ raise TypeError(f'Unsupported format: {file_format}')
+ f: FileLikeObject
+ handler = file_handlers[file_format]
+ if file is None:
+ return handler.dump_to_str(obj, **kwargs)
+ elif isinstance(file, str):
+ file_client = FileClient.infer_client(file_client_args, file)
+ if handler.str_like:
+ with StringIO() as f:
+ handler.dump_to_fileobj(obj, f, **kwargs)
+ file_client.put_text(f.getvalue(), file)
+ else:
+ with BytesIO() as f:
+ handler.dump_to_fileobj(obj, f, **kwargs)
+ file_client.put(f.getvalue(), file)
+ elif hasattr(file, 'write'):
+ handler.dump_to_fileobj(obj, file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filename str or a file-object')
+
+
+def _register_handler(handler: BaseFileHandler,
+ file_formats: Union[str, List[str]]) -> None:
+ """Register a handler for some file extensions.
+
+ Args:
+ handler (:obj:`BaseFileHandler`): Handler to be registered.
+ file_formats (str or list[str]): File formats to be handled by this
+ handler.
+ """
+ if not isinstance(handler, BaseFileHandler):
+ raise TypeError(
+ f'handler must be a child of BaseFileHandler, not {type(handler)}')
+ if isinstance(file_formats, str):
+ file_formats = [file_formats]
+ if not is_list_of(file_formats, str):
+ raise TypeError('file_formats must be a str or a list of str')
+ for ext in file_formats:
+ file_handlers[ext] = handler
+
+
+def register_handler(file_formats: Union[str, list], **kwargs) -> Callable:
+
+ def wrap(cls):
+ _register_handler(cls(**kwargs), file_formats)
+ return cls
+
+ return wrap
diff --git a/mmcv/mmcv/fileio/parse.py b/mmcv/mmcv/fileio/parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..f28e59119325a1bb68b38dd884c59b68dbed6508
--- /dev/null
+++ b/mmcv/mmcv/fileio/parse.py
@@ -0,0 +1,99 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from io import StringIO
+from pathlib import Path
+from typing import Dict, List, Optional, Union
+
+from .file_client import FileClient
+
+
+def list_from_file(filename: Union[str, Path],
+ prefix: str = '',
+ offset: int = 0,
+ max_num: int = 0,
+ encoding: str = 'utf-8',
+ file_client_args: Optional[Dict] = None) -> List:
+ """Load a text file and parse the content as a list of strings.
+
+ Note:
+ In v1.3.16 and later, ``list_from_file`` supports loading a text file
+ which can be storaged in different backends and parsing the content as
+ a list for strings.
+
+ Args:
+ filename (str): Filename.
+ prefix (str): The prefix to be inserted to the beginning of each item.
+ offset (int): The offset of lines.
+ max_num (int): The maximum number of lines to be read,
+ zeros and negatives mean no limitation.
+ encoding (str): Encoding used to open the file. Default utf-8.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> list_from_file('/path/of/your/file') # disk
+ ['hello', 'world']
+ >>> list_from_file('s3://path/of/your/file') # ceph or petrel
+ ['hello', 'world']
+
+ Returns:
+ list[str]: A list of strings.
+ """
+ cnt = 0
+ item_list = []
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with StringIO(file_client.get_text(filename, encoding)) as f:
+ for _ in range(offset):
+ f.readline()
+ for line in f:
+ if 0 < max_num <= cnt:
+ break
+ item_list.append(prefix + line.rstrip('\n\r'))
+ cnt += 1
+ return item_list
+
+
+def dict_from_file(filename: Union[str, Path],
+ key_type: type = str,
+ encoding: str = 'utf-8',
+ file_client_args: Optional[Dict] = None) -> Dict:
+ """Load a text file and parse the content as a dict.
+
+ Each line of the text file will be two or more columns split by
+ whitespaces or tabs. The first column will be parsed as dict keys, and
+ the following columns will be parsed as dict values.
+
+ Note:
+ In v1.3.16 and later, ``dict_from_file`` supports loading a text file
+ which can be storaged in different backends and parsing the content as
+ a dict.
+
+ Args:
+ filename(str): Filename.
+ key_type(type): Type of the dict keys. str is user by default and
+ type conversion will be performed if specified.
+ encoding (str): Encoding used to open the file. Default utf-8.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> dict_from_file('/path/of/your/file') # disk
+ {'key1': 'value1', 'key2': 'value2'}
+ >>> dict_from_file('s3://path/of/your/file') # ceph or petrel
+ {'key1': 'value1', 'key2': 'value2'}
+
+ Returns:
+ dict: The parsed contents.
+ """
+ mapping = {}
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with StringIO(file_client.get_text(filename, encoding)) as f:
+ for line in f:
+ items = line.rstrip('\n').split()
+ assert len(items) >= 2
+ key = key_type(items[0])
+ val = items[1:] if len(items) > 2 else items[1]
+ mapping[key] = val
+ return mapping
diff --git a/mmcv/mmcv/image/__init__.py b/mmcv/mmcv/image/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ecec4046a6f5ee25b4ea07215ed7c7c810dcfa
--- /dev/null
+++ b/mmcv/mmcv/image/__init__.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, bgr2ycbcr,
+ gray2bgr, gray2rgb, hls2bgr, hsv2bgr, imconvert,
+ rgb2bgr, rgb2gray, rgb2ycbcr, ycbcr2bgr, ycbcr2rgb)
+from .geometric import (cutout, imcrop, imflip, imflip_, impad,
+ impad_to_multiple, imrescale, imresize, imresize_like,
+ imresize_to_multiple, imrotate, imshear, imtranslate,
+ rescale_size)
+from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
+from .misc import tensor2imgs
+from .photometric import (adjust_brightness, adjust_color, adjust_contrast,
+ adjust_hue, adjust_lighting, adjust_sharpness,
+ auto_contrast, clahe, imdenormalize, imequalize,
+ iminvert, imnormalize, imnormalize_, lut_transform,
+ posterize, solarize)
+
+__all__ = [
+ 'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
+ 'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale',
+ 'imresize', 'imresize_like', 'imresize_to_multiple', 'rescale_size',
+ 'imcrop', 'imflip', 'imflip_', 'impad', 'impad_to_multiple', 'imrotate',
+ 'imfrombytes', 'imread', 'imwrite', 'supported_backends', 'use_backend',
+ 'imdenormalize', 'imnormalize', 'imnormalize_', 'iminvert', 'posterize',
+ 'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr',
+ 'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize',
+ 'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe',
+ 'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting',
+ 'adjust_hue'
+]
diff --git a/mmcv/mmcv/image/colorspace.py b/mmcv/mmcv/image/colorspace.py
new file mode 100644
index 0000000000000000000000000000000000000000..08f9952408c8e0bb38b17c10e2089e900ed418c2
--- /dev/null
+++ b/mmcv/mmcv/image/colorspace.py
@@ -0,0 +1,309 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Callable, Union
+
+import cv2
+import numpy as np
+
+
+def imconvert(img: np.ndarray, src: str, dst: str) -> np.ndarray:
+ """Convert an image from the src colorspace to dst colorspace.
+
+ Args:
+ img (ndarray): The input image.
+ src (str): The source colorspace, e.g., 'rgb', 'hsv'.
+ dst (str): The destination colorspace, e.g., 'rgb', 'hsv'.
+
+ Returns:
+ ndarray: The converted image.
+ """
+ code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
+ out_img = cv2.cvtColor(img, code)
+ return out_img
+
+
+def bgr2gray(img: np.ndarray, keepdim: bool = False) -> np.ndarray:
+ """Convert a BGR image to grayscale image.
+
+ Args:
+ img (ndarray): The input image.
+ keepdim (bool): If False (by default), then return the grayscale image
+ with 2 dims, otherwise 3 dims.
+
+ Returns:
+ ndarray: The converted grayscale image.
+ """
+ out_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ if keepdim:
+ out_img = out_img[..., None]
+ return out_img
+
+
+def rgb2gray(img: np.ndarray, keepdim: bool = False) -> np.ndarray:
+ """Convert a RGB image to grayscale image.
+
+ Args:
+ img (ndarray): The input image.
+ keepdim (bool): If False (by default), then return the grayscale image
+ with 2 dims, otherwise 3 dims.
+
+ Returns:
+ ndarray: The converted grayscale image.
+ """
+ out_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ if keepdim:
+ out_img = out_img[..., None]
+ return out_img
+
+
+def gray2bgr(img: np.ndarray) -> np.ndarray:
+ """Convert a grayscale image to BGR image.
+
+ Args:
+ img (ndarray): The input image.
+
+ Returns:
+ ndarray: The converted BGR image.
+ """
+ img = img[..., None] if img.ndim == 2 else img
+ out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ return out_img
+
+
+def gray2rgb(img: np.ndarray) -> np.ndarray:
+ """Convert a grayscale image to RGB image.
+
+ Args:
+ img (ndarray): The input image.
+
+ Returns:
+ ndarray: The converted RGB image.
+ """
+ img = img[..., None] if img.ndim == 2 else img
+ out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ return out_img
+
+
+def _convert_input_type_range(img: np.ndarray) -> np.ndarray:
+ """Convert the type and range of the input image.
+
+ It converts the input image to np.float32 type and range of [0, 1].
+ It is mainly used for pre-processing the input image in colorspace
+ conversion functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with type of np.float32 and range of
+ [0, 1].
+ """
+ img_type = img.dtype
+ img = img.astype(np.float32)
+ if img_type == np.float32:
+ pass
+ elif img_type == np.uint8:
+ img /= 255.
+ else:
+ raise TypeError('The img type should be np.float32 or np.uint8, '
+ f'but got {img_type}')
+ return img
+
+
+def _convert_output_type_range(
+ img: np.ndarray, dst_type: Union[np.uint8, np.float32]) -> np.ndarray:
+ """Convert the type and range of the image according to dst_type.
+
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
+ images will be converted to np.uint8 type with range [0, 255]. If
+ `dst_type` is np.float32, it converts the image to np.float32 type with
+ range [0, 1].
+ It is mainly used for post-processing images in colorspace conversion
+ functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The image to be converted with np.float32 type and
+ range [0, 255].
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+ converts the image to np.uint8 type with range [0, 255]. If
+ dst_type is np.float32, it converts the image to np.float32 type
+ with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with desired type and range.
+ """
+ if dst_type not in (np.uint8, np.float32):
+ raise TypeError('The dst_type should be np.float32 or np.uint8, '
+ f'but got {dst_type}')
+ if dst_type == np.uint8:
+ img = img.round()
+ else:
+ img /= 255.
+ return img.astype(dst_type)
+
+
+def rgb2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray:
+ """Convert a RGB image to YCbCr image.
+
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def bgr2ycbcr(img: np.ndarray, y_only: bool = False) -> np.ndarray:
+ """Convert a BGR image to YCbCr image.
+
+ The bgr version of rgb2ycbcr.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2rgb(img: np.ndarray) -> np.ndarray:
+ """Convert a YCbCr image to RGB image.
+
+ This function produces the same results as Matlab's ycbcr2rgb function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted RGB image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+ [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [
+ -222.921, 135.576, -276.836
+ ]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2bgr(img: np.ndarray) -> np.ndarray:
+ """Convert a YCbCr image to BGR image.
+
+ The bgr version of ycbcr2rgb.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted BGR image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+ [0.00791071, -0.00153632, 0],
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [
+ -276.836, 135.576, -222.921
+ ]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def convert_color_factory(src: str, dst: str) -> Callable:
+
+ code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
+
+ def convert_color(img: np.ndarray) -> np.ndarray:
+ out_img = cv2.cvtColor(img, code)
+ return out_img
+
+ convert_color.__doc__ = f"""Convert a {src.upper()} image to {dst.upper()}
+ image.
+
+ Args:
+ img (ndarray or str): The input image.
+
+ Returns:
+ ndarray: The converted {dst.upper()} image.
+ """
+
+ return convert_color
+
+
+bgr2rgb = convert_color_factory('bgr', 'rgb')
+
+rgb2bgr = convert_color_factory('rgb', 'bgr')
+
+bgr2hsv = convert_color_factory('bgr', 'hsv')
+
+hsv2bgr = convert_color_factory('hsv', 'bgr')
+
+bgr2hls = convert_color_factory('bgr', 'hls')
+
+hls2bgr = convert_color_factory('hls', 'bgr')
diff --git a/mmcv/mmcv/image/geometric.py b/mmcv/mmcv/image/geometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..eecd795ea08127055cd8e90eb11c5e51fe586c18
--- /dev/null
+++ b/mmcv/mmcv/image/geometric.py
@@ -0,0 +1,741 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+
+import cv2
+import numpy as np
+
+from ..utils import to_2tuple
+from .io import imread_backend
+
+try:
+ from PIL import Image
+except ImportError:
+ Image = None
+
+
+def _scale_size(size, scale):
+ """Rescale a size by a ratio.
+
+ Args:
+ size (tuple[int]): (w, h).
+ scale (float | tuple(float)): Scaling factor.
+
+ Returns:
+ tuple[int]: scaled size.
+ """
+ if isinstance(scale, (float, int)):
+ scale = (scale, scale)
+ w, h = size
+ return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
+
+
+cv2_interp_codes = {
+ 'nearest': cv2.INTER_NEAREST,
+ 'bilinear': cv2.INTER_LINEAR,
+ 'bicubic': cv2.INTER_CUBIC,
+ 'area': cv2.INTER_AREA,
+ 'lanczos': cv2.INTER_LANCZOS4
+}
+
+# Pillow >=v9.1.0 use a slightly different naming scheme for filters.
+# Set pillow_interp_codes according to the naming scheme used.
+if Image is not None:
+ if hasattr(Image, 'Resampling'):
+ pillow_interp_codes = {
+ 'nearest': Image.Resampling.NEAREST,
+ 'bilinear': Image.Resampling.BILINEAR,
+ 'bicubic': Image.Resampling.BICUBIC,
+ 'box': Image.Resampling.BOX,
+ 'lanczos': Image.Resampling.LANCZOS,
+ 'hamming': Image.Resampling.HAMMING
+ }
+ else:
+ pillow_interp_codes = {
+ 'nearest': Image.NEAREST,
+ 'bilinear': Image.BILINEAR,
+ 'bicubic': Image.BICUBIC,
+ 'box': Image.BOX,
+ 'lanczos': Image.LANCZOS,
+ 'hamming': Image.HAMMING
+ }
+
+
+def imresize(img,
+ size,
+ return_scale=False,
+ interpolation='bilinear',
+ out=None,
+ backend=None):
+ """Resize image to a given size.
+
+ Args:
+ img (ndarray): The input image.
+ size (tuple[int]): Target size (w, h).
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ out (ndarray): The output destination.
+ backend (str | None): The image resize backend type. Options are `cv2`,
+ `pillow`, `None`. If backend is None, the global imread_backend
+ specified by ``mmcv.use_backend()`` will be used. Default: None.
+
+ Returns:
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = img.shape[:2]
+ if backend is None:
+ backend = imread_backend
+ if backend not in ['cv2', 'pillow']:
+ raise ValueError(f'backend: {backend} is not supported for resize.'
+ f"Supported backends are 'cv2', 'pillow'")
+
+ if backend == 'pillow':
+ assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
+ pil_image = Image.fromarray(img)
+ pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
+ resized_img = np.array(pil_image)
+ else:
+ resized_img = cv2.resize(
+ img, size, dst=out, interpolation=cv2_interp_codes[interpolation])
+ if not return_scale:
+ return resized_img
+ else:
+ w_scale = size[0] / w
+ h_scale = size[1] / h
+ return resized_img, w_scale, h_scale
+
+
+def imresize_to_multiple(img,
+ divisor,
+ size=None,
+ scale_factor=None,
+ keep_ratio=False,
+ return_scale=False,
+ interpolation='bilinear',
+ out=None,
+ backend=None):
+ """Resize image according to a given size or scale factor and then rounds
+ up the the resized or rescaled image size to the nearest value that can be
+ divided by the divisor.
+
+ Args:
+ img (ndarray): The input image.
+ divisor (int | tuple): Resized image size will be a multiple of
+ divisor. If divisor is a tuple, divisor should be
+ (w_divisor, h_divisor).
+ size (None | int | tuple[int]): Target size (w, h). Default: None.
+ scale_factor (None | float | tuple[float]): Multiplier for spatial
+ size. Should match input size if it is a tuple and the 2D style is
+ (w_scale_factor, h_scale_factor). Default: None.
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+ image. Default: False.
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ out (ndarray): The output destination.
+ backend (str | None): The image resize backend type. Options are `cv2`,
+ `pillow`, `None`. If backend is None, the global imread_backend
+ specified by ``mmcv.use_backend()`` will be used. Default: None.
+
+ Returns:
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = img.shape[:2]
+ if size is not None and scale_factor is not None:
+ raise ValueError('only one of size or scale_factor should be defined')
+ elif size is None and scale_factor is None:
+ raise ValueError('one of size or scale_factor should be defined')
+ elif size is not None:
+ size = to_2tuple(size)
+ if keep_ratio:
+ size = rescale_size((w, h), size, return_scale=False)
+ else:
+ size = _scale_size((w, h), scale_factor)
+
+ divisor = to_2tuple(divisor)
+ size = tuple(int(np.ceil(s / d)) * d for s, d in zip(size, divisor))
+ resized_img, w_scale, h_scale = imresize(
+ img,
+ size,
+ return_scale=True,
+ interpolation=interpolation,
+ out=out,
+ backend=backend)
+ if return_scale:
+ return resized_img, w_scale, h_scale
+ else:
+ return resized_img
+
+
+def imresize_like(img,
+ dst_img,
+ return_scale=False,
+ interpolation='bilinear',
+ backend=None):
+ """Resize image to the same size of a given image.
+
+ Args:
+ img (ndarray): The input image.
+ dst_img (ndarray): The target image.
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Same as :func:`resize`.
+ backend (str | None): Same as :func:`resize`.
+
+ Returns:
+ tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = dst_img.shape[:2]
+ return imresize(img, (w, h), return_scale, interpolation, backend=backend)
+
+
+def rescale_size(old_size, scale, return_scale=False):
+ """Calculate the new size to be rescaled to.
+
+ Args:
+ old_size (tuple[int]): The old size (w, h) of image.
+ scale (float | tuple[int]): The scaling factor or maximum size.
+ If it is a float number, then the image will be rescaled by this
+ factor, else if it is a tuple of 2 integers, then the image will
+ be rescaled as large as possible within the scale.
+ return_scale (bool): Whether to return the scaling factor besides the
+ rescaled image size.
+
+ Returns:
+ tuple[int]: The new rescaled image size.
+ """
+ w, h = old_size
+ if isinstance(scale, (float, int)):
+ if scale <= 0:
+ raise ValueError(f'Invalid scale {scale}, must be positive.')
+ scale_factor = scale
+ elif isinstance(scale, tuple):
+ max_long_edge = max(scale)
+ max_short_edge = min(scale)
+ scale_factor = min(max_long_edge / max(h, w),
+ max_short_edge / min(h, w))
+ else:
+ raise TypeError(
+ f'Scale must be a number or tuple of int, but got {type(scale)}')
+
+ new_size = _scale_size((w, h), scale_factor)
+
+ if return_scale:
+ return new_size, scale_factor
+ else:
+ return new_size
+
+
+def imrescale(img,
+ scale,
+ return_scale=False,
+ interpolation='bilinear',
+ backend=None):
+ """Resize image while keeping the aspect ratio.
+
+ Args:
+ img (ndarray): The input image.
+ scale (float | tuple[int]): The scaling factor or maximum size.
+ If it is a float number, then the image will be rescaled by this
+ factor, else if it is a tuple of 2 integers, then the image will
+ be rescaled as large as possible within the scale.
+ return_scale (bool): Whether to return the scaling factor besides the
+ rescaled image.
+ interpolation (str): Same as :func:`resize`.
+ backend (str | None): Same as :func:`resize`.
+
+ Returns:
+ ndarray: The rescaled image.
+ """
+ h, w = img.shape[:2]
+ new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
+ rescaled_img = imresize(
+ img, new_size, interpolation=interpolation, backend=backend)
+ if return_scale:
+ return rescaled_img, scale_factor
+ else:
+ return rescaled_img
+
+
+def imflip(img, direction='horizontal'):
+ """Flip an image horizontally or vertically.
+
+ Args:
+ img (ndarray): Image to be flipped.
+ direction (str): The flip direction, either "horizontal" or
+ "vertical" or "diagonal".
+
+ Returns:
+ ndarray: The flipped image.
+ """
+ assert direction in ['horizontal', 'vertical', 'diagonal']
+ if direction == 'horizontal':
+ return np.flip(img, axis=1)
+ elif direction == 'vertical':
+ return np.flip(img, axis=0)
+ else:
+ return np.flip(img, axis=(0, 1))
+
+
+def imflip_(img, direction='horizontal'):
+ """Inplace flip an image horizontally or vertically.
+
+ Args:
+ img (ndarray): Image to be flipped.
+ direction (str): The flip direction, either "horizontal" or
+ "vertical" or "diagonal".
+
+ Returns:
+ ndarray: The flipped image (inplace).
+ """
+ assert direction in ['horizontal', 'vertical', 'diagonal']
+ if direction == 'horizontal':
+ return cv2.flip(img, 1, img)
+ elif direction == 'vertical':
+ return cv2.flip(img, 0, img)
+ else:
+ return cv2.flip(img, -1, img)
+
+
+def imrotate(img,
+ angle,
+ center=None,
+ scale=1.0,
+ border_value=0,
+ interpolation='bilinear',
+ auto_bound=False):
+ """Rotate an image.
+
+ Args:
+ img (ndarray): Image to be rotated.
+ angle (float): Rotation angle in degrees, positive values mean
+ clockwise rotation.
+ center (tuple[float], optional): Center point (w, h) of the rotation in
+ the source image. If not specified, the center of the image will be
+ used.
+ scale (float): Isotropic scale factor.
+ border_value (int): Border value.
+ interpolation (str): Same as :func:`resize`.
+ auto_bound (bool): Whether to adjust the image size to cover the whole
+ rotated image.
+
+ Returns:
+ ndarray: The rotated image.
+ """
+ if center is not None and auto_bound:
+ raise ValueError('`auto_bound` conflicts with `center`')
+ h, w = img.shape[:2]
+ if center is None:
+ center = ((w - 1) * 0.5, (h - 1) * 0.5)
+ assert isinstance(center, tuple)
+
+ matrix = cv2.getRotationMatrix2D(center, -angle, scale)
+ if auto_bound:
+ cos = np.abs(matrix[0, 0])
+ sin = np.abs(matrix[0, 1])
+ new_w = h * sin + w * cos
+ new_h = h * cos + w * sin
+ matrix[0, 2] += (new_w - w) * 0.5
+ matrix[1, 2] += (new_h - h) * 0.5
+ w = int(np.round(new_w))
+ h = int(np.round(new_h))
+ rotated = cv2.warpAffine(
+ img,
+ matrix, (w, h),
+ flags=cv2_interp_codes[interpolation],
+ borderValue=border_value)
+ return rotated
+
+
+def bbox_clip(bboxes, img_shape):
+ """Clip bboxes to fit the image shape.
+
+ Args:
+ bboxes (ndarray): Shape (..., 4*k)
+ img_shape (tuple[int]): (height, width) of the image.
+
+ Returns:
+ ndarray: Clipped bboxes.
+ """
+ assert bboxes.shape[-1] % 4 == 0
+ cmin = np.empty(bboxes.shape[-1], dtype=bboxes.dtype)
+ cmin[0::2] = img_shape[1] - 1
+ cmin[1::2] = img_shape[0] - 1
+ clipped_bboxes = np.maximum(np.minimum(bboxes, cmin), 0)
+ return clipped_bboxes
+
+
+def bbox_scaling(bboxes, scale, clip_shape=None):
+ """Scaling bboxes w.r.t the box center.
+
+ Args:
+ bboxes (ndarray): Shape(..., 4).
+ scale (float): Scaling factor.
+ clip_shape (tuple[int], optional): If specified, bboxes that exceed the
+ boundary will be clipped according to the given shape (h, w).
+
+ Returns:
+ ndarray: Scaled bboxes.
+ """
+ if float(scale) == 1.0:
+ scaled_bboxes = bboxes.copy()
+ else:
+ w = bboxes[..., 2] - bboxes[..., 0] + 1
+ h = bboxes[..., 3] - bboxes[..., 1] + 1
+ dw = (w * (scale - 1)) * 0.5
+ dh = (h * (scale - 1)) * 0.5
+ scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1)
+ if clip_shape is not None:
+ return bbox_clip(scaled_bboxes, clip_shape)
+ else:
+ return scaled_bboxes
+
+
+def imcrop(img, bboxes, scale=1.0, pad_fill=None):
+ """Crop image patches.
+
+ 3 steps: scale the bboxes -> clip bboxes -> crop and pad.
+
+ Args:
+ img (ndarray): Image to be cropped.
+ bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
+ scale (float, optional): Scale ratio of bboxes, the default value
+ 1.0 means no padding.
+ pad_fill (Number | list[Number]): Value to be filled for padding.
+ Default: None, which means no padding.
+
+ Returns:
+ list[ndarray] | ndarray: The cropped image patches.
+ """
+ chn = 1 if img.ndim == 2 else img.shape[2]
+ if pad_fill is not None:
+ if isinstance(pad_fill, (int, float)):
+ pad_fill = [pad_fill for _ in range(chn)]
+ assert len(pad_fill) == chn
+
+ _bboxes = bboxes[None, ...] if bboxes.ndim == 1 else bboxes
+ scaled_bboxes = bbox_scaling(_bboxes, scale).astype(np.int32)
+ clipped_bbox = bbox_clip(scaled_bboxes, img.shape)
+
+ patches = []
+ for i in range(clipped_bbox.shape[0]):
+ x1, y1, x2, y2 = tuple(clipped_bbox[i, :])
+ if pad_fill is None:
+ patch = img[y1:y2 + 1, x1:x2 + 1, ...]
+ else:
+ _x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :])
+ if chn == 1:
+ patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1)
+ else:
+ patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1, chn)
+ patch = np.array(
+ pad_fill, dtype=img.dtype) * np.ones(
+ patch_shape, dtype=img.dtype)
+ x_start = 0 if _x1 >= 0 else -_x1
+ y_start = 0 if _y1 >= 0 else -_y1
+ w = x2 - x1 + 1
+ h = y2 - y1 + 1
+ patch[y_start:y_start + h, x_start:x_start + w,
+ ...] = img[y1:y1 + h, x1:x1 + w, ...]
+ patches.append(patch)
+
+ if bboxes.ndim == 1:
+ return patches[0]
+ else:
+ return patches
+
+
+def impad(img,
+ *,
+ shape=None,
+ padding=None,
+ pad_val=0,
+ padding_mode='constant'):
+ """Pad the given image to a certain shape or pad on all sides with
+ specified padding mode and padding value.
+
+ Args:
+ img (ndarray): Image to be padded.
+ shape (tuple[int]): Expected padding shape (h, w). Default: None.
+ padding (int or tuple[int]): Padding on each border. If a single int is
+ provided this is used to pad all borders. If tuple of length 2 is
+ provided this is the padding on left/right and top/bottom
+ respectively. If a tuple of length 4 is provided this is the
+ padding for the left, top, right and bottom borders respectively.
+ Default: None. Note that `shape` and `padding` can not be both
+ set.
+ pad_val (Number | Sequence[Number]): Values to be filled in padding
+ areas when padding_mode is 'constant'. Default: 0.
+ padding_mode (str): Type of padding. Should be: constant, edge,
+ reflect or symmetric. Default: constant.
+ - constant: pads with a constant value, this value is specified
+ with pad_val.
+ - edge: pads with the last value at the edge of the image.
+ - reflect: pads with reflection of image without repeating the last
+ value on the edge. For example, padding [1, 2, 3, 4] with 2
+ elements on both sides in reflect mode will result in
+ [3, 2, 1, 2, 3, 4, 3, 2].
+ - symmetric: pads with reflection of image repeating the last value
+ on the edge. For example, padding [1, 2, 3, 4] with 2 elements on
+ both sides in symmetric mode will result in
+ [2, 1, 1, 2, 3, 4, 4, 3]
+
+ Returns:
+ ndarray: The padded image.
+ """
+
+ assert (shape is not None) ^ (padding is not None)
+ if shape is not None:
+ width = max(shape[1] - img.shape[1], 0)
+ height = max(shape[0] - img.shape[0], 0)
+ padding = (0, 0, width, height)
+
+ # check pad_val
+ if isinstance(pad_val, tuple):
+ assert len(pad_val) == img.shape[-1]
+ elif not isinstance(pad_val, numbers.Number):
+ raise TypeError('pad_val must be a int or a tuple. '
+ f'But received {type(pad_val)}')
+
+ # check padding
+ if isinstance(padding, tuple) and len(padding) in [2, 4]:
+ if len(padding) == 2:
+ padding = (padding[0], padding[1], padding[0], padding[1])
+ elif isinstance(padding, numbers.Number):
+ padding = (padding, padding, padding, padding)
+ else:
+ raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
+ f'But received {padding}')
+
+ # check padding mode
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
+
+ border_type = {
+ 'constant': cv2.BORDER_CONSTANT,
+ 'edge': cv2.BORDER_REPLICATE,
+ 'reflect': cv2.BORDER_REFLECT_101,
+ 'symmetric': cv2.BORDER_REFLECT
+ }
+ img = cv2.copyMakeBorder(
+ img,
+ padding[1],
+ padding[3],
+ padding[0],
+ padding[2],
+ border_type[padding_mode],
+ value=pad_val)
+
+ return img
+
+
+def impad_to_multiple(img, divisor, pad_val=0):
+ """Pad an image to ensure each edge to be multiple to some number.
+
+ Args:
+ img (ndarray): Image to be padded.
+ divisor (int): Padded image edges will be multiple to divisor.
+ pad_val (Number | Sequence[Number]): Same as :func:`impad`.
+
+ Returns:
+ ndarray: The padded image.
+ """
+ pad_h = int(np.ceil(img.shape[0] / divisor)) * divisor
+ pad_w = int(np.ceil(img.shape[1] / divisor)) * divisor
+ return impad(img, shape=(pad_h, pad_w), pad_val=pad_val)
+
+
+def cutout(img, shape, pad_val=0):
+ """Randomly cut out a rectangle from the original img.
+
+ Args:
+ img (ndarray): Image to be cutout.
+ shape (int | tuple[int]): Expected cutout shape (h, w). If given as a
+ int, the value will be used for both h and w.
+ pad_val (int | float | tuple[int | float]): Values to be filled in the
+ cut area. Defaults to 0.
+
+ Returns:
+ ndarray: The cutout image.
+ """
+
+ channels = 1 if img.ndim == 2 else img.shape[2]
+ if isinstance(shape, int):
+ cut_h, cut_w = shape, shape
+ else:
+ assert isinstance(shape, tuple) and len(shape) == 2, \
+ f'shape must be a int or a tuple with length 2, but got type ' \
+ f'{type(shape)} instead.'
+ cut_h, cut_w = shape
+ if isinstance(pad_val, (int, float)):
+ pad_val = tuple([pad_val] * channels)
+ elif isinstance(pad_val, tuple):
+ assert len(pad_val) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(pad_val), channels)
+ else:
+ raise TypeError(f'Invalid type {type(pad_val)} for `pad_val`')
+
+ img_h, img_w = img.shape[:2]
+ y0 = np.random.uniform(img_h)
+ x0 = np.random.uniform(img_w)
+
+ y1 = int(max(0, y0 - cut_h / 2.))
+ x1 = int(max(0, x0 - cut_w / 2.))
+ y2 = min(img_h, y1 + cut_h)
+ x2 = min(img_w, x1 + cut_w)
+
+ if img.ndim == 2:
+ patch_shape = (y2 - y1, x2 - x1)
+ else:
+ patch_shape = (y2 - y1, x2 - x1, channels)
+
+ img_cutout = img.copy()
+ patch = np.array(
+ pad_val, dtype=img.dtype) * np.ones(
+ patch_shape, dtype=img.dtype)
+ img_cutout[y1:y2, x1:x2, ...] = patch
+
+ return img_cutout
+
+
+def _get_shear_matrix(magnitude, direction='horizontal'):
+ """Generate the shear matrix for transformation.
+
+ Args:
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The flip direction, either "horizontal"
+ or "vertical".
+
+ Returns:
+ ndarray: The shear matrix with dtype float32.
+ """
+ if direction == 'horizontal':
+ shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0]])
+ elif direction == 'vertical':
+ shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0]])
+ return shear_matrix
+
+
+def imshear(img,
+ magnitude,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Shear an image.
+
+ Args:
+ img (ndarray): Image to be sheared with format (h, w)
+ or (h, w, c).
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The flip direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border.
+ interpolation (str): Same as :func:`resize`.
+
+ Returns:
+ ndarray: The sheared image.
+ """
+ assert direction in ['horizontal',
+ 'vertical'], f'Invalid direction: {direction}'
+ height, width = img.shape[:2]
+ if img.ndim == 2:
+ channels = 1
+ elif img.ndim == 3:
+ channels = img.shape[-1]
+ if isinstance(border_value, int):
+ border_value = tuple([border_value] * channels)
+ elif isinstance(border_value, tuple):
+ assert len(border_value) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(border_value), channels)
+ else:
+ raise ValueError(
+ f'Invalid type {type(border_value)} for `border_value`')
+ shear_matrix = _get_shear_matrix(magnitude, direction)
+ sheared = cv2.warpAffine(
+ img,
+ shear_matrix,
+ (width, height),
+ # Note case when the number elements in `border_value`
+ # greater than 3 (e.g. shearing masks whose channels large
+ # than 3) will raise TypeError in `cv2.warpAffine`.
+ # Here simply slice the first 3 values in `border_value`.
+ borderValue=border_value[:3],
+ flags=cv2_interp_codes[interpolation])
+ return sheared
+
+
+def _get_translate_matrix(offset, direction='horizontal'):
+ """Generate the translate matrix.
+
+ Args:
+ offset (int | float): The offset used for translate.
+ direction (str): The translate direction, either
+ "horizontal" or "vertical".
+
+ Returns:
+ ndarray: The translate matrix with dtype float32.
+ """
+ if direction == 'horizontal':
+ translate_matrix = np.float32([[1, 0, offset], [0, 1, 0]])
+ elif direction == 'vertical':
+ translate_matrix = np.float32([[1, 0, 0], [0, 1, offset]])
+ return translate_matrix
+
+
+def imtranslate(img,
+ offset,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Translate an image.
+
+ Args:
+ img (ndarray): Image to be translated with format
+ (h, w) or (h, w, c).
+ offset (int | float): The offset used for translate.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border.
+ interpolation (str): Same as :func:`resize`.
+
+ Returns:
+ ndarray: The translated image.
+ """
+ assert direction in ['horizontal',
+ 'vertical'], f'Invalid direction: {direction}'
+ height, width = img.shape[:2]
+ if img.ndim == 2:
+ channels = 1
+ elif img.ndim == 3:
+ channels = img.shape[-1]
+ if isinstance(border_value, int):
+ border_value = tuple([border_value] * channels)
+ elif isinstance(border_value, tuple):
+ assert len(border_value) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(border_value), channels)
+ else:
+ raise ValueError(
+ f'Invalid type {type(border_value)} for `border_value`.')
+ translate_matrix = _get_translate_matrix(offset, direction)
+ translated = cv2.warpAffine(
+ img,
+ translate_matrix,
+ (width, height),
+ # Note case when the number elements in `border_value`
+ # greater than 3 (e.g. translating masks whose channels
+ # large than 3) will raise TypeError in `cv2.warpAffine`.
+ # Here simply slice the first 3 values in `border_value`.
+ borderValue=border_value[:3],
+ flags=cv2_interp_codes[interpolation])
+ return translated
diff --git a/mmcv/mmcv/image/io.py b/mmcv/mmcv/image/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae81b561a84cccfa4923364679dce56d762db1bc
--- /dev/null
+++ b/mmcv/mmcv/image/io.py
@@ -0,0 +1,314 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import io
+import os.path as osp
+import warnings
+from pathlib import Path
+
+import cv2
+import numpy as np
+from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION,
+ IMREAD_UNCHANGED)
+
+from mmcv.fileio import FileClient
+from mmcv.utils import is_filepath, is_str
+
+try:
+ from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG
+except ImportError:
+ TJCS_RGB = TJPF_GRAY = TJPF_BGR = TurboJPEG = None
+
+try:
+ from PIL import Image, ImageOps
+except ImportError:
+ Image = None
+
+try:
+ import tifffile
+except ImportError:
+ tifffile = None
+
+jpeg = None
+supported_backends = ['cv2', 'turbojpeg', 'pillow', 'tifffile']
+
+imread_flags = {
+ 'color': IMREAD_COLOR,
+ 'grayscale': IMREAD_GRAYSCALE,
+ 'unchanged': IMREAD_UNCHANGED,
+ 'color_ignore_orientation': IMREAD_IGNORE_ORIENTATION | IMREAD_COLOR,
+ 'grayscale_ignore_orientation':
+ IMREAD_IGNORE_ORIENTATION | IMREAD_GRAYSCALE
+}
+
+imread_backend = 'cv2'
+
+
+def use_backend(backend):
+ """Select a backend for image decoding.
+
+ Args:
+ backend (str): The image decoding backend type. Options are `cv2`,
+ `pillow`, `turbojpeg` (see https://github.com/lilohuang/PyTurboJPEG)
+ and `tifffile`. `turbojpeg` is faster but it only supports `.jpeg`
+ file format.
+ """
+ assert backend in supported_backends
+ global imread_backend
+ imread_backend = backend
+ if imread_backend == 'turbojpeg':
+ if TurboJPEG is None:
+ raise ImportError('`PyTurboJPEG` is not installed')
+ global jpeg
+ if jpeg is None:
+ jpeg = TurboJPEG()
+ elif imread_backend == 'pillow':
+ if Image is None:
+ raise ImportError('`Pillow` is not installed')
+ elif imread_backend == 'tifffile':
+ if tifffile is None:
+ raise ImportError('`tifffile` is not installed')
+
+
+def _jpegflag(flag='color', channel_order='bgr'):
+ channel_order = channel_order.lower()
+ if channel_order not in ['rgb', 'bgr']:
+ raise ValueError('channel order must be either "rgb" or "bgr"')
+
+ if flag == 'color':
+ if channel_order == 'bgr':
+ return TJPF_BGR
+ elif channel_order == 'rgb':
+ return TJCS_RGB
+ elif flag == 'grayscale':
+ return TJPF_GRAY
+ else:
+ raise ValueError('flag must be "color" or "grayscale"')
+
+
+def _pillow2array(img, flag='color', channel_order='bgr'):
+ """Convert a pillow image to numpy array.
+
+ Args:
+ img (:obj:`PIL.Image.Image`): The image loaded using PIL
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are 'color', 'grayscale' and 'unchanged'.
+ Default to 'color'.
+ channel_order (str): The channel order of the output image array,
+ candidates are 'bgr' and 'rgb'. Default to 'bgr'.
+
+ Returns:
+ np.ndarray: The converted numpy array
+ """
+ channel_order = channel_order.lower()
+ if channel_order not in ['rgb', 'bgr']:
+ raise ValueError('channel order must be either "rgb" or "bgr"')
+
+ if flag == 'unchanged':
+ array = np.array(img)
+ if array.ndim >= 3 and array.shape[2] >= 3: # color image
+ array[:, :, :3] = array[:, :, (2, 1, 0)] # RGB to BGR
+ else:
+ # Handle exif orientation tag
+ if flag in ['color', 'grayscale']:
+ img = ImageOps.exif_transpose(img)
+ # If the image mode is not 'RGB', convert it to 'RGB' first.
+ if img.mode != 'RGB':
+ if img.mode != 'LA':
+ # Most formats except 'LA' can be directly converted to RGB
+ img = img.convert('RGB')
+ else:
+ # When the mode is 'LA', the default conversion will fill in
+ # the canvas with black, which sometimes shadows black objects
+ # in the foreground.
+ #
+ # Therefore, a random color (124, 117, 104) is used for canvas
+ img_rgba = img.convert('RGBA')
+ img = Image.new('RGB', img_rgba.size, (124, 117, 104))
+ img.paste(img_rgba, mask=img_rgba.split()[3]) # 3 is alpha
+ if flag in ['color', 'color_ignore_orientation']:
+ array = np.array(img)
+ if channel_order != 'rgb':
+ array = array[:, :, ::-1] # RGB to BGR
+ elif flag in ['grayscale', 'grayscale_ignore_orientation']:
+ img = img.convert('L')
+ array = np.array(img)
+ else:
+ raise ValueError(
+ 'flag must be "color", "grayscale", "unchanged", '
+ f'"color_ignore_orientation" or "grayscale_ignore_orientation"'
+ f' but got {flag}')
+ return array
+
+
+def imread(img_or_path,
+ flag='color',
+ channel_order='bgr',
+ backend=None,
+ file_client_args=None):
+ """Read an image.
+
+ Note:
+ In v1.4.1 and later, add `file_client_args` parameters.
+
+ Args:
+ img_or_path (ndarray or str or Path): Either a numpy array or str or
+ pathlib.Path. If it is a numpy array (loaded image), then
+ it will be returned as is.
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are `color`, `grayscale`, `unchanged`,
+ `color_ignore_orientation` and `grayscale_ignore_orientation`.
+ By default, `cv2` and `pillow` backend would rotate the image
+ according to its EXIF info unless called with `unchanged` or
+ `*_ignore_orientation` flags. `turbojpeg` and `tifffile` backend
+ always ignore image's EXIF info regardless of the flag.
+ The `turbojpeg` backend only supports `color` and `grayscale`.
+ channel_order (str): Order of channel, candidates are `bgr` and `rgb`.
+ backend (str | None): The image decoding backend type. Options are
+ `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`.
+ If backend is None, the global imread_backend specified by
+ ``mmcv.use_backend()`` will be used. Default: None.
+ file_client_args (dict | None): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Returns:
+ ndarray: Loaded image array.
+
+ Examples:
+ >>> import mmcv
+ >>> img_path = '/path/to/img.jpg'
+ >>> img = mmcv.imread(img_path)
+ >>> img = mmcv.imread(img_path, flag='color', channel_order='rgb',
+ ... backend='cv2')
+ >>> img = mmcv.imread(img_path, flag='color', channel_order='bgr',
+ ... backend='pillow')
+ >>> s3_img_path = 's3://bucket/img.jpg'
+ >>> # infer the file backend by the prefix s3
+ >>> img = mmcv.imread(s3_img_path)
+ >>> # manually set the file backend petrel
+ >>> img = mmcv.imread(s3_img_path, file_client_args={
+ ... 'backend': 'petrel'})
+ >>> http_img_path = 'http://path/to/img.jpg'
+ >>> img = mmcv.imread(http_img_path)
+ >>> img = mmcv.imread(http_img_path, file_client_args={
+ ... 'backend': 'http'})
+ """
+
+ if isinstance(img_or_path, Path):
+ img_or_path = str(img_or_path)
+
+ if isinstance(img_or_path, np.ndarray):
+ return img_or_path
+ elif is_str(img_or_path):
+ file_client = FileClient.infer_client(file_client_args, img_or_path)
+ img_bytes = file_client.get(img_or_path)
+ return imfrombytes(img_bytes, flag, channel_order, backend)
+ else:
+ raise TypeError('"img" must be a numpy array or a str or '
+ 'a pathlib.Path object')
+
+
+def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
+ """Read an image from bytes.
+
+ Args:
+ content (bytes): Image bytes got from files or other streams.
+ flag (str): Same as :func:`imread`.
+ channel_order (str): The channel order of the output, candidates
+ are 'bgr' and 'rgb'. Default to 'bgr'.
+ backend (str | None): The image decoding backend type. Options are
+ `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`. If backend is
+ None, the global imread_backend specified by ``mmcv.use_backend()``
+ will be used. Default: None.
+
+ Returns:
+ ndarray: Loaded image array.
+
+ Examples:
+ >>> img_path = '/path/to/img.jpg'
+ >>> with open(img_path, 'rb') as f:
+ >>> img_buff = f.read()
+ >>> img = mmcv.imfrombytes(img_buff)
+ >>> img = mmcv.imfrombytes(img_buff, flag='color', channel_order='rgb')
+ >>> img = mmcv.imfrombytes(img_buff, backend='pillow')
+ >>> img = mmcv.imfrombytes(img_buff, backend='cv2')
+ """
+
+ if backend is None:
+ backend = imread_backend
+ if backend not in supported_backends:
+ raise ValueError(
+ f'backend: {backend} is not supported. Supported '
+ "backends are 'cv2', 'turbojpeg', 'pillow', 'tifffile'")
+ if backend == 'turbojpeg':
+ img = jpeg.decode(content, _jpegflag(flag, channel_order))
+ if img.shape[-1] == 1:
+ img = img[:, :, 0]
+ return img
+ elif backend == 'pillow':
+ with io.BytesIO(content) as buff:
+ img = Image.open(buff)
+ img = _pillow2array(img, flag, channel_order)
+ return img
+ elif backend == 'tifffile':
+ with io.BytesIO(content) as buff:
+ img = tifffile.imread(buff)
+ return img
+ else:
+ img_np = np.frombuffer(content, np.uint8)
+ flag = imread_flags[flag] if is_str(flag) else flag
+ img = cv2.imdecode(img_np, flag)
+ if flag == IMREAD_COLOR and channel_order == 'rgb':
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ return img
+
+
+def imwrite(img,
+ file_path,
+ params=None,
+ auto_mkdir=None,
+ file_client_args=None):
+ """Write image to file.
+
+ Note:
+ In v1.4.1 and later, add `file_client_args` parameters.
+
+ Warning:
+ The parameter `auto_mkdir` will be deprecated in the future and every
+ file clients will make directory automatically.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically. It will be deprecated.
+ file_client_args (dict | None): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Returns:
+ bool: Successful or not.
+
+ Examples:
+ >>> # write to hard disk client
+ >>> ret = mmcv.imwrite(img, '/path/to/img.jpg')
+ >>> # infer the file backend by the prefix s3
+ >>> ret = mmcv.imwrite(img, 's3://bucket/img.jpg')
+ >>> # manually set the file backend petrel
+ >>> ret = mmcv.imwrite(img, 's3://bucket/img.jpg', file_client_args={
+ ... 'backend': 'petrel'})
+ """
+ assert is_filepath(file_path)
+ file_path = str(file_path)
+ if auto_mkdir is not None:
+ warnings.warn(
+ 'The parameter `auto_mkdir` will be deprecated in the future and '
+ 'every file clients will make directory automatically.')
+ file_client = FileClient.infer_client(file_client_args, file_path)
+ img_ext = osp.splitext(file_path)[-1]
+ # Encode image according to image suffix.
+ # For example, if image path is '/path/your/img.jpg', the encode
+ # format is '.jpg'.
+ flag, img_buff = cv2.imencode(img_ext, img, params)
+ file_client.put(img_buff.tobytes(), file_path)
+ return flag
diff --git a/mmcv/mmcv/image/misc.py b/mmcv/mmcv/image/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..43934a689dd7ac6d35b772b7ce9921ff3b1fff50
--- /dev/null
+++ b/mmcv/mmcv/image/misc.py
@@ -0,0 +1,53 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+import mmcv
+
+try:
+ import torch
+except ImportError:
+ torch = None
+
+
+def tensor2imgs(tensor, mean=None, std=None, to_rgb=True):
+ """Convert tensor to 3-channel images or 1-channel gray images.
+
+ Args:
+ tensor (torch.Tensor): Tensor that contains multiple images, shape (
+ N, C, H, W). :math:`C` can be either 3 or 1.
+ mean (tuple[float], optional): Mean of images. If None,
+ (0, 0, 0) will be used for tensor with 3-channel,
+ while (0, ) for tensor with 1-channel. Defaults to None.
+ std (tuple[float], optional): Standard deviation of images. If None,
+ (1, 1, 1) will be used for tensor with 3-channel,
+ while (1, ) for tensor with 1-channel. Defaults to None.
+ to_rgb (bool, optional): Whether the tensor was converted to RGB
+ format in the first place. If so, convert it back to BGR.
+ For the tensor with 1 channel, it must be False. Defaults to True.
+
+ Returns:
+ list[np.ndarray]: A list that contains multiple images.
+ """
+
+ if torch is None:
+ raise RuntimeError('pytorch is not installed')
+ assert torch.is_tensor(tensor) and tensor.ndim == 4
+ channels = tensor.size(1)
+ assert channels in [1, 3]
+ if mean is None:
+ mean = (0, ) * channels
+ if std is None:
+ std = (1, ) * channels
+ assert (channels == len(mean) == len(std) == 3) or \
+ (channels == len(mean) == len(std) == 1 and not to_rgb)
+
+ num_imgs = tensor.size(0)
+ mean = np.array(mean, dtype=np.float32)
+ std = np.array(std, dtype=np.float32)
+ imgs = []
+ for img_id in range(num_imgs):
+ img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0)
+ img = mmcv.imdenormalize(
+ img, mean, std, to_bgr=to_rgb).astype(np.uint8)
+ imgs.append(np.ascontiguousarray(img))
+ return imgs
diff --git a/mmcv/mmcv/image/photometric.py b/mmcv/mmcv/image/photometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..b41cea7172ae0ece858d868b73dc65deaea3510c
--- /dev/null
+++ b/mmcv/mmcv/image/photometric.py
@@ -0,0 +1,471 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+from ..utils import is_tuple_of
+from .colorspace import bgr2gray, gray2bgr
+
+
+def imnormalize(img, mean, std, to_rgb=True):
+ """Normalize an image with mean and std.
+
+ Args:
+ img (ndarray): Image to be normalized.
+ mean (ndarray): The mean to be used for normalize.
+ std (ndarray): The std to be used for normalize.
+ to_rgb (bool): Whether to convert to rgb.
+
+ Returns:
+ ndarray: The normalized image.
+ """
+ img = img.copy().astype(np.float32)
+ return imnormalize_(img, mean, std, to_rgb)
+
+
+def imnormalize_(img, mean, std, to_rgb=True):
+ """Inplace normalize an image with mean and std.
+
+ Args:
+ img (ndarray): Image to be normalized.
+ mean (ndarray): The mean to be used for normalize.
+ std (ndarray): The std to be used for normalize.
+ to_rgb (bool): Whether to convert to rgb.
+
+ Returns:
+ ndarray: The normalized image.
+ """
+ # cv2 inplace normalization does not accept uint8
+ assert img.dtype != np.uint8
+ mean = np.float64(mean.reshape(1, -1))
+ stdinv = 1 / np.float64(std.reshape(1, -1))
+ if to_rgb:
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
+ cv2.subtract(img, mean, img) # inplace
+ cv2.multiply(img, stdinv, img) # inplace
+ return img
+
+
+def imdenormalize(img, mean, std, to_bgr=True):
+ assert img.dtype != np.uint8
+ mean = mean.reshape(1, -1).astype(np.float64)
+ std = std.reshape(1, -1).astype(np.float64)
+ img = cv2.multiply(img, std) # make a copy
+ cv2.add(img, mean, img) # inplace
+ if to_bgr:
+ cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace
+ return img
+
+
+def iminvert(img):
+ """Invert (negate) an image.
+
+ Args:
+ img (ndarray): Image to be inverted.
+
+ Returns:
+ ndarray: The inverted image.
+ """
+ return np.full_like(img, 255) - img
+
+
+def solarize(img, thr=128):
+ """Solarize an image (invert all pixel values above a threshold)
+
+ Args:
+ img (ndarray): Image to be solarized.
+ thr (int): Threshold for solarizing (0 - 255).
+
+ Returns:
+ ndarray: The solarized image.
+ """
+ img = np.where(img < thr, img, 255 - img)
+ return img
+
+
+def posterize(img, bits):
+ """Posterize an image (reduce the number of bits for each color channel)
+
+ Args:
+ img (ndarray): Image to be posterized.
+ bits (int): Number of bits (1 to 8) to use for posterizing.
+
+ Returns:
+ ndarray: The posterized image.
+ """
+ shift = 8 - bits
+ img = np.left_shift(np.right_shift(img, shift), shift)
+ return img
+
+
+def adjust_color(img, alpha=1, beta=None, gamma=0):
+ r"""It blends the source image and its gray image:
+
+ .. math::
+ output = img * alpha + gray\_img * beta + gamma
+
+ Args:
+ img (ndarray): The input source image.
+ alpha (int | float): Weight for the source image. Default 1.
+ beta (int | float): Weight for the converted gray image.
+ If None, it's assigned the value (1 - `alpha`).
+ gamma (int | float): Scalar added to each sum.
+ Same as :func:`cv2.addWeighted`. Default 0.
+
+ Returns:
+ ndarray: Colored image which has the same size and dtype as input.
+ """
+ gray_img = bgr2gray(img)
+ gray_img = np.tile(gray_img[..., None], [1, 1, 3])
+ if beta is None:
+ beta = 1 - alpha
+ colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma)
+ if not colored_img.dtype == np.uint8:
+ # Note when the dtype of `img` is not the default `np.uint8`
+ # (e.g. np.float32), the value in `colored_img` got from cv2
+ # is not guaranteed to be in range [0, 255], so here clip
+ # is needed.
+ colored_img = np.clip(colored_img, 0, 255)
+ return colored_img
+
+
+def imequalize(img):
+ """Equalize the image histogram.
+
+ This function applies a non-linear mapping to the input image,
+ in order to create a uniform distribution of grayscale values
+ in the output image.
+
+ Args:
+ img (ndarray): Image to be equalized.
+
+ Returns:
+ ndarray: The equalized image.
+ """
+
+ def _scale_channel(im, c):
+ """Scale the data in the corresponding channel."""
+ im = im[:, :, c]
+ # Compute the histogram of the image channel.
+ histo = np.histogram(im, 256, (0, 255))[0]
+ # For computing the step, filter out the nonzeros.
+ nonzero_histo = histo[histo > 0]
+ step = (np.sum(nonzero_histo) - nonzero_histo[-1]) // 255
+ if not step:
+ lut = np.array(range(256))
+ else:
+ # Compute the cumulative sum, shifted by step // 2
+ # and then normalized by step.
+ lut = (np.cumsum(histo) + (step // 2)) // step
+ # Shift lut, prepending with 0.
+ lut = np.concatenate([[0], lut[:-1]], 0)
+ # handle potential integer overflow
+ lut[lut > 255] = 255
+ # If step is zero, return the original image.
+ # Otherwise, index from lut.
+ return np.where(np.equal(step, 0), im, lut[im])
+
+ # Scales each channel independently and then stacks
+ # the result.
+ s1 = _scale_channel(img, 0)
+ s2 = _scale_channel(img, 1)
+ s3 = _scale_channel(img, 2)
+ equalized_img = np.stack([s1, s2, s3], axis=-1)
+ return equalized_img.astype(img.dtype)
+
+
+def adjust_brightness(img, factor=1.):
+ """Adjust image brightness.
+
+ This function controls the brightness of an image. An
+ enhancement factor of 0.0 gives a black image.
+ A factor of 1.0 gives the original image. This function
+ blends the source image and the degenerated black image:
+
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+
+ Args:
+ img (ndarray): Image to be brightened.
+ factor (float): A value controls the enhancement.
+ Factor 1.0 returns the original image, lower
+ factors mean less color (brightness, contrast,
+ etc), and higher values more. Default 1.
+
+ Returns:
+ ndarray: The brightened image.
+ """
+ degenerated = np.zeros_like(img)
+ # Note manually convert the dtype to np.float32, to
+ # achieve as close results as PIL.ImageEnhance.Brightness.
+ # Set beta=1-factor, and gamma=0
+ brightened_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ brightened_img = np.clip(brightened_img, 0, 255)
+ return brightened_img.astype(img.dtype)
+
+
+def adjust_contrast(img, factor=1.):
+ """Adjust image contrast.
+
+ This function controls the contrast of an image. An
+ enhancement factor of 0.0 gives a solid grey
+ image. A factor of 1.0 gives the original image. It
+ blends the source image and the degenerated mean image:
+
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+
+ Args:
+ img (ndarray): Image to be contrasted. BGR order.
+ factor (float): Same as :func:`mmcv.adjust_brightness`.
+
+ Returns:
+ ndarray: The contrasted image.
+ """
+ gray_img = bgr2gray(img)
+ hist = np.histogram(gray_img, 256, (0, 255))[0]
+ mean = round(np.sum(gray_img) / np.sum(hist))
+ degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype)
+ degenerated = gray2bgr(degenerated)
+ contrasted_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ contrasted_img = np.clip(contrasted_img, 0, 255)
+ return contrasted_img.astype(img.dtype)
+
+
+def auto_contrast(img, cutoff=0):
+ """Auto adjust image contrast.
+
+ This function maximize (normalize) image contrast by first removing cutoff
+ percent of the lightest and darkest pixels from the histogram and remapping
+ the image so that the darkest pixel becomes black (0), and the lightest
+ becomes white (255).
+
+ Args:
+ img (ndarray): Image to be contrasted. BGR order.
+ cutoff (int | float | tuple): The cutoff percent of the lightest and
+ darkest pixels to be removed. If given as tuple, it shall be
+ (low, high). Otherwise, the single value will be used for both.
+ Defaults to 0.
+
+ Returns:
+ ndarray: The contrasted image.
+ """
+
+ def _auto_contrast_channel(im, c, cutoff):
+ im = im[:, :, c]
+ # Compute the histogram of the image channel.
+ histo = np.histogram(im, 256, (0, 255))[0]
+ # Remove cut-off percent pixels from histo
+ histo_sum = np.cumsum(histo)
+ cut_low = histo_sum[-1] * cutoff[0] // 100
+ cut_high = histo_sum[-1] - histo_sum[-1] * cutoff[1] // 100
+ histo_sum = np.clip(histo_sum, cut_low, cut_high) - cut_low
+ histo = np.concatenate([[histo_sum[0]], np.diff(histo_sum)], 0)
+
+ # Compute mapping
+ low, high = np.nonzero(histo)[0][0], np.nonzero(histo)[0][-1]
+ # If all the values have been cut off, return the origin img
+ if low >= high:
+ return im
+ scale = 255.0 / (high - low)
+ offset = -low * scale
+ lut = np.array(range(256))
+ lut = lut * scale + offset
+ lut = np.clip(lut, 0, 255)
+ return lut[im]
+
+ if isinstance(cutoff, (int, float)):
+ cutoff = (cutoff, cutoff)
+ else:
+ assert isinstance(cutoff, tuple), 'cutoff must be of type int, ' \
+ f'float or tuple, but got {type(cutoff)} instead.'
+ # Auto adjusts contrast for each channel independently and then stacks
+ # the result.
+ s1 = _auto_contrast_channel(img, 0, cutoff)
+ s2 = _auto_contrast_channel(img, 1, cutoff)
+ s3 = _auto_contrast_channel(img, 2, cutoff)
+ contrasted_img = np.stack([s1, s2, s3], axis=-1)
+ return contrasted_img.astype(img.dtype)
+
+
+def adjust_sharpness(img, factor=1., kernel=None):
+ """Adjust image sharpness.
+
+ This function controls the sharpness of an image. An
+ enhancement factor of 0.0 gives a blurred image. A
+ factor of 1.0 gives the original image. And a factor
+ of 2.0 gives a sharpened image. It blends the source
+ image and the degenerated mean image:
+
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+
+ Args:
+ img (ndarray): Image to be sharpened. BGR order.
+ factor (float): Same as :func:`mmcv.adjust_brightness`.
+ kernel (np.ndarray, optional): Filter kernel to be applied on the img
+ to obtain the degenerated img. Defaults to None.
+
+ Note:
+ No value sanity check is enforced on the kernel set by users. So with
+ an inappropriate kernel, the ``adjust_sharpness`` may fail to perform
+ the function its name indicates but end up performing whatever
+ transform determined by the kernel.
+
+ Returns:
+ ndarray: The sharpened image.
+ """
+
+ if kernel is None:
+ # adopted from PIL.ImageFilter.SMOOTH
+ kernel = np.array([[1., 1., 1.], [1., 5., 1.], [1., 1., 1.]]) / 13
+ assert isinstance(kernel, np.ndarray), \
+ f'kernel must be of type np.ndarray, but got {type(kernel)} instead.'
+ assert kernel.ndim == 2, \
+ f'kernel must have a dimension of 2, but got {kernel.ndim} instead.'
+
+ degenerated = cv2.filter2D(img, -1, kernel)
+ sharpened_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ sharpened_img = np.clip(sharpened_img, 0, 255)
+ return sharpened_img.astype(img.dtype)
+
+
+def adjust_lighting(img, eigval, eigvec, alphastd=0.1, to_rgb=True):
+ """AlexNet-style PCA jitter.
+
+ This data augmentation is proposed in `ImageNet Classification with Deep
+ Convolutional Neural Networks
+ `_.
+
+ Args:
+ img (ndarray): Image to be adjusted lighting. BGR order.
+ eigval (ndarray): the eigenvalue of the convariance matrix of pixel
+ values, respectively.
+ eigvec (ndarray): the eigenvector of the convariance matrix of pixel
+ values, respectively.
+ alphastd (float): The standard deviation for distribution of alpha.
+ Defaults to 0.1
+ to_rgb (bool): Whether to convert img to rgb.
+
+ Returns:
+ ndarray: The adjusted image.
+ """
+ assert isinstance(eigval, np.ndarray) and isinstance(eigvec, np.ndarray), \
+ f'eigval and eigvec should both be of type np.ndarray, got ' \
+ f'{type(eigval)} and {type(eigvec)} instead.'
+
+ assert eigval.ndim == 1 and eigvec.ndim == 2
+ assert eigvec.shape == (3, eigval.shape[0])
+ n_eigval = eigval.shape[0]
+ assert isinstance(alphastd, float), 'alphastd should be of type float, ' \
+ f'got {type(alphastd)} instead.'
+
+ img = img.copy().astype(np.float32)
+ if to_rgb:
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
+
+ alpha = np.random.normal(0, alphastd, n_eigval)
+ alter = eigvec \
+ * np.broadcast_to(alpha.reshape(1, n_eigval), (3, n_eigval)) \
+ * np.broadcast_to(eigval.reshape(1, n_eigval), (3, n_eigval))
+ alter = np.broadcast_to(alter.sum(axis=1).reshape(1, 1, 3), img.shape)
+ img_adjusted = img + alter
+ return img_adjusted
+
+
+def lut_transform(img, lut_table):
+ """Transform array by look-up table.
+
+ The function lut_transform fills the output array with values from the
+ look-up table. Indices of the entries are taken from the input array.
+
+ Args:
+ img (ndarray): Image to be transformed.
+ lut_table (ndarray): look-up table of 256 elements; in case of
+ multi-channel input array, the table should either have a single
+ channel (in this case the same table is used for all channels) or
+ the same number of channels as in the input array.
+
+ Returns:
+ ndarray: The transformed image.
+ """
+ assert isinstance(img, np.ndarray)
+ assert 0 <= np.min(img) and np.max(img) <= 255
+ assert isinstance(lut_table, np.ndarray)
+ assert lut_table.shape == (256, )
+
+ return cv2.LUT(np.array(img, dtype=np.uint8), lut_table)
+
+
+def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
+ """Use CLAHE method to process the image.
+
+ See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
+ Graphics Gems, 1994:474-485.` for more information.
+
+ Args:
+ img (ndarray): Image to be processed.
+ clip_limit (float): Threshold for contrast limiting. Default: 40.0.
+ tile_grid_size (tuple[int]): Size of grid for histogram equalization.
+ Input image will be divided into equally sized rectangular tiles.
+ It defines the number of tiles in row and column. Default: (8, 8).
+
+ Returns:
+ ndarray: The processed image.
+ """
+ assert isinstance(img, np.ndarray)
+ assert img.ndim == 2
+ assert isinstance(clip_limit, (float, int))
+ assert is_tuple_of(tile_grid_size, int)
+ assert len(tile_grid_size) == 2
+
+ clahe = cv2.createCLAHE(clip_limit, tile_grid_size)
+ return clahe.apply(np.array(img, dtype=np.uint8))
+
+
+def adjust_hue(img: np.ndarray, hue_factor: float) -> np.ndarray:
+ """Adjust hue of an image.
+
+ The image hue is adjusted by converting the image to HSV and cyclically
+ shifting the intensities in the hue channel (H). The image is then
+ converted back to original image mode.
+
+ `hue_factor` is the amount of shift in H channel and must be in the
+ interval `[-0.5, 0.5]`.
+
+ Modified from
+ https://github.com/pytorch/vision/blob/main/torchvision/
+ transforms/functional.py
+
+ Args:
+ img (ndarray): Image to be adjusted.
+ hue_factor (float): How much to shift the hue channel. Should be in
+ [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
+ HSV space in positive and negative direction respectively.
+ 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
+ with complementary colors while 0 gives the original image.
+
+ Returns:
+ ndarray: Hue adjusted image.
+ """
+
+ if not (-0.5 <= hue_factor <= 0.5):
+ raise ValueError(f'hue_factor:{hue_factor} is not in [-0.5, 0.5].')
+ if not (isinstance(img, np.ndarray) and (img.ndim in {2, 3})):
+ raise TypeError('img should be ndarray with dim=[2 or 3].')
+
+ dtype = img.dtype
+ img = img.astype(np.uint8)
+ hsv_img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV_FULL)
+ h, s, v = cv2.split(hsv_img)
+ h = h.astype(np.uint8)
+ # uint8 addition take cares of rotation across boundaries
+ with np.errstate(over='ignore'):
+ h += np.uint8(hue_factor * 255)
+ hsv_img = cv2.merge([h, s, v])
+ return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2RGB_FULL).astype(dtype)
diff --git a/mmcv/mmcv/model_zoo/deprecated.json b/mmcv/mmcv/model_zoo/deprecated.json
new file mode 100644
index 0000000000000000000000000000000000000000..25cf6f28caecc22a77e3136fefa6b8dfc0e6cb5b
--- /dev/null
+++ b/mmcv/mmcv/model_zoo/deprecated.json
@@ -0,0 +1,6 @@
+{
+ "resnet50_caffe": "detectron/resnet50_caffe",
+ "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr",
+ "resnet101_caffe": "detectron/resnet101_caffe",
+ "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr"
+}
diff --git a/mmcv/mmcv/model_zoo/mmcls.json b/mmcv/mmcv/model_zoo/mmcls.json
new file mode 100644
index 0000000000000000000000000000000000000000..c073a41d0aeb44ee0243f97ecc3558de538f9300
--- /dev/null
+++ b/mmcv/mmcv/model_zoo/mmcls.json
@@ -0,0 +1,59 @@
+{
+ "vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth",
+ "vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth",
+ "vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth",
+ "vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth",
+ "vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth",
+ "vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth",
+ "vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth",
+ "vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth",
+ "resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth",
+ "resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth",
+ "resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth",
+ "resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth",
+ "resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_8xb32_in1k_20210901-4d7582fa.pth",
+ "resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_b32x8_imagenet_20210531-db14775a.pth",
+ "resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth",
+ "resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth",
+ "resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth",
+ "resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth",
+ "resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth",
+ "resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth",
+ "se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth",
+ "se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth",
+ "resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth",
+ "resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth",
+ "resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth",
+ "resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth",
+ "shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth",
+ "shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth",
+ "mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth",
+ "mobilenet_v3_small": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_small-8427ecf0.pth",
+ "mobilenet_v3_large": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v3/convert/mobilenet_v3_large-3ea3c186.pth",
+ "repvgg_A0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth",
+ "repvgg_A1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth",
+ "repvgg_A2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth",
+ "repvgg_B0": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth",
+ "repvgg_B1": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth",
+ "repvgg_B1g2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth",
+ "repvgg_B1g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth",
+ "repvgg_B2": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth",
+ "repvgg_B2g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth",
+ "repvgg_B3": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth",
+ "repvgg_B3g4": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth",
+ "repvgg_D2se": "https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth",
+ "res2net101_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net101-w26-s4_3rdparty_8xb32_in1k_20210927-870b6c36.pth",
+ "res2net50_w14": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth",
+ "res2net50_w26": "https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w26-s8_3rdparty_8xb32_in1k_20210927-f547a94b.pth",
+ "swin_tiny": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth",
+ "swin_small": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth",
+ "swin_base": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_base_patch4_window7_224_22kto1k-f967f799.pth",
+ "swin_large": "https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin_large_patch4_window7_224_22kto1k-5f0996db.pth",
+ "t2t_vit_t_14": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth",
+ "t2t_vit_t_19": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth",
+ "t2t_vit_t_24": "https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth",
+ "tnt_small": "https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth",
+ "vit_base_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-98e8652b.pth",
+ "vit_base_p32": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-base-p32_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-9cea8599.pth",
+ "vit_large_p16": "https://download.openmmlab.com/mmclassification/v0/vit/finetune/vit-large-p16_in21k-pre-3rdparty_ft-64xb64_in1k-384_20210928-b20ba619.pth"
+}
diff --git a/mmcv/mmcv/model_zoo/open_mmlab.json b/mmcv/mmcv/model_zoo/open_mmlab.json
new file mode 100644
index 0000000000000000000000000000000000000000..8311db4feef92faa0841c697d75efbee8430c3a0
--- /dev/null
+++ b/mmcv/mmcv/model_zoo/open_mmlab.json
@@ -0,0 +1,50 @@
+{
+ "vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
+ "detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
+ "detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth",
+ "detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
+ "detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
+ "detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth",
+ "resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
+ "resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
+ "resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
+ "contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
+ "detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
+ "detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
+ "jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
+ "jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
+ "jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
+ "jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
+ "jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
+ "jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
+ "msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth",
+ "msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
+ "msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
+ "msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
+ "msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth",
+ "bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
+ "kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
+ "kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
+ "res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth",
+ "regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth",
+ "regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
+ "regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
+ "regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
+ "regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
+ "regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
+ "regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
+ "regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth",
+ "resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth",
+ "resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth",
+ "resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth",
+ "mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth",
+ "mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth",
+ "mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth",
+ "contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth",
+ "contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth",
+ "resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth",
+ "resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth",
+ "resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth",
+ "darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth",
+ "mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth"
+}
diff --git a/mmcv/mmcv/model_zoo/torchvision_0.12.json b/mmcv/mmcv/model_zoo/torchvision_0.12.json
new file mode 100644
index 0000000000000000000000000000000000000000..06defe67484dff91cf6f69109324cb1dd9d64bc3
--- /dev/null
+++ b/mmcv/mmcv/model_zoo/torchvision_0.12.json
@@ -0,0 +1,57 @@
+{
+ "alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
+ "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
+ "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
+ "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
+ "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
+ "efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
+ "efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
+ "efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
+ "efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
+ "efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
+ "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
+ "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
+ "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
+ "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
+ "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
+ "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
+ "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
+ "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
+ "regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
+ "regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
+ "regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
+ "regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
+ "regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
+ "regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
+ "regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
+ "regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
+ "regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
+ "regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
+ "regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
+ "regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
+ "regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
+ "regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
+ "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
+ "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
+ "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
+ "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
+ "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
+ "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
+ "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
+ "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
+ "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
+ "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
+ "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
+ "shufflenetv2_x1.5": null,
+ "shufflenetv2_x2.0": null,
+ "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
+ "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
+ "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
+ "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
+ "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
+ "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
+ "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
+ "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
+ "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
+ "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
+}
diff --git a/mmcv/mmcv/onnx/__init__.py b/mmcv/mmcv/onnx/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d7eb5b0db770144ac6676bd1c7e80d7d2eb7e02
--- /dev/null
+++ b/mmcv/mmcv/onnx/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .info import is_custom_op_loaded
+from .symbolic import register_extra_symbolics
+
+__all__ = ['register_extra_symbolics', 'is_custom_op_loaded']
diff --git a/mmcv/mmcv/onnx/info.py b/mmcv/mmcv/onnx/info.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8325a9c0d0dc3b48b77e9da307341059017ea28
--- /dev/null
+++ b/mmcv/mmcv/onnx/info.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import warnings
+
+import torch
+
+
+def is_custom_op_loaded() -> bool:
+
+ # Following strings of text style are from colorama package
+ bright_style, reset_style = '\x1b[1m', '\x1b[0m'
+ red_text, blue_text = '\x1b[31m', '\x1b[34m'
+ white_background = '\x1b[107m'
+
+ msg = white_background + bright_style + red_text
+ msg += 'DeprecationWarning: This function will be deprecated in future. '
+ msg += blue_text + 'Welcome to use the unified model deployment toolbox '
+ msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
+ msg += reset_style
+ warnings.warn(msg)
+
+ flag = False
+ try:
+ from ..tensorrt import is_tensorrt_plugin_loaded
+ flag = is_tensorrt_plugin_loaded()
+ except (ImportError, ModuleNotFoundError):
+ pass
+ if not flag:
+ try:
+ from ..ops import get_onnxruntime_op_path
+ ort_lib_path = get_onnxruntime_op_path()
+ flag = os.path.exists(ort_lib_path)
+ except (ImportError, ModuleNotFoundError):
+ pass
+ return flag or torch.__version__ == 'parrots'
diff --git a/mmcv/mmcv/onnx/onnx_utils/__init__.py b/mmcv/mmcv/onnx/onnx_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef101fec61e72abc0eb90266d453b5b22331378d
--- /dev/null
+++ b/mmcv/mmcv/onnx/onnx_utils/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) OpenMMLab. All rights reserved.
diff --git a/mmcv/mmcv/onnx/onnx_utils/symbolic_helper.py b/mmcv/mmcv/onnx/onnx_utils/symbolic_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc9e96f8fbbb0cadec23411ddf93b31a90d049d0
--- /dev/null
+++ b/mmcv/mmcv/onnx/onnx_utils/symbolic_helper.py
@@ -0,0 +1,331 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""Modified from https://github.com/pytorch/pytorch."""
+import warnings
+from functools import wraps
+from sys import maxsize
+
+import torch
+import torch.onnx
+# This import monkey-patches graph manipulation methods on Graph, used for the
+# ONNX symbolics
+import torch.onnx.utils
+from torch._C import ListType
+
+# ---------------------------------------------------------------------------------
+# Helper functions
+# ---------------------------------------------------------------------------------
+
+# Save some builtins as locals, because we'll shadown them below
+_sum = sum
+
+
+def _parse_arg(value, desc):
+ if desc == 'none':
+ return value
+ if desc == 'v' or not _is_value(value):
+ return value
+ if value.node().mustBeNone():
+ return None
+ if value.node().kind() == 'onnx::Constant':
+ tval = value.node()['value']
+ if desc == 'i':
+ return int(tval)
+ elif desc == 'f':
+ return float(tval)
+ elif desc == 'b':
+ return bool(tval)
+ elif desc == 's':
+ return str(tval)
+ elif desc == 't':
+ return tval
+ elif desc == 'is':
+ return [int(v) for v in tval]
+ elif desc == 'fs':
+ return [float(v) for v in tval]
+ else:
+ raise RuntimeError(
+ "ONNX symbolic doesn't know to interpret Constant node")
+ elif value.node().kind() == 'prim::ListConstruct':
+ if desc == 'is':
+ for v in value.node().inputs():
+ if v.node().kind() != 'onnx::Constant':
+ raise RuntimeError(
+ "Failed to export an ONNX attribute '" +
+ v.node().kind() +
+ "', since it's not constant, please try to make "
+ 'things (e.g., kernel size) static if possible')
+ return [int(v.node()['value']) for v in value.node().inputs()]
+ else:
+ raise RuntimeError(
+ "ONNX symbolic doesn't know to interpret ListConstruct node")
+
+ raise RuntimeError(f'Unexpected node type: {value.node().kind()}')
+
+
+def _maybe_get_const(value, desc):
+ if _is_value(value) and value.node().kind() == 'onnx::Constant':
+ return _parse_arg(value, desc)
+ return value
+
+
+def _maybe_get_scalar(value):
+ value_t = _maybe_get_const(value, 't')
+ if isinstance(value_t, torch.Tensor) and value_t.shape == ():
+ return value_t
+ return value
+
+
+def _get_const(value, desc, arg_name):
+ if _is_value(value) and value.node().kind() not in ('onnx::Constant',
+ 'prim::Constant'):
+ raise RuntimeError('ONNX symbolic expected a constant'
+ ' value of the {} argument, got `{}`'.format(
+ arg_name, value))
+ return _parse_arg(value, desc)
+
+
+def _unpack_list(list_value):
+ list_node = list_value.node()
+ assert list_node.kind() == 'prim::ListConstruct'
+ return list(list_node.inputs())
+
+
+# Check if list_value is output from prim::ListConstruct
+# This is usually called before _unpack_list to ensure the list can be
+# unpacked.
+def _is_packed_list(list_value):
+ return _is_value(
+ list_value) and list_value.node().kind() == 'prim::ListConstruct'
+
+
+def parse_args(*arg_descriptors):
+
+ def decorator(fn):
+ fn._arg_descriptors = arg_descriptors
+
+ def wrapper(g, *args):
+ # some args may be optional, so the length may be smaller
+ assert len(arg_descriptors) >= len(args)
+ args = [
+ _parse_arg(arg, arg_desc)
+ for arg, arg_desc in zip(args, arg_descriptors)
+ ]
+ return fn(g, *args)
+
+ # In Python 2 functools.wraps chokes on partially applied functions, so
+ # we need this as a workaround
+ try:
+ wrapper = wraps(fn)(wrapper)
+ except Exception:
+ pass
+ return wrapper
+
+ return decorator
+
+
+def _scalar(x):
+ """Convert a scalar tensor into a Python value."""
+ assert x.numel() == 1
+ return x.item()
+
+
+def _if_scalar_type_as(g, self, tensor):
+ """Convert self into the same type of tensor, as necessary."""
+ if isinstance(self, torch._C.Value):
+ return self
+
+ scalar_type = tensor.type().scalarType()
+ if scalar_type:
+ ty = scalar_type.lower()
+ return getattr(self, ty)()
+
+ return self
+
+
+def _is_none(x):
+ return x.node().mustBeNone()
+
+
+def _is_value(x):
+ return isinstance(x, torch._C.Value)
+
+
+def _is_tensor_list(x):
+ return x.type().isSubtypeOf(ListType.ofTensors())
+
+
+def _unimplemented(op, msg):
+ warnings.warn('ONNX export failed on ' + op + ' because ' + msg +
+ ' not supported')
+
+
+def _try_get_scalar_type(*args):
+ for arg in args:
+ try:
+ return arg.type().scalarType()
+ except RuntimeError:
+ pass
+ return None
+
+
+def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None):
+ if out is not None:
+ _unimplemented('TopK', 'Out parameter is not supported')
+ if not _is_value(k):
+ k = g.op('Constant', value_t=torch.tensor([k], dtype=torch.int64))
+ else:
+ k = g.op('Reshape', k, g.op('Constant', value_t=torch.tensor([1])))
+ return g.op(
+ 'TopK',
+ input,
+ k,
+ axis_i=dim,
+ largest_i=largest,
+ sorted_i=sorted,
+ outputs=2)
+
+
+def _slice_helper(g,
+ input,
+ axes,
+ starts,
+ ends,
+ steps=None,
+ dynamic_slice=False):
+ # TODO(ruobing): add support for opset<10
+ from torch.onnx.symbolic_opset10 import _slice
+ return _slice(g, input, axes, starts, ends, steps, dynamic_slice)
+
+
+def _unsqueeze_helper(g, input, dim):
+ from torch.onnx.symbolic_opset9 import unsqueeze
+ return unsqueeze(g, input, dim)
+
+
+def _interpolate_size_to_scales(g, input, output_size, dim):
+ output_size = _maybe_get_const(output_size, 'is')
+ if _is_value(output_size):
+ offset = 2
+ offsets = g.op(
+ 'Constant', value_t=torch.ones(offset, dtype=torch.float32))
+ dividend = g.op(
+ 'Cast', output_size, to_i=cast_pytorch_to_onnx['Float'])
+ divisor = _slice_helper(
+ g, g.op('Shape', input), axes=[0], ends=[maxsize], starts=[offset])
+ divisor = g.op('Cast', divisor, to_i=cast_pytorch_to_onnx['Float'])
+ scale_dims = g.op('Div', dividend, divisor)
+ scales = g.op('Concat', offsets, scale_dims, axis_i=0)
+ else:
+ scales_constant = [
+ 1. if i < 2 else float(output_size[-(dim - i)]) /
+ float(input.type().sizes()[-(dim - i)]) for i in range(0, dim)
+ ]
+ scales = g.op(
+ 'Constant',
+ value_t=torch.tensor(scales_constant, dtype=torch.float32))
+ return scales
+
+
+def _interpolate_get_scales_if_available(g, scales):
+ if len(scales) == 0:
+ return None
+ # scales[0] is NoneType in Pytorch == 1.5.1
+ # scales[0] is TensorType with sizes = [] in Pytorch == 1.6.0
+ # scales[0] is ListType in Pytorch == 1.7.0
+ # scales[0] is TensorType with sizes = [2] in Pytorch == 1.8.0
+ scale_desc = 'fs' if scales[0].type().kind() == 'ListType' or (
+ scales[0].type().kind() == 'TensorType' and
+ (sum(scales[0].type().sizes()) > 1)) else 'f'
+ available_scales = _maybe_get_const(
+ scales[0], scale_desc) != -1 and not _is_none(scales[0])
+
+ if not available_scales:
+ return None
+
+ offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32))
+ if scale_desc == 'fs':
+ scales_list = g.op(
+ 'Constant',
+ value_t=torch.tensor(_maybe_get_const(scales[0], scale_desc)))
+ # modify to support PyTorch==1.7.0
+ # https://github.com/pytorch/pytorch/blob/75ee5756715e7161314ce037474843b68f69fc04/torch/onnx/symbolic_helper.py#L375 # noqa: E501
+ scales = g.op('Concat', offsets, scales_list, axis_i=0)
+ else:
+ # for PyTorch < 1.7.0
+ scales_list = []
+ for scale in scales:
+ unsqueezed_scale = _unsqueeze_helper(g, scale, 0)
+ # ONNX only supports float for the scales. double -> float.
+ unsqueezed_scale = g.op(
+ 'Cast', unsqueezed_scale, to_i=cast_pytorch_to_onnx['Float'])
+ scales_list.append(unsqueezed_scale)
+ scales = g.op('Concat', offsets, *scales_list, axis_i=0)
+ return scales
+
+
+def _get_interpolate_attributes(g, mode, args):
+ if mode == 'nearest':
+ align_corners = None
+ scales = args[0:]
+ else:
+ align_corners = args[0]
+ scales = args[1:]
+ scales = _interpolate_get_scales_if_available(g, scales)
+ return scales, align_corners
+
+
+def _interpolate_get_scales(g, scale_factor, dim):
+ offsets = g.op('Constant', value_t=torch.ones(2, dtype=torch.float32))
+ if isinstance(scale_factor.type(), torch._C.ListType):
+ return g.op('Concat', offsets, scale_factor, axis_i=0)
+ else:
+ scale_factor = _unsqueeze_helper(g, scale_factor, 0)
+ scale_factor = g.op(
+ 'Cast', scale_factor, to_i=cast_pytorch_to_onnx['Float'])
+ scales = [scale_factor for i in range(dim - 2)]
+ scale_factor = g.op('Concat', offsets, *scales, axis_i=0)
+ return scale_factor
+
+
+def _size_helper(g, self, dim):
+ full_shape = g.op('Shape', self)
+ from torch.onnx.symbolic_opset9 import select
+ return select(g, full_shape, g.op('Constant', value_t=torch.tensor([0])),
+ dim)
+
+
+def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override,
+ name):
+ if divisor_override and divisor_override.node().kind() != 'prim::Constant':
+ return _unimplemented(name, 'divisor_override')
+ if not stride:
+ stride = kernel_size
+ padding = tuple(tuple_fn(padding))
+ return padding
+
+
+# Metaprogram symbolics for each ATen native specialized cast operator.
+# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
+# ONNX cast node with `to` attribute 'UINT8'
+#
+# TODO: remove these once we support Type's in the JIT IR and we can once again
+# use the unified toType operator
+cast_pytorch_to_onnx = {
+ 'Byte': torch.onnx.TensorProtoDataType.UINT8,
+ 'Char': torch.onnx.TensorProtoDataType.INT8,
+ 'Double': torch.onnx.TensorProtoDataType.DOUBLE,
+ 'Float': torch.onnx.TensorProtoDataType.FLOAT,
+ 'Half': torch.onnx.TensorProtoDataType.FLOAT16,
+ 'Int': torch.onnx.TensorProtoDataType.INT32,
+ 'Long': torch.onnx.TensorProtoDataType.INT64,
+ 'Short': torch.onnx.TensorProtoDataType.INT16,
+ 'Bool': torch.onnx.TensorProtoDataType.BOOL,
+ 'ComplexFloat': torch.onnx.TensorProtoDataType.COMPLEX64,
+ 'ComplexDouble': torch.onnx.TensorProtoDataType.COMPLEX128,
+ 'Undefined': torch.onnx.TensorProtoDataType.UNDEFINED,
+}
+
+# Global set to store the list of quantized operators in the network.
+# This is currently only used in the conversion of quantized ops from PT
+# -> C2 via ONNX.
+_quantized_ops: set = set()
diff --git a/mmcv/mmcv/onnx/symbolic.py b/mmcv/mmcv/onnx/symbolic.py
new file mode 100644
index 0000000000000000000000000000000000000000..3599b3f26683ea2d1907aa5e839e02e474791370
--- /dev/null
+++ b/mmcv/mmcv/onnx/symbolic.py
@@ -0,0 +1,509 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""Modified from https://github.com/pytorch/pytorch."""
+import os
+import warnings
+
+import numpy as np
+import torch
+from torch.nn.modules.utils import _pair, _single, _triple
+from torch.onnx.symbolic_helper import parse_args
+from torch.onnx.symbolic_registry import register_op
+
+from .onnx_utils import symbolic_helper as sym_help
+
+
+def _interpolate(name, dim, interpolate_mode):
+
+ def symbolic_fn(g, input, output_size, *args):
+ scales, align_corners = sym_help._get_interpolate_attributes(
+ g, interpolate_mode, args)
+ align_corners = sym_help._maybe_get_scalar(align_corners)
+ transformation_mode = 'asymmetric' \
+ if interpolate_mode == 'nearest' \
+ else 'align_corners' if align_corners else 'pytorch_half_pixel'
+ empty_tensor = g.op(
+ 'Constant', value_t=torch.tensor([], dtype=torch.float32))
+
+ if scales is None:
+ if 'ONNX_BACKEND' in os.environ and os.environ[
+ 'ONNX_BACKEND'] == 'TensorRT':
+ input_size = input.type().sizes()
+ # slice the first two dim
+ input_size = input_size[:2]
+ # convert output_size to int type
+ output_size = sym_help._maybe_get_const(output_size, 'is')
+ input_size.extend(output_size)
+ output_size = g.op(
+ 'Constant',
+ value_t=torch.tensor(input_size, dtype=torch.int64))
+ else:
+ input_size = g.op('Shape', input)
+ input_size_beg = sym_help._slice_helper(
+ g, input_size, axes=[0], ends=[2], starts=[0])
+ output_size = g.op(
+ 'Cast',
+ output_size,
+ to_i=sym_help.cast_pytorch_to_onnx['Long'])
+ output_size = g.op(
+ 'Concat', input_size_beg, output_size, axis_i=0)
+ scales = g.op(
+ 'Constant', value_t=torch.tensor([], dtype=torch.float32))
+ return g.op(
+ 'Resize',
+ input,
+ empty_tensor,
+ # roi only takes effect with
+ # coordinate_transformation_mode="tf_crop_and_resize"
+ scales, # scales is not needed since we are sending out_size
+ output_size,
+ coordinate_transformation_mode_s=transformation_mode,
+ cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
+ mode_s=interpolate_mode, # nearest, linear, or cubic
+ nearest_mode_s='floor') # only valid when mode="nearest"
+ else:
+ return g.op(
+ 'Resize',
+ input,
+ empty_tensor,
+ # roi only takes effect with
+ # coordinate_transformation_mode="tf_crop_and_resize"
+ scales, # scales is not needed since we are sending out_size
+ coordinate_transformation_mode_s=transformation_mode,
+ cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
+ mode_s=interpolate_mode, # nearest, linear, or cubic
+ nearest_mode_s='floor') # only valid when mode="nearest"
+
+ return symbolic_fn
+
+
+upsample_nearest1d = _interpolate('upsample_nearest1d', 3, 'nearest')
+upsample_nearest2d = _interpolate('upsample_nearest2d', 4, 'nearest')
+upsample_nearest3d = _interpolate('upsample_nearest3d', 5, 'nearest')
+upsample_linear1d = _interpolate('upsample_linear1d', 3, 'linear')
+upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, 'linear')
+upsample_trilinear3d = _interpolate('upsample_trilinear3d', 5, 'linear')
+upsample_bicubic2d = _interpolate('upsample_bicubic2d', 4, 'cubic')
+
+
+@parse_args('v', 'v', 'i', 'i', 'i', 'none')
+def topk(g, self, k, dim, largest, sorted, out=None):
+ return sym_help._topk_helper(
+ g, self, k, dim, largest=largest, sorted=sorted, out=out)
+
+
+def masked_select(g, self, mask):
+ from torch.onnx.symbolic_opset9 import expand_as, nonzero
+ index = nonzero(g, expand_as(g, mask, self))
+ return g.op('GatherND', self, index)
+
+
+def _prepare_onnx_paddings(g, dim, pad):
+ pad_len = torch.onnx.symbolic_opset9.size(
+ g, pad, g.op('Constant', value_t=torch.tensor([0])))
+ # Set extension = [0] * (dim * 2 - len(pad))
+ extension = g.op(
+ 'Sub',
+ g.op('Mul',
+ g.op('Constant', value_t=torch.tensor(dim, dtype=torch.int64)),
+ g.op('Constant', value_t=torch.tensor(2, dtype=torch.int64))),
+ pad_len)
+ pad = g.op('Cast', pad, to_i=sym_help.cast_pytorch_to_onnx['Long'])
+ paddings = g.op(
+ 'Concat',
+ pad,
+ g.op(
+ 'ConstantOfShape',
+ extension,
+ value_t=torch.tensor([0], dtype=torch.int64)),
+ axis_i=0)
+ paddings = g.op('Reshape', paddings,
+ g.op('Constant', value_t=torch.tensor([-1, 2])))
+ paddings = g.op(
+ 'Transpose',
+ torch.onnx.symbolic_opset10.flip(g, paddings, [0]),
+ perm_i=[1, 0])
+ paddings = g.op('Reshape', paddings,
+ g.op('Constant', value_t=torch.tensor([-1])))
+ padding_c = g.op(
+ 'Cast', paddings, to_i=sym_help.cast_pytorch_to_onnx['Long'])
+ return padding_c
+
+
+def constant_pad_nd(g, input, padding, value=None):
+ mode = 'constant'
+ value = sym_help._maybe_get_scalar(value)
+ value = sym_help._if_scalar_type_as(g, value, input)
+ pad = _prepare_onnx_paddings(g, input.type().dim(), padding)
+ return g.op('Pad', input, pad, value, mode_s=mode)
+
+
+def reflection_pad(g, input, padding):
+ mode = 'reflect'
+ paddings = _prepare_onnx_paddings(g, input.type().dim(), padding)
+ return g.op('Pad', input, paddings, mode_s=mode)
+
+
+reflection_pad1d = reflection_pad
+reflection_pad2d = reflection_pad
+reflection_pad3d = reflection_pad
+
+
+def _avg_pool(name, tuple_fn):
+
+ @parse_args('v', 'is', 'is', 'is', 'i', 'i', 'none')
+ def symbolic_fn(g,
+ input,
+ kernel_size,
+ stride,
+ padding,
+ ceil_mode,
+ count_include_pad,
+ divisor_override=None):
+ padding = sym_help._avgpool_helper(tuple_fn, padding, kernel_size,
+ stride, divisor_override, name)
+ if not stride:
+ stride = kernel_size
+ if count_include_pad:
+ input = g.op(
+ 'Pad',
+ input,
+ g.op(
+ 'Constant',
+ value_t=torch.tensor(((0, ) * 2 + padding) * 2)),
+ mode_s='constant')
+ padding = (0, ) * len(padding)
+ output = g.op(
+ 'AveragePool',
+ input,
+ kernel_shape_i=tuple_fn(kernel_size),
+ strides_i=tuple_fn(stride),
+ pads_i=padding * 2,
+ ceil_mode_i=ceil_mode)
+ return output
+
+ return symbolic_fn
+
+
+avg_pool1d = _avg_pool('avg_pool1d', _single)
+avg_pool2d = _avg_pool('avg_pool2d', _pair)
+avg_pool3d = _avg_pool('avg_pool3d', _triple)
+
+
+def _get_im2col_indices_along_dim(g, input_d, kernel_size_d, dilation_d,
+ padding_d, stride_d):
+ # Input is always 4-D (N, C, H, W)
+ # Calculate indices of sliding blocks along spatial dimension
+ # Slide kernel over input each dim d:
+ # each dimension d ranges from 0 to
+ # input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
+ # with steps = stride
+
+ blocks_d = g.op('Add', input_d,
+ g.op('Constant', value_t=torch.tensor(padding_d * 2)))
+ blocks_d = g.op(
+ 'Sub', blocks_d,
+ g.op(
+ 'Constant',
+ value_t=torch.tensor(dilation_d * (kernel_size_d - 1))))
+
+ # Stride kernel over input and find starting indices along dim d
+ blocks_d_indices = g.op('Range', g.op('Constant', value_t=torch.tensor(0)),
+ blocks_d,
+ g.op('Constant', value_t=torch.tensor(stride_d)))
+
+ # Apply dilation on kernel and find its indices along dim d
+ kernel_grid = np.arange(0, kernel_size_d * dilation_d, dilation_d)
+ kernel_grid = g.op('Constant', value_t=torch.tensor([kernel_grid]))
+
+ # Broadcast and add kernel staring positions (indices) with
+ # kernel_grid along dim d, to get block indices along dim d
+ blocks_d_indices = g.op(
+ 'Unsqueeze', blocks_d_indices, axes_i=[0]) # Reshape to [1, -1]
+ kernel_mask = g.op('Reshape', kernel_grid,
+ g.op('Constant', value_t=torch.tensor([-1, 1])))
+ block_mask = g.op('Add', blocks_d_indices, kernel_mask)
+
+ return block_mask
+
+
+def _get_im2col_padded_input(g, input, padding_h, padding_w):
+ # Input is always 4-D tensor (N, C, H, W)
+ # Padding tensor has the following format: (padding_h, padding_w)
+ # Reshape the padding to follow ONNX format:
+ # (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
+ pad = g.op(
+ 'Constant', value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2))
+ return g.op('Pad', input, pad)
+
+
+def _get_im2col_output_shape(g, input, kernel_h, kernel_w):
+ batch_dim = size(g, input, g.op('Constant', value_t=torch.tensor(0)))
+ channel_dim = size(g, input, g.op('Constant', value_t=torch.tensor(1)))
+ channel_unfolded = g.op(
+ 'Mul', channel_dim,
+ g.op('Constant', value_t=torch.tensor(kernel_h * kernel_w)))
+
+ return g.op(
+ 'Concat',
+ g.op('Unsqueeze', batch_dim, axes_i=[0]),
+ g.op('Unsqueeze', channel_unfolded, axes_i=[0]),
+ g.op('Constant', value_t=torch.tensor([-1])),
+ axis_i=0)
+
+
+def size(g, self, dim=None):
+ if dim is None:
+ return g.op('Shape', self)
+ return sym_help._size_helper(g, self, dim)
+
+
+@parse_args('v', 'is', 'is', 'is', 'is')
+def im2col(g, input, kernel_size, dilation, padding, stride):
+ # Input is always 4-D tensor (N, C, H, W)
+ # All other args are int[2]
+
+ input_h = size(g, input, g.op('Constant', value_t=torch.tensor(2)))
+ input_w = size(g, input, g.op('Constant', value_t=torch.tensor(3)))
+
+ stride_h, stride_w = stride[0], stride[1]
+ padding_h, padding_w = padding[0], padding[1]
+ dilation_h, dilation_w = dilation[0], dilation[1]
+ kernel_h, kernel_w = kernel_size[0], kernel_size[1]
+
+ blocks_row_indices = _get_im2col_indices_along_dim(g, input_h, kernel_h,
+ dilation_h, padding_h,
+ stride_h)
+ blocks_col_indices = _get_im2col_indices_along_dim(g, input_w, kernel_w,
+ dilation_w, padding_w,
+ stride_w)
+
+ output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w)
+ padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w)
+
+ output = g.op('Gather', padded_input, blocks_row_indices, axis_i=2)
+ output = g.op('Gather', output, blocks_col_indices, axis_i=4)
+ output = g.op('Transpose', output, perm_i=[0, 1, 2, 4, 3, 5])
+ return g.op('Reshape', output, output_shape)
+
+
+@parse_args('v', 'i')
+def one_hot(g, self, num_classes):
+ values = g.op('Constant', value_t=torch.LongTensor([0, 1]))
+ depth = g.op('Constant', value_t=torch.LongTensor([num_classes]))
+ return g.op('OneHot', self, depth, values, axis_i=-1)
+
+
+@parse_args('v', 'i', 'none')
+def softmax(g, input, dim, dtype=None):
+ input_dim = input.type().dim()
+ if input_dim:
+ # TODO: remove this as onnx opset 11 spec allows negative axes
+ if dim < 0:
+ dim = input_dim + dim
+ if input_dim == dim + 1:
+ softmax = g.op('Softmax', input, axis_i=dim)
+ if dtype and dtype.node().kind() != 'prim::Constant':
+ parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
+ softmax = g.op(
+ 'Cast',
+ softmax,
+ to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
+ return softmax
+
+ max_value = g.op('ReduceMax', input, axes_i=[dim], keepdims_i=1)
+ input = g.op('Sub', input, max_value)
+ exp = g.op('Exp', input)
+ sum = g.op('ReduceSum', exp, axes_i=[dim])
+ softmax = g.op('Div', exp, sum)
+ if dtype and dtype.node().kind() != 'prim::Constant':
+ parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
+ softmax = g.op(
+ 'Cast', softmax, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
+ return softmax
+
+
+def _adaptive_pool(name, type, tuple_fn, fn=None):
+
+ @parse_args('v', 'is')
+ def symbolic_fn(g, input, output_size):
+ if output_size == [1] * len(output_size) and type == 'AveragePool':
+ return g.op('GlobalAveragePool', input)
+ if not input.isCompleteTensor():
+ if output_size == [1] * len(output_size):
+ return g.op('GlobalMaxPool', input), None
+ raise NotImplementedError(
+ '[Adaptive pool]:input size not accessible')
+ dim = input.type().sizes()[2:]
+ if output_size == [1] * len(output_size) and type == 'MaxPool':
+ return g.op('GlobalMaxPool', input), None
+
+ # compute stride = floor(input_size / output_size)
+ s = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]
+
+ # compute kernel_size = input_size - (output_size - 1) * stride
+ k = [dim[i] - (output_size[i] - 1) * s[i] for i in range(0, len(dim))]
+
+ # call max_poolxd_with_indices to get indices in the output
+ if type == 'MaxPool':
+ return fn(g, input, k, k, (0, ) * len(dim), (1, ) * len(dim),
+ False)
+ output = g.op(
+ type,
+ input,
+ kernel_shape_i=tuple_fn(k),
+ strides_i=tuple_fn(s),
+ ceil_mode_i=False)
+ return output
+
+ return symbolic_fn
+
+
+adaptive_avg_pool1d = _adaptive_pool('adaptive_avg_pool1d', 'AveragePool',
+ _single)
+adaptive_avg_pool2d = _adaptive_pool('adaptive_avg_pool2d', 'AveragePool',
+ _pair)
+adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', 'AveragePool',
+ _triple)
+
+
+def new_full(g,
+ self,
+ size,
+ fill_value,
+ dtype,
+ layout,
+ device,
+ pin_memory=False):
+ from torch.onnx.symbolic_opset9 import full
+ if dtype is None and self.isCompleteTensor():
+ dtype = self.type().scalarType()
+ dtype = sym_help.scalar_type_to_onnx.index(
+ sym_help.cast_pytorch_to_onnx[dtype])
+ return full(g, size, fill_value, dtype, layout, device, pin_memory)
+
+
+@parse_args('v', 'v', 'i', 'i', 'i')
+def grid_sampler(g,
+ input,
+ grid,
+ interpolation_mode,
+ padding_mode,
+ align_corners=False):
+ return g.op(
+ 'mmcv::grid_sampler',
+ input,
+ grid,
+ interpolation_mode_i=interpolation_mode,
+ padding_mode_i=padding_mode,
+ align_corners_i=align_corners)
+
+
+@parse_args('v', 'i')
+def cummax(g, input, dim):
+ return g.op('mmcv::cummax', input, dim_i=dim, outputs=2)
+
+
+@parse_args('v', 'i')
+def cummin(g, input, dim):
+ return g.op('mmcv::cummin', input, dim_i=dim, outputs=2)
+
+
+@parse_args('v', 'v', 'is')
+def roll(g, input, shifts, dims):
+ from packaging import version
+ from torch.onnx.symbolic_opset9 import squeeze
+ input_shape = g.op('Shape', input)
+
+ need_flatten = len(dims) == 0
+ # If dims is not specified, the tensor will be flattened before
+ # rolling and then restored to the original shape.
+ if need_flatten:
+ resize_shape = input_shape
+ input = g.op('Reshape', input,
+ g.op('Constant', value_t=torch.LongTensor([1, -1])))
+ input_shape = g.op('Shape', input)
+ dims = [1]
+
+ for index, dim in enumerate(dims):
+ end_size = sym_help._slice_helper(
+ g, input_shape, axes=[0], ends=[dim + 1], starts=[dim])
+ shift_size = sym_help._slice_helper(
+ g, shifts, axes=[0], ends=[index + 1], starts=[index])
+ slice_size = g.op('Sub', end_size, shift_size)
+
+ # Can not use Mod because tensorrt does not support
+ div_size = g.op('Div', slice_size, end_size)
+ slice_size = g.op('Sub', slice_size, g.op('Mul', end_size, div_size))
+
+ if version.parse(torch.__version__) >= version.parse('1.7.0'):
+ # add dim=0 for pytorch 1.9.0
+ end_size = squeeze(g, end_size, 0)
+ slice_size = squeeze(g, slice_size, 0)
+ else:
+ end_size = g.op('Squeeze', end_size)
+ slice_size = g.op('Squeeze', slice_size)
+ dim = torch.LongTensor([dim])
+
+ input_slice0 = sym_help._slice_helper(
+ g,
+ input,
+ axes=dim,
+ starts=torch.LongTensor([0]),
+ ends=slice_size,
+ dynamic_slice=True)
+ input_slice1 = sym_help._slice_helper(
+ g,
+ input,
+ axes=dim,
+ ends=end_size,
+ starts=slice_size,
+ dynamic_slice=True)
+
+ input = g.op('Concat', input_slice1, input_slice0, axis_i=dim)
+
+ if need_flatten:
+ input = g.op('Reshape', input, resize_shape)
+
+ return input
+
+
+def register_extra_symbolics(opset=11):
+ # Following strings of text style are from colorama package
+ bright_style, reset_style = '\x1b[1m', '\x1b[0m'
+ red_text, blue_text = '\x1b[31m', '\x1b[34m'
+ white_background = '\x1b[107m'
+
+ msg = white_background + bright_style + red_text
+ msg += 'DeprecationWarning: This function will be deprecated in future. '
+ msg += blue_text + 'Welcome to use the unified model deployment toolbox '
+ msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
+ msg += reset_style
+ warnings.warn(msg)
+
+ register_op('one_hot', one_hot, '', opset)
+ register_op('im2col', im2col, '', opset)
+ register_op('topk', topk, '', opset)
+ register_op('softmax', softmax, '', opset)
+ register_op('constant_pad_nd', constant_pad_nd, '', opset)
+ register_op('reflection_pad1d', reflection_pad1d, '', opset)
+ register_op('reflection_pad2d', reflection_pad2d, '', opset)
+ register_op('reflection_pad3d', reflection_pad3d, '', opset)
+ register_op('avg_pool1d', avg_pool1d, '', opset)
+ register_op('avg_pool2d', avg_pool2d, '', opset)
+ register_op('avg_pool3d', avg_pool3d, '', opset)
+ register_op('adaptive_avg_pool1d', adaptive_avg_pool1d, '', opset)
+ register_op('adaptive_avg_pool2d', adaptive_avg_pool2d, '', opset)
+ register_op('adaptive_avg_pool3d', adaptive_avg_pool3d, '', opset)
+ register_op('masked_select', masked_select, '', opset)
+ register_op('upsample_nearest1d', upsample_nearest1d, '', opset)
+ register_op('upsample_nearest2d', upsample_nearest2d, '', opset)
+ register_op('upsample_nearest3d', upsample_nearest3d, '', opset)
+ register_op('upsample_linear1d', upsample_linear1d, '', opset)
+ register_op('upsample_bilinear2d', upsample_bilinear2d, '', opset)
+ register_op('upsample_trilinear3d', upsample_trilinear3d, '', opset)
+ register_op('upsample_bicubic2d', upsample_bicubic2d, '', opset)
+ register_op('new_full', new_full, '', opset)
+ register_op('grid_sampler', grid_sampler, '', opset)
+ register_op('cummax', cummax, '', opset)
+ register_op('cummin', cummin, '', opset)
+ register_op('roll', roll, '', opset)
diff --git a/mmcv/mmcv/ops/__init__.py b/mmcv/mmcv/ops/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..a65f14fff5f92039947d82a291fca09408f69f87
--- /dev/null
+++ b/mmcv/mmcv/ops/__init__.py
@@ -0,0 +1,106 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .active_rotated_filter import active_rotated_filter
+from .assign_score_withk import assign_score_withk
+from .ball_query import ball_query
+from .bbox import bbox_overlaps
+from .border_align import BorderAlign, border_align
+from .box_iou_rotated import box_iou_rotated
+from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
+from .cc_attention import CrissCrossAttention
+from .chamfer_distance import chamfer_distance
+from .contour_expand import contour_expand
+from .convex_iou import convex_giou, convex_iou
+from .corner_pool import CornerPool
+from .correlation import Correlation
+from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
+from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack,
+ ModulatedDeformRoIPoolPack, deform_roi_pool)
+from .deprecated_wrappers import Conv2d_deprecated as Conv2d
+from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d
+from .deprecated_wrappers import Linear_deprecated as Linear
+from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
+from .diff_iou_rotated import diff_iou_rotated_2d, diff_iou_rotated_3d
+from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
+ sigmoid_focal_loss, softmax_focal_loss)
+from .furthest_point_sample import (furthest_point_sample,
+ furthest_point_sample_with_dist)
+from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
+from .gather_points import gather_points
+from .group_points import GroupAll, QueryAndGroup, grouping_operation
+from .info import (get_compiler_version, get_compiling_cuda_version,
+ get_onnxruntime_op_path)
+from .iou3d import (boxes_iou3d, boxes_iou_bev, boxes_overlap_bev, nms3d,
+ nms3d_normal, nms_bev, nms_normal_bev)
+from .knn import knn
+from .masked_conv import MaskedConv2d, masked_conv2d
+from .min_area_polygons import min_area_polygons
+from .modulated_deform_conv import (ModulatedDeformConv2d,
+ ModulatedDeformConv2dPack,
+ modulated_deform_conv2d)
+from .multi_scale_deform_attn import MultiScaleDeformableAttention
+from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
+from .pixel_group import pixel_group
+from .point_sample import (SimpleRoIAlign, point_sample,
+ rel_roi_point_to_rel_img_point)
+from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
+ points_in_boxes_part)
+from .points_in_polygons import points_in_polygons
+from .points_sampler import PointsSampler
+from .prroi_pool import PrRoIPool, prroi_pool
+from .psa_mask import PSAMask
+from .riroi_align_rotated import RiRoIAlignRotated, riroi_align_rotated
+from .roi_align import RoIAlign, roi_align
+from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
+from .roi_pool import RoIPool, roi_pool
+from .roiaware_pool3d import RoIAwarePool3d
+from .roipoint_pool3d import RoIPointPool3d
+from .rotated_feature_align import rotated_feature_align
+from .saconv import SAConv2d
+from .scatter_points import DynamicScatter, dynamic_scatter
+from .sparse_conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d,
+ SparseConvTranspose3d, SparseInverseConv2d,
+ SparseInverseConv3d, SubMConv2d, SubMConv3d)
+from .sparse_modules import SparseModule, SparseSequential
+from .sparse_pool import SparseMaxPool2d, SparseMaxPool3d
+from .sparse_structure import SparseConvTensor, scatter_nd
+from .sync_bn import SyncBatchNorm
+from .three_interpolate import three_interpolate
+from .three_nn import three_nn
+from .tin_shift import TINShift, tin_shift
+from .upfirdn2d import upfirdn2d
+from .voxelize import Voxelization, voxelization
+
+__all__ = [
+ 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
+ 'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack',
+ 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
+ 'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss',
+ 'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss',
+ 'get_compiler_version', 'get_compiling_cuda_version',
+ 'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d',
+ 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
+ 'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
+ 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
+ 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
+ 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
+ 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
+ 'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
+ 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
+ 'rotated_feature_align', 'RiRoIAlignRotated', 'riroi_align_rotated',
+ 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup',
+ 'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn',
+ 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
+ 'border_align', 'gather_points', 'furthest_point_sample',
+ 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
+ 'boxes_iou3d', 'boxes_iou_bev', 'boxes_overlap_bev', 'nms_bev',
+ 'nms_normal_bev', 'nms3d', 'nms3d_normal', 'Voxelization', 'voxelization',
+ 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d', 'SparseConv2d',
+ 'SparseConv3d', 'SparseConvTranspose2d', 'SparseConvTranspose3d',
+ 'SparseInverseConv2d', 'SparseInverseConv3d', 'SubMConv2d', 'SubMConv3d',
+ 'SparseModule', 'SparseSequential', 'SparseMaxPool2d', 'SparseMaxPool3d',
+ 'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part',
+ 'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons',
+ 'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou',
+ 'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance',
+ 'PrRoIPool', 'prroi_pool'
+]
diff --git a/mmcv/mmcv/ops/active_rotated_filter.py b/mmcv/mmcv/ops/active_rotated_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..46c2aa7806ab62a6d0544f6dc1fb609af3a8a483
--- /dev/null
+++ b/mmcv/mmcv/ops/active_rotated_filter.py
@@ -0,0 +1,64 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple
+
+import torch
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext',
+ ['active_rotated_filter_forward', 'active_rotated_filter_backward'])
+
+
+class ActiveRotatedFilterFunction(Function):
+ """Encoding the orientation information and generating orientation-
+ sensitive features.
+
+ The details are described in the paper `Align Deep Features for Oriented
+ Object Detection _`.
+ """
+
+ @staticmethod
+ def forward(ctx, input: torch.Tensor,
+ indices: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ input (torch.Tensor): Input features with shape
+ [num_output_planes, num_input_planes, num_orientations, H, W].
+ indices (torch.Tensor): Indices with shape
+ [num_orientations, H, W, num_rotations].
+
+ Returns:
+ torch.Tensor: Refined features with shape [num_output_planes *
+ num_rotations, num_input_planes * num_orientations, H, W].
+ """
+ ctx.save_for_backward(input, indices)
+ op, ip, o, h, w = input.size()
+ o, h, w, r = indices.size()
+ output = input.new_zeros((op * r, ip * o, h, w))
+ ext_module.active_rotated_filter_forward(input, indices, output)
+
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, None]:
+ """
+ Args:
+ grad_output (torch.Tensor): The gradiant of output features
+ with shape [num_output_planes * num_rotations,
+ num_input_planes * num_orientations, H, W].
+
+ Returns:
+ torch.Tensor: The gradiant of input features with shape
+ [num_output_planes, num_input_planes, num_orientations, H, W].
+ """
+ input, indices = ctx.saved_tensors
+ grad_in = torch.zeros_like(input)
+ ext_module.active_rotated_filter_backward(grad_out, indices, grad_in)
+ return grad_in, None
+
+
+active_rotated_filter = ActiveRotatedFilterFunction.apply
diff --git a/mmcv/mmcv/ops/assign_score_withk.py b/mmcv/mmcv/ops/assign_score_withk.py
new file mode 100644
index 0000000000000000000000000000000000000000..deca0892bddc52b51e9d2543a9e893f0bd67ebdb
--- /dev/null
+++ b/mmcv/mmcv/ops/assign_score_withk.py
@@ -0,0 +1,131 @@
+from typing import Tuple
+
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['assign_score_withk_forward', 'assign_score_withk_backward'])
+
+
+class AssignScoreWithK(Function):
+ r"""Perform weighted sum to generate output features according to scores.
+ Modified from `PAConv `_.
+
+ This is a memory-efficient CUDA implementation of assign_scores operation,
+ which first transform all point features with weight bank, then assemble
+ neighbor features with ``knn_idx`` and perform weighted sum of ``scores``.
+
+ See the `paper `_ appendix Sec. D for
+ more detailed descriptions.
+
+ Note:
+ This implementation assumes using ``neighbor`` kernel input, which is
+ (point_features - center_features, point_features).
+ See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/
+ pointnet2/paconv.py#L128 for more details.
+ """
+
+ @staticmethod
+ def forward(ctx,
+ scores: torch.Tensor,
+ point_features: torch.Tensor,
+ center_features: torch.Tensor,
+ knn_idx: torch.Tensor,
+ aggregate: str = 'sum') -> torch.Tensor:
+ """
+ Args:
+ scores (torch.Tensor): (B, npoint, K, M), predicted scores to
+ aggregate weight matrices in the weight bank.
+ ``npoint`` is the number of sampled centers.
+ ``K`` is the number of queried neighbors.
+ ``M`` is the number of weight matrices in the weight bank.
+ point_features (torch.Tensor): (B, N, M, out_dim)
+ Pre-computed point features to be aggregated.
+ center_features (torch.Tensor): (B, N, M, out_dim)
+ Pre-computed center features to be aggregated.
+ knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN.
+ We assume the first idx in each row is the idx of the center.
+ aggregate (str, optional): Aggregation method.
+ Can be 'sum', 'avg' or 'max'. Defaults: 'sum'.
+
+ Returns:
+ torch.Tensor: (B, out_dim, npoint, K), the aggregated features.
+ """
+ agg = {'sum': 0, 'avg': 1, 'max': 2}
+
+ B, N, M, out_dim = point_features.size()
+ _, npoint, K, _ = scores.size()
+
+ output = point_features.new_zeros((B, out_dim, npoint, K))
+ ext_module.assign_score_withk_forward(
+ point_features.contiguous(),
+ center_features.contiguous(),
+ scores.contiguous(),
+ knn_idx.contiguous(),
+ output,
+ B=B,
+ N0=N,
+ N1=npoint,
+ M=M,
+ K=K,
+ O=out_dim,
+ aggregate=agg[aggregate])
+
+ ctx.save_for_backward(output, point_features, center_features, scores,
+ knn_idx)
+ ctx.agg = agg[aggregate]
+
+ return output
+
+ @staticmethod
+ def backward(
+ ctx, grad_out: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None]:
+ """
+ Args:
+ grad_out (torch.Tensor): (B, out_dim, npoint, K)
+
+ Returns:
+ tuple[torch.Tensor]: A tuple contains five elements. The first one
+ is the gradient of ``scores`` whose shape is (B, npoint, K, M). The
+ second is the gradient of ``point_features`` whose shape is
+ (B, N, M, out_dim). The third is the gradient of
+ ``center_features`` with the shape of (B, N, M, out_dim). The last
+ two are ``None``.
+ """
+ _, point_features, center_features, scores, knn_idx = ctx.saved_tensors
+
+ agg = ctx.agg
+
+ B, N, M, out_dim = point_features.size()
+ _, npoint, K, _ = scores.size()
+
+ grad_point_features = point_features.new_zeros(point_features.shape)
+ grad_center_features = center_features.new_zeros(center_features.shape)
+ grad_scores = scores.new_zeros(scores.shape)
+
+ ext_module.assign_score_withk_backward(
+ grad_out.contiguous(),
+ point_features.contiguous(),
+ center_features.contiguous(),
+ scores.contiguous(),
+ knn_idx.contiguous(),
+ grad_point_features,
+ grad_center_features,
+ grad_scores,
+ B=B,
+ N0=N,
+ N1=npoint,
+ M=M,
+ K=K,
+ O=out_dim,
+ aggregate=agg)
+
+ return grad_scores, grad_point_features, \
+ grad_center_features, None, None
+
+
+assign_score_withk = AssignScoreWithK.apply
diff --git a/mmcv/mmcv/ops/ball_query.py b/mmcv/mmcv/ops/ball_query.py
new file mode 100644
index 0000000000000000000000000000000000000000..d24e0446ca81a19a9e2d4b822cb32533f941d78f
--- /dev/null
+++ b/mmcv/mmcv/ops/ball_query.py
@@ -0,0 +1,58 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple
+
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['ball_query_forward'])
+
+
+class BallQuery(Function):
+ """Find nearby points in spherical space."""
+
+ @staticmethod
+ def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
+ xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ min_radius (float): minimum radius of the balls.
+ max_radius (float): maximum radius of the balls.
+ sample_num (int): maximum number of features in the balls.
+ xyz (torch.Tensor): (B, N, 3) xyz coordinates of the features.
+ center_xyz (torch.Tensor): (B, npoint, 3) centers of the ball
+ query.
+
+ Returns:
+ torch.Tensor: (B, npoint, nsample) tensor with the indices of the
+ features that form the query balls.
+ """
+ assert center_xyz.is_contiguous()
+ assert xyz.is_contiguous()
+ assert min_radius < max_radius
+
+ B, N, _ = xyz.size()
+ npoint = center_xyz.size(1)
+ idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int)
+
+ ext_module.ball_query_forward(
+ center_xyz,
+ xyz,
+ idx,
+ b=B,
+ n=N,
+ m=npoint,
+ min_radius=min_radius,
+ max_radius=max_radius,
+ nsample=sample_num)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(idx)
+ return idx
+
+ @staticmethod
+ def backward(ctx, a=None) -> Tuple[None, None, None, None]:
+ return None, None, None, None
+
+
+ball_query = BallQuery.apply
diff --git a/mmcv/mmcv/ops/bbox.py b/mmcv/mmcv/ops/bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf6bd43bbb0adcb4b6d104a815f73ed2e5912069
--- /dev/null
+++ b/mmcv/mmcv/ops/bbox.py
@@ -0,0 +1,130 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['bbox_overlaps'])
+
+
+def _bbox_overlaps_cpu(bboxes1: torch.Tensor,
+ bboxes2: torch.Tensor,
+ mode: str = 'iou',
+ aligned: bool = False,
+ offset: int = 0) -> torch.Tensor:
+ assert mode in ['iou', 'iof']
+
+ if aligned:
+ lt = torch.max(bboxes1[:, :2], bboxes2[:, :2]) # [rows, 2]
+ rb = torch.min(bboxes1[:, 2:], bboxes2[:, 2:]) # [rows, 2]
+
+ wh = (rb - lt + offset).clamp(min=0) # [rows, 2]
+ overlap = wh[:, 0] * wh[:, 1]
+ area1 = (bboxes1[:, 2] - bboxes1[:, 0] + offset) * (
+ bboxes1[:, 3] - bboxes1[:, 1] + offset)
+
+ if mode == 'iou':
+ area2 = (bboxes2[:, 2] - bboxes2[:, 0] + offset) * (
+ bboxes2[:, 3] - bboxes2[:, 1] + offset)
+ ious = overlap / (area1 + area2 - overlap)
+ else:
+ ious = overlap / area1
+ else:
+ lt = torch.max(bboxes1[:, None, :2], bboxes2[:, :2]) # [rows, cols, 2]
+ rb = torch.min(bboxes1[:, None, 2:], bboxes2[:, 2:]) # [rows, cols, 2]
+
+ wh = (rb - lt + offset).clamp(min=0) # [rows, cols, 2]
+ overlap = wh[:, :, 0] * wh[:, :, 1]
+ area1 = (bboxes1[:, 2] - bboxes1[:, 0] + offset) * (
+ bboxes1[:, 3] - bboxes1[:, 1] + offset)
+
+ if mode == 'iou':
+ area2 = (bboxes2[:, 2] - bboxes2[:, 0] + offset) * (
+ bboxes2[:, 3] - bboxes2[:, 1] + offset)
+ ious = overlap / (area1[:, None] + area2 - overlap)
+ else:
+ ious = overlap / (area1[:, None])
+
+ return ious
+
+
+def bbox_overlaps(bboxes1: torch.Tensor,
+ bboxes2: torch.Tensor,
+ mode: str = 'iou',
+ aligned: bool = False,
+ offset: int = 0) -> torch.Tensor:
+ """Calculate overlap between two set of bboxes.
+
+ If ``aligned`` is ``False``, then calculate the ious between each bbox
+ of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
+ bboxes1 and bboxes2.
+
+ Args:
+ bboxes1 (torch.Tensor): shape (m, 4) in format or
+ empty.
+ bboxes2 (torch.Tensor): shape (n, 4) in format or
+ empty. If aligned is ``True``, then m and n must be equal.
+ mode (str): "iou" (intersection over union) or iof (intersection over
+ foreground).
+
+ Returns:
+ torch.Tensor: Return the ious betweens boxes. If ``aligned`` is
+ ``False``, the shape of ious is (m, n) else (m, 1).
+
+ Example:
+ >>> bboxes1 = torch.FloatTensor([
+ >>> [0, 0, 10, 10],
+ >>> [10, 10, 20, 20],
+ >>> [32, 32, 38, 42],
+ >>> ])
+ >>> bboxes2 = torch.FloatTensor([
+ >>> [0, 0, 10, 20],
+ >>> [0, 10, 10, 19],
+ >>> [10, 10, 20, 20],
+ >>> ])
+ >>> bbox_overlaps(bboxes1, bboxes2)
+ tensor([[0.5000, 0.0000, 0.0000],
+ [0.0000, 0.0000, 1.0000],
+ [0.0000, 0.0000, 0.0000]])
+
+ Example:
+ >>> empty = torch.FloatTensor([])
+ >>> nonempty = torch.FloatTensor([
+ >>> [0, 0, 10, 9],
+ >>> ])
+ >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
+ >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
+ >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
+ """
+
+ mode_dict = {'iou': 0, 'iof': 1}
+ assert mode in mode_dict.keys()
+ mode_flag = mode_dict[mode]
+ # Either the boxes are empty or the length of boxes' last dimension is 4
+ assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)
+ assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)
+ assert offset == 1 or offset == 0
+
+ rows = bboxes1.size(0)
+ cols = bboxes2.size(0)
+ if aligned:
+ assert rows == cols
+
+ if rows * cols == 0:
+ return bboxes1.new(rows, 1) if aligned else bboxes1.new(rows, cols)
+
+ if bboxes1.device.type == 'cpu':
+ return _bbox_overlaps_cpu(
+ bboxes1, bboxes2, mode=mode, aligned=aligned, offset=offset)
+ else:
+ if aligned:
+ ious = bboxes1.new_zeros(rows)
+ else:
+ ious = bboxes1.new_zeros((rows, cols))
+ ext_module.bbox_overlaps(
+ bboxes1,
+ bboxes2,
+ ious,
+ mode=mode_flag,
+ aligned=aligned,
+ offset=offset)
+ return ious
diff --git a/mmcv/mmcv/ops/border_align.py b/mmcv/mmcv/ops/border_align.py
new file mode 100644
index 0000000000000000000000000000000000000000..c09501b962cfce10b1da87e6b651d61911eb8406
--- /dev/null
+++ b/mmcv/mmcv/ops/border_align.py
@@ -0,0 +1,114 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# modified from
+# https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/border_align.py
+
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['border_align_forward', 'border_align_backward'])
+
+
+class BorderAlignFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, boxes, pool_size):
+ return g.op(
+ 'mmcv::MMCVBorderAlign', input, boxes, pool_size_i=pool_size)
+
+ @staticmethod
+ def forward(ctx, input: torch.Tensor, boxes: torch.Tensor,
+ pool_size: int) -> torch.Tensor:
+ ctx.pool_size = pool_size
+ ctx.input_shape = input.size()
+
+ assert boxes.ndim == 3, 'boxes must be with shape [B, H*W, 4]'
+ assert boxes.size(2) == 4, \
+ 'the last dimension of boxes must be (x1, y1, x2, y2)'
+ assert input.size(1) % 4 == 0, \
+ 'the channel for input feature must be divisible by factor 4'
+
+ # [B, C//4, H*W, 4]
+ output_shape = (input.size(0), input.size(1) // 4, boxes.size(1), 4)
+ output = input.new_zeros(output_shape)
+ # `argmax_idx` only used for backward
+ argmax_idx = input.new_zeros(output_shape).to(torch.int)
+
+ ext_module.border_align_forward(
+ input, boxes, output, argmax_idx, pool_size=ctx.pool_size)
+
+ ctx.save_for_backward(boxes, argmax_idx)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx,
+ grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
+ boxes, argmax_idx = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(ctx.input_shape)
+ # complex head architecture may cause grad_output uncontiguous
+ grad_output = grad_output.contiguous()
+ ext_module.border_align_backward(
+ grad_output,
+ boxes,
+ argmax_idx,
+ grad_input,
+ pool_size=ctx.pool_size)
+ return grad_input, None, None
+
+
+border_align = BorderAlignFunction.apply
+
+
+class BorderAlign(nn.Module):
+ r"""Border align pooling layer.
+
+ Applies border_align over the input feature based on predicted bboxes.
+ The details were described in the paper
+ `BorderDet: Border Feature for Dense Object Detection
+ `_.
+
+ For each border line (e.g. top, left, bottom or right) of each box,
+ border_align does the following:
+
+ 1. uniformly samples ``pool_size`` +1 positions on this line, involving
+ the start and end points.
+ 2. the corresponding features on these points are computed by bilinear
+ interpolation.
+ 3. max pooling over all the ``pool_size`` +1 positions are used for
+ computing pooled feature.
+
+ Args:
+ pool_size (int): number of positions sampled over the boxes' borders
+ (e.g. top, bottom, left, right).
+ """
+
+ def __init__(self, pool_size: int):
+ super().__init__()
+ self.pool_size = pool_size
+
+ def forward(self, input: torch.Tensor,
+ boxes: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ input: Features with shape [N,4C,H,W]. Channels ranged in [0,C),
+ [C,2C), [2C,3C), [3C,4C) represent the top, left, bottom,
+ right features respectively.
+ boxes: Boxes with shape [N,H*W,4]. Coordinate format (x1,y1,x2,y2).
+
+ Returns:
+ torch.Tensor: Pooled features with shape [N,C,H*W,4]. The order is
+ (top,left,bottom,right) for the last dimension.
+ """
+ return border_align(input, boxes, self.pool_size)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(pool_size={self.pool_size})'
+ return s
diff --git a/mmcv/mmcv/ops/box_iou_rotated.py b/mmcv/mmcv/ops/box_iou_rotated.py
new file mode 100644
index 0000000000000000000000000000000000000000..2443af27c92146ed4328e8f94b1415c7e72c542b
--- /dev/null
+++ b/mmcv/mmcv/ops/box_iou_rotated.py
@@ -0,0 +1,148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated'])
+
+
+def box_iou_rotated(bboxes1: torch.Tensor,
+ bboxes2: torch.Tensor,
+ mode: str = 'iou',
+ aligned: bool = False,
+ clockwise: bool = True) -> torch.Tensor:
+ """Return intersection-over-union (Jaccard index) of boxes.
+
+ Both sets of boxes are expected to be in
+ (x_center, y_center, width, height, angle) format.
+
+ If ``aligned`` is ``False``, then calculate the ious between each bbox
+ of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
+ bboxes1 and bboxes2.
+
+ .. note::
+ The operator assumes:
+
+ 1) The positive direction along x axis is left -> right.
+
+ 2) The positive direction along y axis is top -> down.
+
+ 3) The w border is in parallel with x axis when angle = 0.
+
+ However, there are 2 opposite definitions of the positive angular
+ direction, clockwise (CW) and counter-clockwise (CCW). MMCV supports
+ both definitions and uses CW by default.
+
+ Please set ``clockwise=False`` if you are using the CCW definition.
+
+ The coordinate system when ``clockwise`` is ``True`` (default)
+
+ .. code-block:: none
+
+ 0-------------------> x (0 rad)
+ | A-------------B
+ | | |
+ | | box h
+ | | angle=0 |
+ | D------w------C
+ v
+ y (pi/2 rad)
+
+ In such coordination system the rotation matrix is
+
+ .. math::
+ \\begin{pmatrix}
+ \\cos\\alpha & -\\sin\\alpha \\\\
+ \\sin\\alpha & \\cos\\alpha
+ \\end{pmatrix}
+
+ The coordinates of the corner point A can be calculated as:
+
+ .. math::
+ P_A=
+ \\begin{pmatrix} x_A \\\\ y_A\\end{pmatrix}
+ =
+ \\begin{pmatrix} x_{center} \\\\ y_{center}\\end{pmatrix} +
+ \\begin{pmatrix}\\cos\\alpha & -\\sin\\alpha \\\\
+ \\sin\\alpha & \\cos\\alpha\\end{pmatrix}
+ \\begin{pmatrix} -0.5w \\\\ -0.5h\\end{pmatrix} \\\\
+ =
+ \\begin{pmatrix} x_{center}-0.5w\\cos\\alpha+0.5h\\sin\\alpha
+ \\\\
+ y_{center}-0.5w\\sin\\alpha-0.5h\\cos\\alpha\\end{pmatrix}
+
+
+ The coordinate system when ``clockwise`` is ``False``
+
+ .. code-block:: none
+
+ 0-------------------> x (0 rad)
+ | A-------------B
+ | | |
+ | | box h
+ | | angle=0 |
+ | D------w------C
+ v
+ y (-pi/2 rad)
+
+ In such coordination system the rotation matrix is
+
+ .. math::
+ \\begin{pmatrix}
+ \\cos\\alpha & \\sin\\alpha \\\\
+ -\\sin\\alpha & \\cos\\alpha
+ \\end{pmatrix}
+
+ The coordinates of the corner point A can be calculated as:
+
+ .. math::
+ P_A=
+ \\begin{pmatrix} x_A \\\\ y_A\\end{pmatrix}
+ =
+ \\begin{pmatrix} x_{center} \\\\ y_{center}\\end{pmatrix} +
+ \\begin{pmatrix}\\cos\\alpha & \\sin\\alpha \\\\
+ -\\sin\\alpha & \\cos\\alpha\\end{pmatrix}
+ \\begin{pmatrix} -0.5w \\\\ -0.5h\\end{pmatrix} \\\\
+ =
+ \\begin{pmatrix} x_{center}-0.5w\\cos\\alpha-0.5h\\sin\\alpha
+ \\\\
+ y_{center}+0.5w\\sin\\alpha-0.5h\\cos\\alpha\\end{pmatrix}
+
+ Args:
+ boxes1 (torch.Tensor): rotated bboxes 1. It has shape (N, 5),
+ indicating (x, y, w, h, theta) for each row. Note that theta is in
+ radian.
+ boxes2 (torch.Tensor): rotated bboxes 2. It has shape (M, 5),
+ indicating (x, y, w, h, theta) for each row. Note that theta is in
+ radian.
+ mode (str): "iou" (intersection over union) or iof (intersection over
+ foreground).
+ clockwise (bool): flag indicating whether the positive angular
+ orientation is clockwise. default True.
+ `New in version 1.4.3.`
+
+ Returns:
+ torch.Tensor: Return the ious betweens boxes. If ``aligned`` is
+ ``False``, the shape of ious is (N, M) else (N,).
+ """
+ assert mode in ['iou', 'iof']
+ mode_dict = {'iou': 0, 'iof': 1}
+ mode_flag = mode_dict[mode]
+ rows = bboxes1.size(0)
+ cols = bboxes2.size(0)
+ if aligned:
+ ious = bboxes1.new_zeros(rows)
+ else:
+ ious = bboxes1.new_zeros(rows * cols)
+ if not clockwise:
+ flip_mat = bboxes1.new_ones(bboxes1.shape[-1])
+ flip_mat[-1] = -1
+ bboxes1 = bboxes1 * flip_mat
+ bboxes2 = bboxes2 * flip_mat
+ bboxes1 = bboxes1.contiguous()
+ bboxes2 = bboxes2.contiguous()
+ ext_module.box_iou_rotated(
+ bboxes1, bboxes2, ious, mode_flag=mode_flag, aligned=aligned)
+ if not aligned:
+ ious = ious.view(rows, cols)
+ return ious
diff --git a/mmcv/mmcv/ops/carafe.py b/mmcv/mmcv/ops/carafe.py
new file mode 100644
index 0000000000000000000000000000000000000000..18230c08074f5309e791810a4774e294084c3f5b
--- /dev/null
+++ b/mmcv/mmcv/ops/carafe.py
@@ -0,0 +1,301 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.autograd import Function
+from torch.nn.modules.module import Module
+
+from ..cnn import UPSAMPLE_LAYERS, normal_init, xavier_init
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'carafe_naive_forward', 'carafe_naive_backward', 'carafe_forward',
+ 'carafe_backward'
+])
+
+
+class CARAFENaiveFunction(Function):
+
+ @staticmethod
+ def symbolic(g, features: Tensor, masks: Tensor, kernel_size: int,
+ group_size: int, scale_factor: int) -> Tensor:
+ return g.op(
+ 'mmcv::MMCVCARAFENaive',
+ features,
+ masks,
+ kernel_size_i=kernel_size,
+ group_size_i=group_size,
+ scale_factor_f=scale_factor)
+
+ @staticmethod
+ def forward(ctx, features: Tensor, masks: Tensor, kernel_size: int,
+ group_size: int, scale_factor: int) -> Tensor:
+ assert scale_factor >= 1
+ assert masks.size(1) == kernel_size * kernel_size * group_size
+ assert masks.size(-1) == features.size(-1) * scale_factor
+ assert masks.size(-2) == features.size(-2) * scale_factor
+ assert features.size(1) % group_size == 0
+ assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
+ ctx.kernel_size = kernel_size
+ ctx.group_size = group_size
+ ctx.scale_factor = scale_factor
+ ctx.feature_size = features.size()
+ ctx.mask_size = masks.size()
+
+ n, c, h, w = features.size()
+ output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
+ ext_module.carafe_naive_forward(
+ features,
+ masks,
+ output,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+
+ if features.requires_grad or masks.requires_grad or \
+ torch.__version__ == 'parrots':
+ ctx.save_for_backward(features, masks)
+ return output
+
+ @staticmethod
+ def backward(
+ ctx,
+ grad_output: Tensor) -> Tuple[Tensor, Tensor, None, None, None]:
+ assert grad_output.is_cuda
+
+ features, masks = ctx.saved_tensors
+ kernel_size = ctx.kernel_size
+ group_size = ctx.group_size
+ scale_factor = ctx.scale_factor
+
+ grad_input = torch.zeros_like(features)
+ grad_masks = torch.zeros_like(masks)
+ ext_module.carafe_naive_backward(
+ grad_output.contiguous(),
+ features,
+ masks,
+ grad_input,
+ grad_masks,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+
+ return grad_input, grad_masks, None, None, None
+
+
+carafe_naive = CARAFENaiveFunction.apply
+
+
+class CARAFENaive(Module):
+
+ def __init__(self, kernel_size: int, group_size: int, scale_factor: int):
+ super().__init__()
+
+ assert isinstance(kernel_size, int) and isinstance(
+ group_size, int) and isinstance(scale_factor, int)
+ self.kernel_size = kernel_size
+ self.group_size = group_size
+ self.scale_factor = scale_factor
+
+ def forward(self, features: Tensor, masks: Tensor) -> Tensor:
+ return carafe_naive(features, masks, self.kernel_size, self.group_size,
+ self.scale_factor)
+
+
+class CARAFEFunction(Function):
+
+ @staticmethod
+ def symbolic(g, features: Tensor, masks: Tensor, kernel_size: int,
+ group_size: int, scale_factor: int) -> Tensor:
+ return g.op(
+ 'mmcv::MMCVCARAFE',
+ features,
+ masks,
+ kernel_size_i=kernel_size,
+ group_size_i=group_size,
+ scale_factor_f=scale_factor)
+
+ @staticmethod
+ def forward(ctx, features: Tensor, masks: Tensor, kernel_size: int,
+ group_size: int, scale_factor: int) -> Tensor:
+ assert scale_factor >= 1
+ assert masks.size(1) == kernel_size * kernel_size * group_size
+ assert masks.size(-1) == features.size(-1) * scale_factor
+ assert masks.size(-2) == features.size(-2) * scale_factor
+ assert features.size(1) % group_size == 0
+ assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
+ ctx.kernel_size = kernel_size
+ ctx.group_size = group_size
+ ctx.scale_factor = scale_factor
+ ctx.feature_size = features.size()
+ ctx.mask_size = masks.size()
+
+ n, c, h, w = features.size()
+ output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
+ routput = features.new_zeros(output.size(), requires_grad=False)
+ rfeatures = features.new_zeros(features.size(), requires_grad=False)
+ rmasks = masks.new_zeros(masks.size(), requires_grad=False)
+ ext_module.carafe_forward(
+ features,
+ masks,
+ rfeatures,
+ routput,
+ rmasks,
+ output,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+
+ if features.requires_grad or masks.requires_grad or \
+ torch.__version__ == 'parrots':
+ ctx.save_for_backward(features, masks, rfeatures)
+ return output
+
+ @staticmethod
+ def backward(
+ ctx,
+ grad_output: Tensor) -> Tuple[Tensor, Tensor, None, None, None]:
+ assert grad_output.is_cuda
+
+ features, masks, rfeatures = ctx.saved_tensors
+ kernel_size = ctx.kernel_size
+ group_size = ctx.group_size
+ scale_factor = ctx.scale_factor
+
+ rgrad_output = torch.zeros_like(grad_output, requires_grad=False)
+ rgrad_input_hs = torch.zeros_like(grad_output, requires_grad=False)
+ rgrad_input = torch.zeros_like(features, requires_grad=False)
+ rgrad_masks = torch.zeros_like(masks, requires_grad=False)
+ grad_input = torch.zeros_like(features, requires_grad=False)
+ grad_masks = torch.zeros_like(masks, requires_grad=False)
+ ext_module.carafe_backward(
+ grad_output.contiguous(),
+ rfeatures,
+ masks,
+ rgrad_output,
+ rgrad_input_hs,
+ rgrad_input,
+ rgrad_masks,
+ grad_input,
+ grad_masks,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+ return grad_input, grad_masks, None, None, None
+
+
+carafe = CARAFEFunction.apply
+
+
+class CARAFE(Module):
+ """ CARAFE: Content-Aware ReAssembly of FEatures
+
+ Please refer to `CARAFE: Content-Aware ReAssembly of FEatures
+ `_ for more details.
+
+ Args:
+ kernel_size (int): reassemble kernel size
+ group_size (int): reassemble group size
+ scale_factor (int): upsample ratio
+
+ Returns:
+ upsampled feature map
+ """
+
+ def __init__(self, kernel_size: int, group_size: int, scale_factor: int):
+ super().__init__()
+
+ assert isinstance(kernel_size, int) and isinstance(
+ group_size, int) and isinstance(scale_factor, int)
+ self.kernel_size = kernel_size
+ self.group_size = group_size
+ self.scale_factor = scale_factor
+
+ def forward(self, features: Tensor, masks: Tensor) -> Tensor:
+ return carafe(features, masks, self.kernel_size, self.group_size,
+ self.scale_factor)
+
+
+@UPSAMPLE_LAYERS.register_module(name='carafe')
+class CARAFEPack(nn.Module):
+ """A unified package of CARAFE upsampler that contains: 1) channel
+ compressor 2) content encoder 3) CARAFE op.
+
+ Official implementation of ICCV 2019 paper
+ `CARAFE: Content-Aware ReAssembly of FEatures
+ `_.
+
+ Args:
+ channels (int): input feature channels
+ scale_factor (int): upsample ratio
+ up_kernel (int): kernel size of CARAFE op
+ up_group (int): group size of CARAFE op
+ encoder_kernel (int): kernel size of content encoder
+ encoder_dilation (int): dilation of content encoder
+ compressed_channels (int): output channels of channels compressor
+
+ Returns:
+ upsampled feature map
+ """
+
+ def __init__(self,
+ channels: int,
+ scale_factor: int,
+ up_kernel: int = 5,
+ up_group: int = 1,
+ encoder_kernel: int = 3,
+ encoder_dilation: int = 1,
+ compressed_channels: int = 64):
+ super().__init__()
+ self.channels = channels
+ self.scale_factor = scale_factor
+ self.up_kernel = up_kernel
+ self.up_group = up_group
+ self.encoder_kernel = encoder_kernel
+ self.encoder_dilation = encoder_dilation
+ self.compressed_channels = compressed_channels
+ self.channel_compressor = nn.Conv2d(channels, self.compressed_channels,
+ 1)
+ self.content_encoder = nn.Conv2d(
+ self.compressed_channels,
+ self.up_kernel * self.up_kernel * self.up_group *
+ self.scale_factor * self.scale_factor,
+ self.encoder_kernel,
+ padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
+ dilation=self.encoder_dilation,
+ groups=1)
+ self.init_weights()
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+ normal_init(self.content_encoder, std=0.001)
+
+ def kernel_normalizer(self, mask: Tensor) -> Tensor:
+ mask = F.pixel_shuffle(mask, self.scale_factor)
+ n, mask_c, h, w = mask.size()
+ # use float division explicitly,
+ # to void inconsistency while exporting to onnx
+ mask_channel = int(mask_c / float(self.up_kernel**2))
+ mask = mask.view(n, mask_channel, -1, h, w)
+
+ mask = F.softmax(mask, dim=2, dtype=mask.dtype)
+ mask = mask.view(n, mask_c, h, w).contiguous()
+
+ return mask
+
+ def feature_reassemble(self, x: Tensor, mask: Tensor) -> Tensor:
+ x = carafe(x, mask, self.up_kernel, self.up_group, self.scale_factor)
+ return x
+
+ def forward(self, x: Tensor) -> Tensor:
+ compressed_x = self.channel_compressor(x)
+ mask = self.content_encoder(compressed_x)
+ mask = self.kernel_normalizer(mask)
+
+ x = self.feature_reassemble(x, mask)
+ return x
diff --git a/mmcv/mmcv/ops/cc_attention.py b/mmcv/mmcv/ops/cc_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e5d3325263f18f6b5eb0bfbc522eeaef1999e3b
--- /dev/null
+++ b/mmcv/mmcv/ops/cc_attention.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mmcv.cnn import PLUGIN_LAYERS, Scale
+
+
+def NEG_INF_DIAG(n: int, device: torch.device) -> torch.Tensor:
+ """Returns a diagonal matrix of size [n, n].
+
+ The diagonal are all "-inf". This is for avoiding calculating the
+ overlapped element in the Criss-Cross twice.
+ """
+ return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0)
+
+
+@PLUGIN_LAYERS.register_module()
+class CrissCrossAttention(nn.Module):
+ """Criss-Cross Attention Module.
+
+ .. note::
+ Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch
+ to a pure PyTorch and equivalent implementation. For more
+ details, please refer to https://github.com/open-mmlab/mmcv/pull/1201.
+
+ Speed comparison for one forward pass
+
+ - Input size: [2,512,97,97]
+ - Device: 1 NVIDIA GeForce RTX 2080 Ti
+
+ +-----------------------+---------------+------------+---------------+
+ | |PyTorch version|CUDA version|Relative speed |
+ +=======================+===============+============+===============+
+ |with torch.no_grad() |0.00554402 s |0.0299619 s |5.4x |
+ +-----------------------+---------------+------------+---------------+
+ |no with torch.no_grad()|0.00562803 s |0.0301349 s |5.4x |
+ +-----------------------+---------------+------------+---------------+
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ """
+
+ def __init__(self, in_channels: int) -> None:
+ super().__init__()
+ self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
+ self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
+ self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
+ self.gamma = Scale(0.)
+ self.in_channels = in_channels
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """forward function of Criss-Cross Attention.
+
+ Args:
+ x (torch.Tensor): Input feature with the shape of
+ (batch_size, in_channels, height, width).
+
+ Returns:
+ torch.Tensor: Output of the layer, with the shape of
+ (batch_size, in_channels, height, width)
+ """
+ B, C, H, W = x.size()
+ query = self.query_conv(x)
+ key = self.key_conv(x)
+ value = self.value_conv(x)
+ energy_H = torch.einsum('bchw,bciw->bwhi', query, key) + NEG_INF_DIAG(
+ H, query.device)
+ energy_H = energy_H.transpose(1, 2)
+ energy_W = torch.einsum('bchw,bchj->bhwj', query, key)
+ attn = F.softmax(
+ torch.cat([energy_H, energy_W], dim=-1), dim=-1) # [B,H,W,(H+W)]
+ out = torch.einsum('bciw,bhwi->bchw', value, attn[..., :H])
+ out += torch.einsum('bchj,bhwj->bchw', value, attn[..., H:])
+
+ out = self.gamma(out) + x
+ out = out.contiguous()
+
+ return out
+
+ def __repr__(self) -> str:
+ s = self.__class__.__name__
+ s += f'(in_channels={self.in_channels})'
+ return s
diff --git a/mmcv/mmcv/ops/chamfer_distance.py b/mmcv/mmcv/ops/chamfer_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..d68eafb47c85418c374a1eaf086478e3fc0cb1d1
--- /dev/null
+++ b/mmcv/mmcv/ops/chamfer_distance.py
@@ -0,0 +1,95 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Sequence, Tuple
+
+import torch
+from torch import Tensor
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['chamfer_distance_forward', 'chamfer_distance_backward'])
+
+
+class ChamferDistanceFunction(Function):
+ """This is an implementation of the 2D Chamfer Distance.
+
+ It has been used in the paper `Oriented RepPoints for Aerial Object
+ Detection (CVPR 2022) _`.
+ """
+
+ @staticmethod
+ def forward(ctx, xyz1: Tensor, xyz2: Tensor) -> Sequence[Tensor]:
+ """
+ Args:
+ xyz1 (Tensor): Point set with shape (B, N, 2).
+ xyz2 (Tensor): Point set with shape (B, N, 2).
+
+ Returns:
+ Sequence[Tensor]:
+
+ - dist1 (Tensor): Chamfer distance (xyz1 to xyz2) with
+ shape (B, N).
+ - dist2 (Tensor): Chamfer distance (xyz2 to xyz1) with
+ shape (B, N).
+ - idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2)
+ with shape (B, N), which be used in compute gradient.
+ - idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2)
+ with shape (B, N), which be used in compute gradient.
+ """
+ batch_size, n, _ = xyz1.size()
+ _, m, _ = xyz2.size()
+ device = xyz1.device
+ xyz1 = xyz1.contiguous()
+ xyz2 = xyz2.contiguous()
+
+ dist1 = torch.zeros(batch_size, n).to(device)
+ dist2 = torch.zeros(batch_size, m).to(device)
+ idx1 = torch.zeros(batch_size, n).type(torch.IntTensor).to(device)
+ idx2 = torch.zeros(batch_size, m).type(torch.IntTensor).to(device)
+
+ ext_module.chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1,
+ idx2)
+ ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
+ return dist1, dist2, idx1, idx2
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_dist1: Tensor, grad_dist2: Tensor,
+ grad_idx1: Tensor,
+ grad_idx2: Tensor) -> Tuple[Tensor, Tensor]:
+ """
+
+ Args:
+ grad_dist1 (Tensor): Gradient of chamfer distance
+ (xyz1 to xyz2) with shape (B, N).
+ grad_dist2 (Tensor): Gradient of chamfer distance
+ (xyz2 to xyz1) with shape (B, N).
+ grad_idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2)
+ with shape (B, N), which be used in compute gradient.
+ grad_idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2)
+ with shape (B, N), which be used in compute gradient.
+
+ Returns:
+ Tuple[Tensor, Tensor]:
+
+ - grad_xyz1 (Tensor): Gradient of the point set with shape \
+ (B, N, 2).
+ - grad_xyz2 (Tensor):Gradient of the point set with shape \
+ (B, N, 2).
+ """
+ xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
+ device = grad_dist1.device
+ grad_dist1 = grad_dist1.contiguous()
+ grad_dist2 = grad_dist2.contiguous()
+ grad_xyz1 = torch.zeros(xyz1.size()).to(device)
+ grad_xyz2 = torch.zeros(xyz2.size()).to(device)
+
+ ext_module.chamfer_distance_backward(xyz1, xyz2, grad_xyz1, grad_xyz2,
+ grad_dist1, grad_dist2, idx1,
+ idx2)
+ return grad_xyz1, grad_xyz2
+
+
+chamfer_distance = ChamferDistanceFunction.apply
diff --git a/mmcv/mmcv/ops/contour_expand.py b/mmcv/mmcv/ops/contour_expand.py
new file mode 100644
index 0000000000000000000000000000000000000000..7184609ad9b64d421c17fdfe4a1a0dbeb62d64c8
--- /dev/null
+++ b/mmcv/mmcv/ops/contour_expand.py
@@ -0,0 +1,52 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Union
+
+import numpy as np
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['contour_expand'])
+
+
+def contour_expand(kernel_mask: Union[np.array, torch.Tensor],
+ internal_kernel_label: Union[np.array, torch.Tensor],
+ min_kernel_area: int, kernel_num: int) -> list:
+ """Expand kernel contours so that foreground pixels are assigned into
+ instances.
+
+ Args:
+ kernel_mask (np.array or torch.Tensor): The instance kernel mask with
+ size hxw.
+ internal_kernel_label (np.array or torch.Tensor): The instance internal
+ kernel label with size hxw.
+ min_kernel_area (int): The minimum kernel area.
+ kernel_num (int): The instance kernel number.
+
+ Returns:
+ list: The instance index map with size hxw.
+ """
+ assert isinstance(kernel_mask, (torch.Tensor, np.ndarray))
+ assert isinstance(internal_kernel_label, (torch.Tensor, np.ndarray))
+ assert isinstance(min_kernel_area, int)
+ assert isinstance(kernel_num, int)
+
+ if isinstance(kernel_mask, np.ndarray):
+ kernel_mask = torch.from_numpy(kernel_mask)
+ if isinstance(internal_kernel_label, np.ndarray):
+ internal_kernel_label = torch.from_numpy(internal_kernel_label)
+
+ if torch.__version__ == 'parrots':
+ if kernel_mask.shape[0] == 0 or internal_kernel_label.shape[0] == 0:
+ label = []
+ else:
+ label = ext_module.contour_expand(
+ kernel_mask,
+ internal_kernel_label,
+ min_kernel_area=min_kernel_area,
+ kernel_num=kernel_num)
+ label = label.tolist() # type: ignore
+ else:
+ label = ext_module.contour_expand(kernel_mask, internal_kernel_label,
+ min_kernel_area, kernel_num)
+ return label
diff --git a/mmcv/mmcv/ops/convex_iou.py b/mmcv/mmcv/ops/convex_iou.py
new file mode 100644
index 0000000000000000000000000000000000000000..50050363ac5b08cfa8f86dd186ab7087fac6f48a
--- /dev/null
+++ b/mmcv/mmcv/ops/convex_iou.py
@@ -0,0 +1,52 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple
+
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['convex_iou', 'convex_giou'])
+
+
+def convex_giou(pointsets: torch.Tensor,
+ polygons: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Return generalized intersection-over-union (Jaccard index) between point
+ sets and polygons.
+
+ Args:
+ pointsets (torch.Tensor): It has shape (N, 18),
+ indicating (x1, y1, x2, y2, ..., x9, y9) for each row.
+ polygons (torch.Tensor): It has shape (N, 8),
+ indicating (x1, y1, x2, y2, x3, y3, x4, y4) for each row.
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: The first element is the gious
+ between point sets and polygons with the shape (N,). The second
+ element is the gradient of point sets with the shape (N, 18).
+ """
+ output = pointsets.new_zeros((pointsets.size(0), 19))
+ ext_module.convex_giou(pointsets, polygons, output)
+ convex_giou = output[:, -1]
+ points_grad = output[:, 0:-1]
+ return convex_giou, points_grad
+
+
+def convex_iou(pointsets: torch.Tensor,
+ polygons: torch.Tensor) -> torch.Tensor:
+ """Return intersection-over-union (Jaccard index) between point sets and
+ polygons.
+
+ Args:
+ pointsets (torch.Tensor): It has shape (N, 18),
+ indicating (x1, y1, x2, y2, ..., x9, y9) for each row.
+ polygons (torch.Tensor): It has shape (K, 8),
+ indicating (x1, y1, x2, y2, x3, y3, x4, y4) for each row.
+
+ Returns:
+ torch.Tensor: Return the ious between point sets and polygons with the
+ shape (N, K).
+ """
+ N, K = pointsets.size(0), polygons.size(0)
+ ious = pointsets.new_zeros((N, K))
+ ext_module.convex_iou(pointsets, polygons, ious)
+ return ious
diff --git a/mmcv/mmcv/ops/corner_pool.py b/mmcv/mmcv/ops/corner_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..17ce24952a3b229fb552f450429c948e70aefa19
--- /dev/null
+++ b/mmcv/mmcv/ops/corner_pool.py
@@ -0,0 +1,156 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import Tensor, nn
+from torch.autograd import Function
+
+_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}
+
+
+def _corner_pool(x: Tensor, dim: int, flip: bool) -> Tensor:
+ size = x.size(dim)
+ output = x.clone()
+
+ ind = 1
+ while ind < size:
+ if flip:
+ cur_start = 0
+ cur_len = size - ind
+ next_start = ind
+ next_len = size - ind
+ else:
+ cur_start = ind
+ cur_len = size - ind
+ next_start = 0
+ next_len = size - ind
+
+ # max_temp should be cloned for backward computation
+ max_temp = output.narrow(dim, cur_start, cur_len).clone()
+ cur_temp = output.narrow(dim, cur_start, cur_len)
+ next_temp = output.narrow(dim, next_start, next_len)
+
+ cur_temp[...] = torch.where(max_temp > next_temp, max_temp, next_temp)
+
+ ind = ind << 1
+
+ return output
+
+
+class TopPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input: Tensor) -> Tensor:
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input: Tensor) -> Tensor:
+ return _corner_pool(input, 2, True)
+
+
+class BottomPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input: Tensor) -> Tensor:
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input: Tensor) -> Tensor:
+ return _corner_pool(input, 2, False)
+
+
+class LeftPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input: Tensor) -> Tensor:
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input: Tensor) -> Tensor:
+ return _corner_pool(input, 3, True)
+
+
+class RightPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input: Tensor) -> Tensor:
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input: Tensor) -> Tensor:
+ return _corner_pool(input, 3, False)
+
+
+class CornerPool(nn.Module):
+ """Corner Pooling.
+
+ Corner Pooling is a new type of pooling layer that helps a
+ convolutional network better localize corners of bounding boxes.
+
+ Please refer to `CornerNet: Detecting Objects as Paired Keypoints
+ `_ for more details.
+
+ Code is modified from https://github.com/princeton-vl/CornerNet-Lite.
+
+ Args:
+ mode (str): Pooling orientation for the pooling layer
+
+ - 'bottom': Bottom Pooling
+ - 'left': Left Pooling
+ - 'right': Right Pooling
+ - 'top': Top Pooling
+
+ Returns:
+ Feature map after pooling.
+ """
+
+ pool_functions = {
+ 'bottom': BottomPoolFunction,
+ 'left': LeftPoolFunction,
+ 'right': RightPoolFunction,
+ 'top': TopPoolFunction,
+ }
+
+ cummax_dim_flip = {
+ 'bottom': (2, False),
+ 'left': (3, True),
+ 'right': (3, False),
+ 'top': (2, True),
+ }
+
+ def __init__(self, mode: str):
+ super().__init__()
+ assert mode in self.pool_functions
+ self.mode = mode
+ self.corner_pool: Function = self.pool_functions[mode]
+
+ def forward(self, x: Tensor) -> Tensor:
+ if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0':
+ if torch.onnx.is_in_onnx_export():
+ assert torch.__version__ >= '1.7.0', \
+ 'When `cummax` serves as an intermediate component whose '\
+ 'outputs is used as inputs for another modules, it\'s '\
+ 'expected that pytorch version must be >= 1.7.0, '\
+ 'otherwise Error appears like: `RuntimeError: tuple '\
+ 'appears in op that does not forward tuples, unsupported '\
+ 'kind: prim::PythonOp`.'
+
+ dim, flip = self.cummax_dim_flip[self.mode]
+ if flip:
+ x = x.flip(dim)
+ pool_tensor, _ = torch.cummax(x, dim=dim)
+ if flip:
+ pool_tensor = pool_tensor.flip(dim)
+ return pool_tensor
+ else:
+ if torch.onnx.is_in_onnx_export():
+ return self.corner_pool.apply(x)
+ else:
+ dim, flip = self.cummax_dim_flip[self.mode]
+ return _corner_pool(x, dim, flip)
diff --git a/mmcv/mmcv/ops/correlation.py b/mmcv/mmcv/ops/correlation.py
new file mode 100644
index 0000000000000000000000000000000000000000..319b7646782637e9ebaac4ef07b82d1f460031b5
--- /dev/null
+++ b/mmcv/mmcv/ops/correlation.py
@@ -0,0 +1,200 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple
+
+import torch
+from torch import Tensor, nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['correlation_forward', 'correlation_backward'])
+
+
+class CorrelationFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input1: Tensor,
+ input2: Tensor,
+ kernel_size: int = 1,
+ max_displacement: int = 1,
+ stride: int = 1,
+ padding: int = 1,
+ dilation: int = 1,
+ dilation_patch: int = 1) -> Tensor:
+
+ ctx.save_for_backward(input1, input2)
+
+ kH, kW = ctx.kernel_size = _pair(kernel_size)
+ patch_size = max_displacement * 2 + 1
+ ctx.patch_size = patch_size
+ dH, dW = ctx.stride = _pair(stride)
+ padH, padW = ctx.padding = _pair(padding)
+ dilationH, dilationW = ctx.dilation = _pair(dilation)
+ dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair(
+ dilation_patch)
+
+ output_size = CorrelationFunction._output_size(ctx, input1)
+
+ output = input1.new_zeros(output_size)
+
+ ext_module.correlation_forward(
+ input1,
+ input2,
+ output,
+ kH=kH,
+ kW=kW,
+ patchH=patch_size,
+ patchW=patch_size,
+ padH=padH,
+ padW=padW,
+ dilationH=dilationH,
+ dilationW=dilationW,
+ dilation_patchH=dilation_patchH,
+ dilation_patchW=dilation_patchW,
+ dH=dH,
+ dW=dW)
+
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(
+ ctx, grad_output: Tensor
+ ) -> Tuple[Tensor, Tensor, None, None, None, None, None, None]:
+ input1, input2 = ctx.saved_tensors
+
+ kH, kW = ctx.kernel_size
+ patch_size = ctx.patch_size
+ padH, padW = ctx.padding
+ dilationH, dilationW = ctx.dilation
+ dilation_patchH, dilation_patchW = ctx.dilation_patch
+ dH, dW = ctx.stride
+ grad_input1 = torch.zeros_like(input1)
+ grad_input2 = torch.zeros_like(input2)
+
+ ext_module.correlation_backward(
+ grad_output,
+ input1,
+ input2,
+ grad_input1,
+ grad_input2,
+ kH=kH,
+ kW=kW,
+ patchH=patch_size,
+ patchW=patch_size,
+ padH=padH,
+ padW=padW,
+ dilationH=dilationH,
+ dilationW=dilationW,
+ dilation_patchH=dilation_patchH,
+ dilation_patchW=dilation_patchW,
+ dH=dH,
+ dW=dW)
+ return grad_input1, grad_input2, None, None, None, None, None, None
+
+ @staticmethod
+ def _output_size(ctx, input1):
+ iH, iW = input1.size(2), input1.size(3)
+ batch_size = input1.size(0)
+ kH, kW = ctx.kernel_size
+ patch_size = ctx.patch_size
+ dH, dW = ctx.stride
+ padH, padW = ctx.padding
+ dilationH, dilationW = ctx.dilation
+ dilatedKH = (kH - 1) * dilationH + 1
+ dilatedKW = (kW - 1) * dilationW + 1
+
+ oH = int((iH + 2 * padH - dilatedKH) / dH + 1)
+ oW = int((iW + 2 * padW - dilatedKW) / dW + 1)
+
+ output_size = (batch_size, patch_size, patch_size, oH, oW)
+ return output_size
+
+
+class Correlation(nn.Module):
+ r"""Correlation operator
+
+ This correlation operator works for optical flow correlation computation.
+
+ There are two batched tensors with shape :math:`(N, C, H, W)`,
+ and the correlation output's shape is :math:`(N, max\_displacement \times
+ 2 + 1, max\_displacement * 2 + 1, H_{out}, W_{out})`
+
+ where
+
+ .. math::
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times padding -
+ dilation \times (kernel\_size - 1) - 1}
+ {stride} + 1\right\rfloor
+
+ .. math::
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times padding - dilation
+ \times (kernel\_size - 1) - 1}
+ {stride} + 1\right\rfloor
+
+ the correlation item :math:`(N_i, dy, dx)` is formed by taking the sliding
+ window convolution between input1 and shifted input2,
+
+ .. math::
+ Corr(N_i, dx, dy) =
+ \sum_{c=0}^{C-1}
+ input1(N_i, c) \star
+ \mathcal{S}(input2(N_i, c), dy, dx)
+
+ where :math:`\star` is the valid 2d sliding window convolution operator,
+ and :math:`\mathcal{S}` means shifting the input features (auto-complete
+ zero marginal), and :math:`dx, dy` are shifting distance, :math:`dx, dy \in
+ [-max\_displacement \times dilation\_patch, max\_displacement \times
+ dilation\_patch]`.
+
+ Args:
+ kernel_size (int): The size of sliding window i.e. local neighborhood
+ representing the center points and involved in correlation
+ computation. Defaults to 1.
+ max_displacement (int): The radius for computing correlation volume,
+ but the actual working space can be dilated by dilation_patch.
+ Defaults to 1.
+ stride (int): The stride of the sliding blocks in the input spatial
+ dimensions. Defaults to 1.
+ padding (int): Zero padding added to all four sides of the input1.
+ Defaults to 0.
+ dilation (int): The spacing of local neighborhood that will involved
+ in correlation. Defaults to 1.
+ dilation_patch (int): The spacing between position need to compute
+ correlation. Defaults to 1.
+ """
+
+ def __init__(self,
+ kernel_size: int = 1,
+ max_displacement: int = 1,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ dilation_patch: int = 1) -> None:
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.max_displacement = max_displacement
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.dilation_patch = dilation_patch
+
+ def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
+ return CorrelationFunction.apply(input1, input2, self.kernel_size,
+ self.max_displacement, self.stride,
+ self.padding, self.dilation,
+ self.dilation_patch)
+
+ def __repr__(self) -> str:
+ s = self.__class__.__name__
+ s += f'(kernel_size={self.kernel_size}, '
+ s += f'max_displacement={self.max_displacement}, '
+ s += f'stride={self.stride}, '
+ s += f'padding={self.padding}, '
+ s += f'dilation={self.dilation}, '
+ s += f'dilation_patch={self.dilation_patch})'
+ return s
diff --git a/mmcv/mmcv/ops/csrc/README.md b/mmcv/mmcv/ops/csrc/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..dbc82b534b1ab27593361b3053cb61e12fbd420e
--- /dev/null
+++ b/mmcv/mmcv/ops/csrc/README.md
@@ -0,0 +1,189 @@
+# Code Structure of CUDA operators
+
+This folder contains all non-python code for MMCV custom ops. Please follow the same architecture if you want to add new ops.
+
+## Directories Tree
+
+```folder
+.
+├── common
+│ ├── box_iou_rotated_utils.hpp
+│ ├── parrots_cpp_helper.hpp
+│ ├── parrots_cuda_helper.hpp
+│ ├── pytorch_cpp_helper.hpp
+│ ├── pytorch_cuda_helper.hpp
+│ ├── pytorch_device_registry.hpp
+│ ├── cuda
+│ │ ├── common_cuda_helper.hpp
+│ │ ├── parrots_cudawarpfunction.cuh
+│ │ ├── ...
+│ │ └── ops_cuda_kernel.cuh
+| ├── mps
+│ │ ├── MPSLibrary.h
+│ │ ├── ...
+│ │ └── MPSUtils.h
+| ├── mlu
+│ │ └── ...
+| └── utils
+│ │ └── ...
+├── onnxruntime
+│ ├── onnxruntime_register.h
+│ ├── onnxruntime_session_options_config_keys.h
+│ ├── ort_mmcv_utils.h
+│ ├── ...
+│ ├── onnx_ops.h
+│ └── cpu
+│ ├── onnxruntime_register.cpp
+│ ├── ...
+│ └── onnx_ops_impl.cpp
+├── parrots
+│ ├── ...
+│ ├── ops.cpp
+│ ├── ops_parrots.cpp
+│ └── ops_pytorch.h
+├── pytorch
+│ ├── info.cpp
+│ ├── pybind.cpp
+│ ├── ...
+│ ├── ops.cpp
+│ ├── cuda
+│ │ ├── ...
+│ │ └── ops_cuda.cu
+│ ├── cpu
+│ │ ├── ...
+│ │ └── ops.cpp
+│ ├── mps
+│ │ ├── ...
+│ | └── op_mps.mm
+│ └── mlu
+│ ├── ...
+│ └── op_mlu.cpp
+└── tensorrt
+ ├── trt_cuda_helper.cuh
+ ├── trt_plugin_helper.hpp
+ ├── trt_plugin.hpp
+ ├── trt_serialize.hpp
+ ├── ...
+ ├── trt_ops.hpp
+ └── plugins
+ ├── trt_cuda_helper.cu
+ ├── trt_plugin.cpp
+ ├── ...
+ ├── trt_ops.cpp
+ └── trt_ops_kernel.cu
+```
+
+## Components
+
+- `common`: This directory contains all tools and shared codes.
+ - `cuda`: The cuda kernels which can be shared by all backends. **HIP** kernel is also here since they have similar syntax.
+ - `mps`: The tools used to support MPS ops. **NOTE** that MPS support is **experimental**.
+ - `mlu`: The MLU kernels used to support [Cambricon](https://www.cambricon.com/) device.
+ - `utils`: The kernels and utils of spconv.
+- `onnxruntime`: **ONNX Runtime** support for custom ops. Has been deprecated, please try the latest custom ops in [MMDeploy](https://github.com/open-mmlab/mmdeploy).
+ - `cpu`: CPU implementation of supported ops.
+- `parrots`: **Parrots** is a deep learning frame for model training and inference. Parrots custom ops are placed in this directory.
+- `pytorch`: **PyTorch** custom ops are supported by binding C++ to Python with **pybind11**. The ops implementation and binding codes are placed in this directory.
+ - `cuda`: This directory contains cuda kernel launchers, which feed memory pointers of tensor to the cuda kernel in `common/cuda`. The launchers provide c++ interface of cuda implementation of corresponding custom ops.
+ - `cpu`: This directory contain cpu implementations of corresponding custom ops.
+ - `mlu`: This directory contain launchers of each MLU kernels.
+ - `mps`: MPS ops implementation and launchers.
+- `tensorrt`: **TensorRT** support for custom ops. Has been deprecated, please try the latest custom ops in [MMDeploy](https://github.com/open-mmlab/mmdeploy).
+ - `plugins`: This directory contains the implementation of the supported custom ops. Some ops might also use shared cuda kernel in `common/cuda`.
+
+## How to add new PyTorch ops?
+
+1. (Optional) Add shared kernel in `common` to support special hardware platform.
+
+ ```c++
+ // src/common/cuda/new_ops_cuda_kernel.cuh
+
+ template
+ __global__ void new_ops_forward_cuda_kernel(const T* input, T* output, ...) {
+ // forward here
+ }
+
+ ```
+
+ Add cuda kernel launcher in `pytorch/cuda`.
+
+ ```c++
+ // src/pytorch/cuda
+ #include
+
+ void NewOpsForwardCUDAKernelLauncher(Tensor input, Tensor output, ...){
+ // initialize
+ at::cuda::CUDAGuard device_guard(input.device());
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ ...
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ input.scalar_type(), "new_ops_forward_cuda_kernel", ([&] {
+ new_ops_forward_cuda_kernel
+ <<>>(
+ input.data_ptr(), output.data_ptr(),...);
+ }));
+ AT_CUDA_CHECK(cudaGetLastError());
+ }
+ ```
+
+2. Register implementation for different devices.
+
+ ```c++
+ // src/pytorch/cuda/cudabind.cpp
+ ...
+
+ Tensor new_ops_forward_cuda(Tensor input, Tensor output, ...){
+ // implement cuda forward here
+ // use `NewOpsForwardCUDAKernelLauncher` here
+ }
+ // declare interface here.
+ Tensor new_ops_forward_impl(Tensor input, Tensor output, ...);
+ // register the implementation for given device (CUDA here).
+ REGISTER_DEVICE_IMPL(new_ops_forward_impl, CUDA, new_ops_forward_cuda);
+ ```
+
+3. Add ops implementation in `pytorch` directory. Select different implementations according to device type.
+
+ ```c++
+ // src/pytorch/new_ops.cpp
+ Tensor new_ops_forward_impl(Tensor input, Tensor output, ...){
+ // dispatch the implementation according to the device type of input.
+ DISPATCH_DEVICE_IMPL(new_ops_forward_impl, input, output, ...);
+ }
+ ...
+
+ Tensor new_ops_forward(Tensor input, Tensor output, ...){
+ return new_ops_forward_impl(input, output, ...);
+ }
+ ```
+
+4. Binding the implementation in `pytorch/pybind.cpp`
+
+ ```c++
+ // src/pytorch/pybind.cpp
+
+ ...
+
+ Tensor new_ops_forward(Tensor input, Tensor output, ...);
+
+ ...
+
+ // bind with pybind11
+ m.def("new_ops_forward", &new_ops_forward, "new_ops_forward",
+ py::arg("input"), py::arg("output"), ...);
+
+ ...
+
+ ```
+
+5. Build MMCV again. Enjoy new ops in python
+
+ ```python
+ from ..utils import ext_loader
+ ext_module = ext_loader.load_ext('_ext', ['new_ops_forward'])
+
+ ...
+
+ ext_module.new_ops_forward(input, output, ...)
+
+ ```
diff --git a/mmcv/mmcv/ops/csrc/common/box_iou_rotated_utils.hpp b/mmcv/mmcv/ops/csrc/common/box_iou_rotated_utils.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..243200e156f1384b625d6bac7fa4c68e533d9441
--- /dev/null
+++ b/mmcv/mmcv/ops/csrc/common/box_iou_rotated_utils.hpp
@@ -0,0 +1,347 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+// modified from
+// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h
+#pragma once
+#include
+#include
+
+#ifdef __CUDACC__
+// Designates functions callable from the host (CPU) and the device (GPU)
+#define HOST_DEVICE __host__ __device__
+#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
+#else
+#include
+#define HOST_DEVICE
+#define HOST_DEVICE_INLINE HOST_DEVICE inline
+#endif
+
+namespace {
+
+template
+struct RotatedBox {
+ T x_ctr, y_ctr, w, h, a;
+};
+
+template
+struct Point {
+ T x, y;
+ HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {}
+ HOST_DEVICE_INLINE Point operator+(const Point& p) const {
+ return Point(x + p.x, y + p.y);
+ }
+ HOST_DEVICE_INLINE Point& operator+=(const Point& p) {
+ x += p.x;
+ y += p.y;
+ return *this;
+ }
+ HOST_DEVICE_INLINE Point operator-(const Point& p) const {
+ return Point(x - p.x, y - p.y);
+ }
+ HOST_DEVICE_INLINE Point operator*(const T coeff) const {
+ return Point(x * coeff, y * coeff);
+ }
+};
+
+template
+HOST_DEVICE_INLINE T dot_2d(const Point& A, const Point& B) {
+ return A.x * B.x + A.y * B.y;
+}
+
+template
+HOST_DEVICE_INLINE T cross_2d(const Point& A, const Point& B) {
+ return A.x * B.y - B.x * A.y;
+}
+
+template
+HOST_DEVICE_INLINE void get_rotated_vertices(const RotatedBox& box,
+ Point (&pts)[4]) {
+ // M_PI / 180. == 0.01745329251
+ // double theta = box.a * 0.01745329251;
+ // MODIFIED
+ double theta = box.a;
+ T cosTheta2 = (T)cos(theta) * 0.5f;
+ T sinTheta2 = (T)sin(theta) * 0.5f;
+
+ // y: top --> down; x: left --> right
+ pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w;
+ pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
+ pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w;
+ pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
+ pts[2].x = 2 * box.x_ctr - pts[0].x;
+ pts[2].y = 2 * box.y_ctr - pts[0].y;
+ pts[3].x = 2 * box.x_ctr - pts[1].x;
+ pts[3].y = 2 * box.y_ctr - pts[1].y;
+}
+
+template
+HOST_DEVICE_INLINE int get_intersection_points(const Point (&pts1)[4],
+ const Point (&pts2)[4],
+ Point (&intersections)[24]) {
+ // Line vector
+ // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
+ Point vec1[4], vec2[4];
+ for (int i = 0; i < 4; i++) {
+ vec1[i] = pts1[(i + 1) % 4] - pts1[i];
+ vec2[i] = pts2[(i + 1) % 4] - pts2[i];
+ }
+
+ // Line test - test all line combos for intersection
+ int num = 0; // number of intersections
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 4; j++) {
+ // Solve for 2x2 Ax=b
+ T det = cross_2d(vec2[j], vec1[i]);
+
+ // This takes care of parallel lines
+ if (fabs(det) <= 1e-14) {
+ continue;
+ }
+
+ auto vec12 = pts2[j] - pts1[i];
+
+ T t1 = cross_2d(vec2[j], vec12) / det;
+ T t2 = cross_2d(vec1[i], vec12) / det;
+
+ if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
+ intersections[num++] = pts1[i] + vec1[i] * t1;
+ }
+ }
+ }
+
+ // Check for vertices of rect1 inside rect2
+ {
+ const auto& AB = vec2[0];
+ const auto& DA = vec2[3];
+ auto ABdotAB = dot_2d(AB, AB);
+ auto ADdotAD = dot_2d(DA, DA);
+ for (int i = 0; i < 4; i++) {
+ // assume ABCD is the rectangle, and P is the point to be judged
+ // P is inside ABCD iff. P's projection on AB lies within AB
+ // and P's projection on AD lies within AD
+
+ auto AP = pts1[i] - pts2[0];
+
+ auto APdotAB = dot_2d(AP, AB);
+ auto APdotAD = -dot_2d(AP, DA);
+
+ if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
+ (APdotAD <= ADdotAD)) {
+ intersections[num++] = pts1[i];
+ }
+ }
+ }
+
+ // Reverse the check - check for vertices of rect2 inside rect1
+ {
+ const auto& AB = vec1[0];
+ const auto& DA = vec1[3];
+ auto ABdotAB = dot_2d(AB, AB);
+ auto ADdotAD = dot_2d(DA, DA);
+ for (int i = 0; i < 4; i++) {
+ auto AP = pts2[i] - pts1[0];
+
+ auto APdotAB = dot_2d(AP, AB);
+ auto APdotAD = -dot_2d(AP, DA);
+
+ if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
+ (APdotAD <= ADdotAD)) {
+ intersections[num++] = pts2[i];
+ }
+ }
+ }
+
+ return num;
+}
+
+template
+HOST_DEVICE_INLINE int convex_hull_graham(const Point (&p)[24],
+ const int& num_in, Point (&q)[24],
+ bool shift_to_zero = false) {
+ assert(num_in >= 2);
+
+ // Step 1:
+ // Find point with minimum y
+ // if more than 1 points have the same minimum y,
+ // pick the one with the minimum x.
+ int t = 0;
+ for (int i = 1; i < num_in; i++) {
+ if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
+ t = i;
+ }
+ }
+ auto& start = p[t]; // starting point
+
+ // Step 2:
+ // Subtract starting point from every points (for sorting in the next step)
+ for (int i = 0; i < num_in; i++) {
+ q[i] = p[i] - start;
+ }
+
+ // Swap the starting point to position 0
+ auto tmp = q[0];
+ q[0] = q[t];
+ q[t] = tmp;
+
+ // Step 3:
+ // Sort point 1 ~ num_in according to their relative cross-product values
+ // (essentially sorting according to angles)
+ // If the angles are the same, sort according to their distance to origin
+ T dist[24];
+ for (int i = 0; i < num_in; i++) {
+ dist[i] = dot_2d(q[i], q[i]);
+ }
+
+#ifdef __CUDACC__
+ // CUDA version
+ // In the future, we can potentially use thrust
+ // for sorting here to improve speed (though not guaranteed)
+ for (int i = 1; i < num_in - 1; i++) {
+ for (int j = i + 1; j < num_in; j++) {
+ T crossProduct = cross_2d(q[i], q[j]);
+ if ((crossProduct < -1e-6) ||
+ (fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
+ auto q_tmp = q[i];
+ q[i] = q[j];
+ q[j] = q_tmp;
+ auto dist_tmp = dist[i];
+ dist[i] = dist[j];
+ dist[j] = dist_tmp;
+ }
+ }
+ }
+#else
+ // CPU version
+ std::sort(q + 1, q + num_in,
+ [](const Point& A, const Point& B) -> bool {
+ T temp = cross_2d(A, B);
+ if (fabs(temp) < 1e-6) {
+ return dot_2d(A, A) < dot_2d(B, B);
+ } else {
+ return temp > 0;
+ }
+ });
+ // compute distance to origin after sort, since the points are now different.
+ for (int i = 0; i < num_in; i++) {
+ dist[i] = dot_2d(q[i], q[i]);
+ }
+#endif
+
+ // Step 4:
+ // Make sure there are at least 2 points (that don't overlap with each other)
+ // in the stack
+ int k; // index of the non-overlapped second point
+ for (k = 1; k < num_in; k++) {
+ if (dist[k] > 1e-8) {
+ break;
+ }
+ }
+ if (k == num_in) {
+ // We reach the end, which means the convex hull is just one point
+ q[0] = p[t];
+ return 1;
+ }
+ q[1] = q[k];
+ int m = 2; // 2 points in the stack
+ // Step 5:
+ // Finally we can start the scanning process.
+ // When a non-convex relationship between the 3 points is found
+ // (either concave shape or duplicated points),
+ // we pop the previous point from the stack
+ // until the 3-point relationship is convex again, or
+ // until the stack only contains two points
+ for (int i = k + 1; i < num_in; i++) {
+ while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
+ m--;
+ }
+ q[m++] = q[i];
+ }
+
+ // Step 6 (Optional):
+ // In general sense we need the original coordinates, so we
+ // need to shift the points back (reverting Step 2)
+ // But if we're only interested in getting the area/perimeter of the shape
+ // We can simply return.
+ if (!shift_to_zero) {
+ for (int i = 0; i < m; i++) {
+ q[i] += start;
+ }
+ }
+
+ return m;
+}
+
+template
+HOST_DEVICE_INLINE T polygon_area(const Point (&q)[24], const int& m) {
+ if (m <= 2) {
+ return 0;
+ }
+
+ T area = 0;
+ for (int i = 1; i < m - 1; i++) {
+ area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0]));
+ }
+
+ return area / 2.0;
+}
+
+template
+HOST_DEVICE_INLINE T rotated_boxes_intersection(const RotatedBox& box1,
+ const RotatedBox& box2) {
+ // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
+ // from rotated_rect_intersection_pts
+ Point intersectPts[24], orderedPts[24];
+
+ Point pts1[4];
+ Point pts2[4];
+ get_rotated_vertices(box1, pts1);
+ get_rotated_vertices(box2, pts2);
+
+ int num = get_intersection_points(pts1, pts2, intersectPts);
+
+ if (num <= 2) {
+ return 0.0;
+ }
+
+ // Convex Hull to order the intersection points in clockwise order and find
+ // the contour area.
+ int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true);
+ return polygon_area(orderedPts, num_convex);
+}
+
+} // namespace
+
+template
+HOST_DEVICE_INLINE T single_box_iou_rotated(T const* const box1_raw,
+ T const* const box2_raw,
+ const int mode_flag) {
+ // shift center to the middle point to achieve higher precision in result
+ RotatedBox box1, box2;
+ auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
+ auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0;
+ box1.x_ctr = box1_raw[0] - center_shift_x;
+ box1.y_ctr = box1_raw[1] - center_shift_y;
+ box1.w = box1_raw[2];
+ box1.h = box1_raw[3];
+ box1.a = box1_raw[4];
+ box2.x_ctr = box2_raw[0] - center_shift_x;
+ box2.y_ctr = box2_raw[1] - center_shift_y;
+ box2.w = box2_raw[2];
+ box2.h = box2_raw[3];
+ box2.a = box2_raw[4];
+
+ const T area1 = box1.w * box1.h;
+ const T area2 = box2.w * box2.h;
+ if (area1 < 1e-14 || area2 < 1e-14) {
+ return 0.f;
+ }
+
+ const T intersection = rotated_boxes_intersection(box1, box2);
+ T baseS = 1.0;
+ if (mode_flag == 0) {
+ baseS = (area1 + area2 - intersection);
+ } else if (mode_flag == 1) {
+ baseS = area1;
+ }
+ const T iou = intersection / baseS;
+ return iou;
+}
diff --git a/mmcv/mmcv/ops/csrc/common/cuda/active_rotated_filter_cuda_kernel.cuh b/mmcv/mmcv/ops/csrc/common/cuda/active_rotated_filter_cuda_kernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..36e41107ebd52d3cf5e9a71cffe6eddeed4f0765
--- /dev/null
+++ b/mmcv/mmcv/ops/csrc/common/cuda/active_rotated_filter_cuda_kernel.cuh
@@ -0,0 +1,59 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+// Modified from
+// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/orn/src/cuda/ActiveRotatingFilter_cuda.cu
+#ifndef ACTIVE_ROTATED_FILTER_CUDA_KERNEL_CUH
+#define ACTIVE_ROTATED_FILTER_CUDA_KERNEL_CUH
+
+#ifdef MMCV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else
+#include "pytorch_cuda_helper.hpp"
+#endif
+
+template
+__global__ void active_rotated_filter_forward_cuda_kernel(
+ const int nthreads, const scalar_t* weight_data, const int* indices_data,
+ const int num_input_planes, const int num_output_planes,
+ const int num_orientations, const int num_rotations, const int nEntry,
+ scalar_t* output_data) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ int l = index % nEntry;
+ int j = (index / nEntry) % num_input_planes;
+ int i = index / nEntry / num_input_planes;
+ int k;
+ scalar_t val = *(weight_data + index);
+ for (k = 0; k < num_rotations; k++) {
+ int idx = (int)(*(indices_data + l * num_rotations + k)) - 1;
+ scalar_t* target = output_data +
+ i * (num_rotations * num_input_planes * nEntry) +
+ k * (num_input_planes * nEntry) + j * (nEntry) + idx;
+ *target = val;
+ }
+ }
+}
+
+template
+__global__ void active_rotated_filter_backward_cuda_kernel(
+ const int nthreads, const scalar_t* gradWeight_data,
+ const int* indices_data, const int num_input_planes,
+ const int num_output_planes, const int num_orientations,
+ const int num_rotations, const int nEntry, scalar_t* weight_data) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ int l = index % nEntry;
+ int j = (index / nEntry) % num_input_planes;
+ int i = index / nEntry / num_input_planes;
+ int k;
+ scalar_t* val = weight_data + index;
+ *val = 0;
+ scalar_t tmp = 0;
+ for (k = 0; k < num_rotations; k++) {
+ int idx = (int)(*(indices_data + l * num_rotations + k)) - 1;
+ scalar_t target =
+ *(gradWeight_data + i * (num_rotations * num_input_planes * nEntry) +
+ k * (num_input_planes * nEntry) + j * (nEntry) + idx);
+ tmp = tmp + target;
+ }
+ *val = tmp;
+ }
+}
+#endif // ACTIVE_ROTATED_FILTER_CUDA_KERNEL_CUH
diff --git a/mmcv/mmcv/ops/csrc/common/cuda/assign_score_withk_cuda_kernel.cuh b/mmcv/mmcv/ops/csrc/common/cuda/assign_score_withk_cuda_kernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..9f9250844b9ceeca0df0377640c3d28e3f61cecc
--- /dev/null
+++ b/mmcv/mmcv/ops/csrc/common/cuda/assign_score_withk_cuda_kernel.cuh
@@ -0,0 +1,116 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef ASSIGN_SCORE_WITHK_CUDA_KERNEL_CUH
+#define ASSIGN_SCORE_WITHK_CUDA_KERNEL_CUH
+
+#ifdef MMCV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else
+#include "pytorch_cuda_helper.hpp"
+#endif
+
+// input: points(B,N0,M,O), centers(B,N0,M,O), scores(B,N1,K,M), knn_idx(B,N1,K)
+// output: fout(B,O,N)
+// algo: fout(b,i,k,j) = s(b,i,k,m)*p(b,c(i),k,m,j) = s(b,i,k,m)*p(b,i(k),m,j)
+// i(k) = idx(b,i,k)
+// sum: fout(b,i,j) = fout(b,i,j) + s(b,i,k,m)*p(b,i,k,m,j)
+// avg: fout(b,i,j) = sum(fout(b,i,k,j)) / k
+// max: fout(b,i,j) = max(fout(b,i,k,j), sum(s(b,i,k,m)*p(b,i,k,m,j)))
+
+template
+__global__ void assign_score_withk_forward_cuda_kernel(
+ const int B, const int N0, const int N1, const int M, const int K,
+ const int O, const int aggregate, const T* points, const T* centers,
+ const T* scores, const int64_t* knn_idx, T* output) {
+ // ----- parallel loop for B, N1, K and O ---------
+ CUDA_1D_KERNEL_LOOP(i, B * O * N1 * K) {
+ // ------- loop for M ----------
+ const int b = (int)(i / (O * N1 * K));
+ const int o = (int)(i % (O * N1 * K) / (N1 * K));
+ const int n = (int)(i % (N1 * K) / K);
+ const int k = (int)(i % K);
+ const int cn = (int)knn_idx[b * K * N1 + n * K +
+ 0]; // The first neighbor is the center point
+ const int kn = (int)knn_idx[b * K * N1 + n * K + k];
+ if (kn >= N0 ||
+ kn < 0) { // if index overflows, it is out of the neighborhood range
+ return;
+ }
+ assert(b < B);
+ assert(kn < N0);
+ assert(cn < N0);
+ assert(o < O);
+ assert(n < N1);
+ const int out_idx = b * N1 * O * K + o * N1 * K + n * K + k;
+ T val = output[out_idx];
+ for (int m = 0; m < M; m++) {
+ val += points[b * N0 * M * O + kn * M * O + m * O + o] *
+ scores[b * N1 * K * M + n * K * M + k * M + m] -
+ centers[b * N0 * M * O + cn * M * O + m * O + o] *
+ scores[b * N1 * K * M + n * K * M + k * M + m];
+ }
+ output[out_idx] = val;
+ }
+}
+
+template
+__global__ void assign_score_withk_points_backward_cuda_kernel(
+ const int B, const int N0, const int N, const int M, const int K,
+ const int O, const int aggregate, const T* grad_out, const T* scores,
+ const int64_t* knn_idx, T* grad_points, T* grad_centers) {
+ // ----- parallel loop for B, M, O ---------
+ CUDA_1D_KERNEL_LOOP(i, B * M * O) {
+ int b = (int)(i / (M * O));
+ int m = (int)(i % (M * O) / O);
+ int o = (int)(i % O);
+
+ // ----- loop for N,K ---------
+ for (int n = 0; n < N; n++) {
+ for (int k = 0; k < K; k++) {
+ int kn = knn_idx[b * N * K + n * K + k];
+ int cn = knn_idx[b * N * K + n * K + 0];
+ if (kn >= N0 || kn < 0) { // if index overflows, it is out of the
+ // neighborhood range
+ continue;
+ }
+ atomicAdd(grad_points + b * N0 * M * O + kn * M * O + m * O + o,
+ scores[b * N * K * M + n * K * M + k * M + m] *
+ grad_out[b * O * N * K + o * N * K + n * K + k]);
+ atomicAdd(grad_centers + b * N0 * M * O + cn * M * O + m * O + o,
+ -scores[b * N * K * M + n * K * M + k * M + m] *
+ grad_out[b * O * N * K + o * N * K + n * K + k]);
+ }
+ }
+ }
+}
+
+template
+__global__ void assign_score_withk_scores_backward_cuda_kernel(
+ const int B, const int N0, const int N, const int M, const int K,
+ const int O, const int aggregate, const T* grad_out, const T* points,
+ const T* centers, const int64_t* knn_idx, T* grad_scores) {
+ // ----- parallel loop for B, N, K, M ---------
+ CUDA_1D_KERNEL_LOOP(i, B * N * K * M) {
+ const int b = (int)(i / (N * M * K));
+ const int n = (int)(i % (N * M * K) / M / K);
+ const int k = (int)(i % (M * K) / M);
+ const int m = (int)(i % M);
+ const int cn = knn_idx[b * N * K + n * K + 0];
+ const int kn = knn_idx[b * N * K + n * K + k];
+ if (kn >= N0 ||
+ kn < 0) { // if index overflows, it is out of the neighborhood range
+ return;
+ }
+
+ // -------------- loop for O ------------------------
+ const int out_idx = b * N * K * M + n * K * M + k * M + m;
+ T val = grad_scores[out_idx];
+ for (int o = 0; o < O; o++) {
+ val += (points[b * N0 * M * O + kn * M * O + m * O + o] -
+ centers[b * N0 * M * O + cn * M * O + m * O + o]) *
+ grad_out[b * O * N * K + o * N * K + n * K + k];
+ }
+ grad_scores[out_idx] = val;
+ }
+}
+
+#endif // ASSIGN_SCORE_WITHK_CUDA_KERNEL_CUH
diff --git a/mmcv/mmcv/ops/csrc/common/cuda/ball_query_cuda_kernel.cuh b/mmcv/mmcv/ops/csrc/common/cuda/ball_query_cuda_kernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..632b5c4940b33a9d8d839fa3f3b92e7b6a2bd29e
--- /dev/null
+++ b/mmcv/mmcv/ops/csrc/common/cuda/ball_query_cuda_kernel.cuh
@@ -0,0 +1,58 @@
+// Copyright (c) OpenMMLab. All rights reserved
+// Modified from
+// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu
+#ifndef BALL_QUERY_CUDA_KERNEL_CUH
+#define BALL_QUERY_CUDA_KERNEL_CUH
+
+#ifdef MMCV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else
+#include "pytorch_cuda_helper.hpp"
+#endif
+
+template
+__global__ void ball_query_forward_cuda_kernel(int b, int n, int m,
+ float min_radius,
+ float max_radius, int nsample,
+ const T* new_xyz, const T* xyz,
+ int* idx) {
+ // new_xyz: (B, M, 3)
+ // xyz: (B, N, 3)
+ // output:
+ // idx: (B, M, nsample)
+ int bs_idx = blockIdx.y;
+ CUDA_1D_KERNEL_LOOP(pt_idx, m) {
+ if (bs_idx >= b) return;
+
+ new_xyz += bs_idx * m * 3 + pt_idx * 3;
+ xyz += bs_idx * n * 3;
+ idx += bs_idx * m * nsample + pt_idx * nsample;
+
+ float max_radius2 = max_radius * max_radius;
+ float min_radius2 = min_radius * min_radius;
+ T new_x = new_xyz[0];
+ T new_y = new_xyz[1];
+ T new_z = new_xyz[2];
+
+ int cnt = 0;
+ for (int k = 0; k < n; ++k) {
+ T x = xyz[k * 3 + 0];
+ T y = xyz[k * 3 + 1];
+ T z = xyz[k * 3 + 2];
+ T d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
+ (new_z - z) * (new_z - z);
+ if (d2 == 0 || (d2 >= min_radius2 && d2 < max_radius2)) {
+ if (cnt == 0) {
+ for (int l = 0; l < nsample; ++l) {
+ idx[l] = k;
+ }
+ }
+ idx[cnt] = k;
+ ++cnt;
+ if (cnt >= nsample) break;
+ }
+ }
+ }
+}
+
+#endif // BALL_QUERY_CUDA_KERNEL_CUH
diff --git a/mmcv/mmcv/ops/csrc/common/cuda/bbox_overlaps_cuda_kernel.cuh b/mmcv/mmcv/ops/csrc/common/cuda/bbox_overlaps_cuda_kernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..15bd91eca629895d3a99dde3fe6614036ca31dc9
--- /dev/null
+++ b/mmcv/mmcv/ops/csrc/common/cuda/bbox_overlaps_cuda_kernel.cuh
@@ -0,0 +1,147 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef BBOX_OVERLAPS_CUDA_KERNEL_CUH
+#define BBOX_OVERLAPS_CUDA_KERNEL_CUH
+
+#ifdef MMCV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else
+#include "pytorch_cuda_helper.hpp"
+#endif
+
+template
+__device__ __forceinline__ void load_bbox(const T* bbox, const int base, T& x1,
+ T& y1, T& x2, T& y2) {
+ x1 = bbox[base];
+ y1 = bbox[base + 1];
+ x2 = bbox[base + 2];
+ y2 = bbox[base + 3];
+}
+
+template <>
+__device__ __forceinline__ void load_bbox(const float* bbox,
+ const int base, float& x1,
+ float& y1, float& x2,
+ float& y2) {
+ const float4 bbox_offset = reinterpret_cast(bbox + base)[0];
+ x1 = bbox_offset.x;
+ y1 = bbox_offset.y;
+ x2 = bbox_offset.z;
+ y2 = bbox_offset.w;
+}
+
+template
+__global__ void bbox_overlaps_cuda_kernel(const T* bbox1, const T* bbox2,
+ T* ious, const int num_bbox1,
+ const int num_bbox2, const int mode,
+ const bool aligned,
+ const int offset) {
+ if (aligned) {
+ CUDA_1D_KERNEL_LOOP(index, num_bbox1) {
+ const int b1 = index;
+ const int b2 = index;
+
+ const int base1 = b1 << 2; // b1 * 4
+ T b1_x1, b1_y1, b1_x2, b1_y2;
+ load_bbox(bbox1, base1, b1_x1, b1_y1, b1_x2, b1_y2);
+ const T b1_area = (b1_x2 - b1_x1 + offset) * (b1_y2 - b1_y1 + offset);
+
+ const int base2 = b2 << 2; // b2 * 4
+ T b2_x1, b2_y1, b2_x2, b2_y2;
+ load_bbox(bbox2, base2, b2_x1, b2_y1, b2_x2, b2_y2);
+ const T b2_area = (b2_x2 - b2_x1 + offset) * (b2_y2 - b2_y1 + offset);
+
+ const T left = fmaxf(b1_x1, b2_x1), right = fminf(b1_x2, b2_x2);
+ const T top = fmaxf(b1_y1, b2_y1), bottom = fminf(b1_y2, b2_y2);
+ const T width = fmaxf(right - left + offset, 0.f);
+ const T height = fmaxf(bottom - top + offset, 0.f);
+ const T interS = width * height;
+
+ const T baseS =
+ fmaxf(mode == 0 ? b1_area + b2_area - interS : b1_area, T(offset));
+ ious[index] = interS / baseS;
+ }
+ } else {
+ CUDA_1D_KERNEL_LOOP(index, num_bbox1 * num_bbox2) {
+ const int b1 = index / num_bbox2;
+ const int b2 = index % num_bbox2;
+
+ const int base1 = b1 << 2; // b1 * 4
+ T b1_x1, b1_y1, b1_x2, b1_y2;
+ load_bbox(bbox1, base1, b1_x1, b1_y1, b1_x2, b1_y2);
+ const T b1_area = (b1_x2 - b1_x1 + offset) * (b1_y2 - b1_y1 + offset);
+
+ const int base2 = b2 << 2; // b2 * 4
+ T b2_x1, b2_y1, b2_x2, b2_y2;
+ load_bbox(bbox2, base2, b2_x1, b2_y1, b2_x2, b2_y2);
+ const T b2_area = (b2_x2 - b2_x1 + offset) * (b2_y2 - b2_y1 + offset);
+
+ const T left = fmaxf(b1_x1, b2_x1), right = fminf(b1_x2, b2_x2);
+ const T top = fmaxf(b1_y1, b2_y1), bottom = fminf(b1_y2, b2_y2);
+ const T width = fmaxf(right - left + offset, 0.f);
+ const T height = fmaxf(bottom - top + offset, 0.f);
+ const T interS = width * height;
+
+ const T baseS =
+ fmaxf(mode == 0 ? b1_area + b2_area - interS : b1_area, T(offset));
+ ious[index] = interS / baseS;
+ }
+ }
+}
+
+#if __CUDA_ARCH__ >= 530
+__device__ __forceinline__ __half __half_area(const __half x1, const __half y1,
+ const __half x2, const __half y2,
+ const __half offset) {
+ const __half half_w = __hadd(__hsub(x2, x1), offset);
+ const __half half_h = __hadd(__hsub(y2, y1), offset);
+ return __hmul(half_w, half_h);
+}
+
+__device__ __forceinline__ __half __half_max(const __half a, const __half b) {
+ return __hge(a, b) ? a : b;
+}
+
+__device__ __forceinline__ __half __half_min(const __half a, const __half b) {
+ return __hle(a, b) ? a : b;
+}
+
+// fp16 won't provide much increase when aligned==true. It is useful when
+// aligned==false, which would give you ~40% bonus.
+__device__ void bbox_overlaps_cuda_kernel_half(
+ const __half* bbox1, const __half* bbox2, __half* ious, const int num_bbox1,
+ const int num_bbox2, const int mode, const bool aligned, const int offset) {
+ const int num_output = aligned ? num_bbox1 : num_bbox1 * num_bbox2;
+ const __half h_offset = __int2half_rn(offset);
+ CUDA_1D_KERNEL_LOOP(index, num_output) {
+ const int b1 = aligned ? index : index / num_bbox2;
+ const int b2 = aligned ? index : index % num_bbox2;
+
+ const int base1 = b1 << 2;
+ __half b1_x1, b1_y1, b1_x2, b1_y2;
+ load_bbox<__half>(bbox1, base1, b1_x1, b1_y1, b1_x2, b1_y2);
+ const __half b1_area = __half_area(b1_x1, b1_y1, b1_x2, b1_y2, h_offset);
+
+ const int base2 = b2 << 2;
+ __half b2_x1, b2_y1, b2_x2, b2_y2;
+ load_bbox<__half>(bbox2, base2, b2_x1, b2_y1, b2_x2, b2_y2);
+ const __half b2_area = __half_area(b2_x1, b2_y1, b2_x2, b2_y2, h_offset);
+
+ const __half left = __half_max(b1_x1, b2_x1),
+ right = __half_min(b1_x2, b2_x2);
+ const __half top = __half_max(b1_y1, b2_y1),
+ bottom = __half_min(b1_y2, b2_y2);
+ const __half width =
+ __half_max(__hadd(__hsub(right, left), h_offset), __float2half(0.f));
+ const __half height =
+ __half_max(__hadd(__hsub(bottom, top), h_offset), __float2half(0.f));
+ const __half interS = __hmul(width, height);
+
+ const __half baseS = __half_max(
+ mode == 0 ? __hsub(__hadd(b1_area, b2_area), interS) : b1_area,
+ h_offset);
+ ious[index] = __hdiv(interS, baseS);
+ }
+}
+#endif // __CUDA_ARCH__ >= 530
+
+#endif // BBOX_OVERLAPS_CUDA_KERNEL_CUH
diff --git a/mmcv/mmcv/ops/csrc/common/cuda/border_align_cuda_kernel.cuh b/mmcv/mmcv/ops/csrc/common/cuda/border_align_cuda_kernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..1d2a2197b45ef5c82412c4b75d7819a7e27674f6
--- /dev/null
+++ b/mmcv/mmcv/ops/csrc/common/cuda/border_align_cuda_kernel.cuh
@@ -0,0 +1,200 @@
+// Copyright (c) OpenMMLab. All rights reserved
+// modified from
+// https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/csrc/border_align/border_align_kernel.cu.
+// the main difference: (1) use `argmax_idx` for fast computing of gradient
+// during the backward. (2) `wh` is directly computed by `boxes`, rather than
+// passing it as argument to forward or backward functions.
+
+#ifndef BORDER_ALIGN_CUDA_KERNEL_CUH
+#define BORDER_ALIGN_CUDA_KERNEL_CUH
+
+#include
+#ifdef MMCV_WITH_TRT
+#include "common_cuda_helper.hpp"
+#else // MMCV_WITH_TRT
+#ifdef MMCV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else // MMCV_USE_PARROTS
+#include "pytorch_cuda_helper.hpp"
+#endif // MMCV_USE_PARROTS
+#endif // MMCV_WITH_TRT
+
+enum BorderMode { Top = 0, Left = 1, Bottom = 2, Right = 3 };
+
+/*** Forward ***/
+template
+__global__ void border_align_forward_cuda_kernel(
+ const int nthreads, const T* input, const T* boxes, T* output,
+ int* argmax_idx, const int channels, const int box_size, const int height,
+ const int width, const int pool_size) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (batch_idx, c_idx, box_idx) is an element paralleled for computing
+ // output, and `extreme_idx` is in range [0,3]
+ int batch_idx, c_idx, box_idx, extreme_idx, maxidx, *offset_argmax_idx;
+ const T *offset_box, *offset_input, *offset_box_x;
+ T *offset_output, box_width, box_height, stride, x_stride, y_stride, x, y,
+ val, maxval;
+
+ extreme_idx = threadIdx.y;
+ // shape (N, C, box_size, 4) for output
+ batch_idx = index / channels / box_size;
+ // shape (N, box_size, 4) for boxes
+ box_idx = index % box_size + batch_idx * box_size;
+ c_idx = (index / box_size) % channels;
+
+ offset_box = boxes + box_idx * 4;
+ box_width = *(offset_box + 2) - *offset_box;
+ box_height = *(offset_box + 3) - *(offset_box + 1);
+ offset_output = output + index * 4 + extreme_idx;
+ offset_argmax_idx = argmax_idx + index * 4 + extreme_idx;
+ // shape (N, 4C, h, w) for input.
+ // [0,C) for top feature, [C,2C) for left feature,
+ // [2C,3C) for bottom feature, [3C,4C) for right feature
+ offset_input =
+ input + (batch_idx * channels * 4 + extreme_idx * channels + c_idx) *
+ height * width;
+
+ // extreme_idx in [0,1] -> offset_box_x indexed at x1
+ // extreme_idx in [2,3] -> offset_box_x indexed at x2
+ offset_box_x = offset_box + extreme_idx / 2 * 2;
+
+ // (x1,y1) or (x2,y2) for (x,y)
+ x = *offset_box_x;
+ y = *(offset_box_x + 1);
+
+ switch (extreme_idx) {
+ // top
+ case BorderMode::Top:
+ stride = box_width / pool_size;
+ x_stride = stride;
+ y_stride = 0;
+ break;
+ // left
+ case BorderMode::Left:
+ stride = box_height / pool_size;
+ x_stride = 0;
+ y_stride = stride;
+ break;
+ // bottom
+ case BorderMode::Bottom:
+ stride = box_width / pool_size;
+ x_stride = -stride;
+ y_stride = 0;
+ break;
+ // right
+ case BorderMode::Right:
+ stride = box_height / pool_size;
+ x_stride = 0;
+ y_stride = -stride;
+ break;
+ }
+
+ // initialize maxval and maxidx with the start position (e.g. (x1,y1) or
+ // (x2,y2))
+ maxval = bilinear_interpolate(offset_input, height, width, y, x, index);
+ maxidx = 0;
+
+ // do max_pool along the border
+ for (int i = 1; i <= pool_size; i++) {
+ x += x_stride;
+ y += y_stride;
+ val = bilinear_interpolate(offset_input, height, width, y, x, index);
+ if (val > maxval) {
+ maxval = val;
+ maxidx = i;
+ }
+ }
+
+ // update output and argmax_idx
+ *offset_output = maxval;
+ *offset_argmax_idx = maxidx;
+ }
+}
+
+/*** Backward ***/
+template
+__global__ void border_align_backward_cuda_kernel(
+ const int nthreads, const T* grad_output, const T* boxes,
+ const int* argmax_idx, T* grad_input, const int channels,
+ const int box_size, const int height, const int width,
+ const int pool_size) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (batch_idx, c_idx, box_idx) is an element paralleled for computing
+ // output, and `extreme_idx` is in range [0,3]
+ int batch_idx, c_idx, box_idx, extreme_idx;
+ const int* offset_argmax_idx;
+ const T *offset_grad_output, *offset_box, *offset_box_x;
+ T *offset_grad_input, box_width, box_height, stride, x_stride, y_stride, x,
+ y;
+
+ extreme_idx = threadIdx.y;
+ batch_idx = index / channels / box_size;
+ box_idx = index % box_size + batch_idx * box_size;
+ c_idx = (index / box_size) % channels;
+
+ offset_box = boxes + box_idx * 4;
+ box_width = *(offset_box + 2) - *offset_box;
+ box_height = *(offset_box + 3) - *(offset_box + 1);
+ offset_grad_output = grad_output + index * 4 + extreme_idx;
+ offset_argmax_idx = argmax_idx + index * 4 + extreme_idx;
+ // [0,C) for top feature grad, [C,2C) for left feature grad,
+ // [2C,3C) for bottom feature grad, [3C,4C) for right feature grad
+ offset_grad_input = grad_input + (batch_idx * channels * 4 +
+ extreme_idx * channels + c_idx) *
+ height * width;
+
+ // extreme_idx in [0,1] -> offset_box_x indexed at x1
+ // extreme_idx in [2,3] -> offset_box_x indexed at x2
+ offset_box_x = offset_box + extreme_idx / 2 * 2;
+
+ switch (extreme_idx) {
+ // top
+ case BorderMode::Top:
+ stride = box_width / pool_size;
+ x_stride = stride;
+ y_stride = 0;
+ break;
+ // left
+ case BorderMode::Left:
+ stride = box_height / pool_size;
+ x_stride = 0;
+ y_stride = stride;
+ break;
+ // bottom
+ case BorderMode::Bottom:
+ stride = box_width / pool_size;
+ x_stride = -stride;
+ y_stride = 0;
+ break;
+ // right
+ case BorderMode::Right:
+ stride = box_height / pool_size;
+ x_stride = 0;
+ y_stride = -stride;
+ break;
+ }
+
+ // get position (x,y) which has maximum value during forward
+ x = *offset_box_x;
+ y = *(offset_box_x + 1);
+ x += x_stride * (T)(*offset_argmax_idx);
+ y += y_stride * (T)(*offset_argmax_idx);
+
+ T w1, w2, w3, w4;
+ int x_low, x_high, y_low, y_high;
+ bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, x_low,
+ x_high, y_low, y_high, index);
+
+ // update grad_output
+ atomicAdd(offset_grad_input + y_low * width + x_low,
+ *offset_grad_output * w1);
+ atomicAdd(offset_grad_input + y_low * width + x_high,
+ *offset_grad_output * w2);
+ atomicAdd(offset_grad_input + y_high * width + x_low,
+ *offset_grad_output * w3);
+ atomicAdd(offset_grad_input + y_high * width + x_high,
+ *offset_grad_output * w4);
+ }
+}
+
+#endif // BORDER_ALIGN_CUDA_KERNEL_CUH
diff --git a/mmcv/mmcv/ops/csrc/common/cuda/box_iou_rotated_cuda.cuh b/mmcv/mmcv/ops/csrc/common/cuda/box_iou_rotated_cuda.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..abd47cd85437804310886de057b5a839a49481b2
--- /dev/null
+++ b/mmcv/mmcv/ops/csrc/common/cuda/box_iou_rotated_cuda.cuh
@@ -0,0 +1,81 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+// modified from
+// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
+#ifndef BOX_IOU_ROTATED_CUDA_CUH
+#define BOX_IOU_ROTATED_CUDA_CUH
+
+#ifdef MMCV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else
+#include "pytorch_cuda_helper.hpp"
+#endif
+#include "box_iou_rotated_utils.hpp"
+
+// 2D block with 32 * 16 = 512 threads per block
+const int BLOCK_DIM_X = 32;
+const int BLOCK_DIM_Y = 16;
+
+inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); }
+
+template
+__global__ void box_iou_rotated_cuda_kernel(
+ const int n_boxes1, const int n_boxes2, const T* dev_boxes1,
+ const T* dev_boxes2, T* dev_ious, const int mode_flag, const bool aligned) {
+ if (aligned) {
+ CUDA_1D_KERNEL_LOOP(index, n_boxes1) {
+ int b1 = index;
+ int b2 = index;
+
+ int base1 = b1 * 5;
+
+ float block_boxes1[5];
+ float block_boxes2[5];
+
+ block_boxes1[0] = dev_boxes1[base1 + 0];
+ block_boxes1[1] = dev_boxes1[base1 + 1];
+ block_boxes1[2] = dev_boxes1[base1 + 2];
+ block_boxes1[3] = dev_boxes1[base1 + 3];
+ block_boxes1[4] = dev_boxes1[base1 + 4];
+
+ int base2 = b2 * 5;
+
+ block_boxes2[0] = dev_boxes2[base2 + 0];
+ block_boxes2[1] = dev_boxes2[base2 + 1];
+ block_boxes2[2] = dev_boxes2[base2 + 2];
+ block_boxes2[3] = dev_boxes2[base2 + 3];
+ block_boxes2[4] = dev_boxes2[base2 + 4];
+
+ dev_ious[index] =
+ single_box_iou_rotated(block_boxes1, block_boxes2, mode_flag);
+ }
+ } else {
+ CUDA_1D_KERNEL_LOOP(index, n_boxes1 * n_boxes2) {
+ int b1 = index / n_boxes2;
+ int b2 = index % n_boxes2;
+
+ int base1 = b1 * 5;
+
+ float block_boxes1[5];
+ float block_boxes2[5];
+
+ block_boxes1[0] = dev_boxes1[base1 + 0];
+ block_boxes1[1] = dev_boxes1[base1 + 1];
+ block_boxes1[2] = dev_boxes1[base1 + 2];
+ block_boxes1[3] = dev_boxes1[base1 + 3];
+ block_boxes1[4] = dev_boxes1[base1 + 4];
+
+ int base2 = b2 * 5;
+
+ block_boxes2[0] = dev_boxes2[base2 + 0];
+ block_boxes2[1] = dev_boxes2[base2 + 1];
+ block_boxes2[2] = dev_boxes2[base2 + 2];
+ block_boxes2[3] = dev_boxes2[base2 + 3];
+ block_boxes2[4] = dev_boxes2[base2 + 4];
+
+ dev_ious[index] =
+ single_box_iou_rotated(block_boxes1, block_boxes2, mode_flag);
+ }
+ }
+}
+
+#endif
diff --git a/mmcv/mmcv/ops/csrc/common/cuda/carafe_cuda_kernel.cuh b/mmcv/mmcv/ops/csrc/common/cuda/carafe_cuda_kernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..e7fa990fea1849f626baa0b81a726564373216a8
--- /dev/null
+++ b/mmcv/mmcv/ops/csrc/common/cuda/carafe_cuda_kernel.cuh
@@ -0,0 +1,332 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef CARAFE_CUDA_KERNEL_CUH
+#define CARAFE_CUDA_KERNEL_CUH
+
+#ifdef MMCV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else
+#include "pytorch_cuda_helper.hpp"
+#endif
+
+#ifdef HIP_DIFF
+#define WARP_SIZE 64
+#else
+#define WARP_SIZE 32
+#endif
+#define THREADS_PER_PIXEL 32
+#define MAX_SHARED_MEMORY 49152
+#define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144
+#define MAXIMIZE_KERNEL_SIZE true
+#define kTileDim 32
+#define kBlockRows 8
+#define FULL_MASK 0xffffffff
+
+inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); }
+
+__device__ inline int Loc2Index(const int n, const int c, const int h,
+ const int w, const int channel_num,
+ const int height, const int width) {
+ int index = w + (h + (c + n * channel_num) * height) * width;
+ return index;
+}
+#ifndef HIP_DIFF
+/* TODO: move this to a common place */
+template
+__device__ inline scalar_t min(scalar_t a, scalar_t b) {
+ return a < b ? a : b;
+}
+
+template
+__device__ inline scalar_t max(scalar_t a, scalar_t b) {
+ return a > b ? a : b;
+}
+#endif
+template
+__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
+ for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
+#ifdef HIP_DIFF
+ val += __shfl_down(val, offset);
+#else
+ val += __shfl_down_sync(FULL_MASK, val, offset);
+#endif
+ return val;
+}
+
+template <>
+__device__ __forceinline__ phalf warpReduceSum(phalf val) {
+ for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
+#ifdef HIP_DIFF
+ __PHALF(val) += __shfl_down(FULL_MASK, val, offset);
+#else
+ __PHALF(val) +=
+ __shfl_down_sync(FULL_MASK, static_cast<__half>(__PHALF(val)), offset);
+#endif
+ return val;
+}
+
+// Splits the original matrix into submatrices with size 32 * 32.
+// Each block transposes one submatrix by loading it into shared memory.
+// Reference https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/
+template
+__global__ void BatchTranspose2DCUDAKernel(const int N, const int H,
+ const int W, const int dh,
+ const int dw,
+ const scalar_t *__restrict__ X,
+ scalar_t *__restrict__ Y) {
+ __shared__ scalar_t tile[kTileDim][kTileDim + 1];
+ const int n = blockIdx.x / (dh * dw);
+ const int k = blockIdx.x % (dh * dw);
+ const int r = k / dw;
+ const int c = k % dw;
+ const int offset = n * H * W;
+ int x = c * kTileDim + threadIdx.x;
+ int y = r * kTileDim + threadIdx.y;
+ if (x < W) {
+ for (int i = 0; threadIdx.y + i < kTileDim && y + i < H; i += kBlockRows) {
+ tile[threadIdx.y + i][threadIdx.x] = X[offset + (y + i) * W + x];
+ }
+ }
+ __syncthreads();
+ x = r * kTileDim + threadIdx.x;
+ y = c * kTileDim + threadIdx.y;
+ if (x < H) {
+ for (int i = 0; threadIdx.y + i < kTileDim && y + i < W; i += kBlockRows) {
+ Y[offset + (y + i) * H + x] = tile[threadIdx.x][threadIdx.y + i];
+ }
+ }
+}
+template
+__global__ void CARAFEForward(
+ const int num_kernels, const scalar_t *__restrict__ bottom_data,
+ const scalar_t *__restrict__ bottom_masks, const int kernel_size,
+ const int group_size, const int scale_factor, const int channels,
+ const int down_height, const int down_width, const int height,
+ const int width, const int mask_channels, scalar_t *__restrict__ top_data) {
+#if MAXIMIZE_KERNEL_SIZE
+ __shared__ float shared_mask[MAX_SHARED_SCALAR_T * 2];
+#else
+ __shared__ scalar_t shared_mask[MAX_SHARED_SCALAR_T];
+#endif
+
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index > num_kernels - 1) {
+ return;
+ }
+ const int pixel_id = threadIdx.x / THREADS_PER_PIXEL;
+ const int split_id = threadIdx.x % THREADS_PER_PIXEL;
+ index = index / THREADS_PER_PIXEL;
+ const int pw = index % width;
+ const int ph = (index / width) % height;
+ const int n = index / width / height;
+
+ const int down_pw = pw / scale_factor;
+ const int down_ph = ph / scale_factor;
+
+ const int start_w = down_pw - (kernel_size - 1) / 2;
+ const int end_w = down_pw + (kernel_size - 1) / 2 + 1;
+ const int start_h = down_ph - (kernel_size - 1) / 2;
+ const int end_h = down_ph + (kernel_size - 1) / 2 + 1;
+ for (int c = split_id; c < mask_channels; c += THREADS_PER_PIXEL) {
+ int mask_index = Loc2Index(n, ph, pw, c, height, width, mask_channels);
+ shared_mask[c * WARP_SIZE + pixel_id] = bottom_masks[mask_index];
+ }
+ __syncthreads();
+
+ const int channels_per_group = ceilf(channels / (float)group_size);
+#pragma unroll
+ for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) {
+ int mask_group = c / channels_per_group;
+ scalar_t output_val = 0;
+#pragma unroll
+ for (int iy = start_h; iy < end_h; iy++) {
+#pragma unroll
+ for (int ix = start_w; ix < end_w; ix++) {
+ if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) {
+ continue;
+ }
+ int mask_iy = iy - down_ph + (kernel_size - 1) / 2;
+ int mask_ix = ix - down_pw + (kernel_size - 1) / 2;
+ int mask_c =
+ (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix;
+ int feat_index =
+ Loc2Index(n, iy, ix, c, down_height, down_width, channels);
+
+ output_val += bottom_data[feat_index] *
+ shared_mask[mask_c * WARP_SIZE + pixel_id];
+ }
+ }
+
+ int top_index = Loc2Index(n, ph, pw, c, height, width, channels);
+ top_data[top_index] = output_val;
+ }
+}
+
+template
+__global__ void CARAFEBackward_Feature(
+ const int num_kernels, const scalar_t *__restrict__ top_diff,
+ const scalar_t *__restrict__ bottom_masks, const int kernel_size,
+ const int group_size, const int scale_factor, const int channels,
+ const int down_height, const int down_width, const int height,
+ const int width, const int mask_channels,
+ scalar_t *__restrict__ bottom_diff) {
+#if MAXIMIZE_KERNEL_SIZE
+ __shared__ float shared_mask[MAX_SHARED_SCALAR_T * 2];
+#else
+ __shared__ scalar_t shared_mask[MAX_SHARED_SCALAR_T];
+#endif
+
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index > num_kernels - 1) {
+ return;
+ }
+
+ const int pixel_id = threadIdx.x / THREADS_PER_PIXEL;
+ const int split_id = threadIdx.x % THREADS_PER_PIXEL;
+ // (n, c, ph, pw) is an element in the bottom_data
+ index = index / THREADS_PER_PIXEL;
+ const int pw = index % width;
+ const int ph = (index / width) % height;
+ const int n = index / width / height;
+
+ const int start_w = pw - (kernel_size - 1) * scale_factor / 2;
+ const int end_w = pw + (kernel_size - 1) * scale_factor / 2 + 1;
+ const int start_h = ph - (kernel_size - 1) * scale_factor / 2;
+ const int end_h = ph + (kernel_size - 1) * scale_factor / 2 + 1;
+ for (int c = split_id; c < mask_channels; c += THREADS_PER_PIXEL) {
+ const int mask_w = (c % kernel_size) * scale_factor;
+ const int mask_h = (c / kernel_size % kernel_size) * scale_factor;
+ const int mask_x = start_w + mask_w;
+ const int mask_y = start_h + mask_h;
+ if (mask_y < 0 || mask_y > height - 1 || mask_x < 0 || mask_x > width - 1) {
+ shared_mask[c * WARP_SIZE + pixel_id] = 0;
+ continue;
+ }
+ const int mask_group = c / (kernel_size * kernel_size);
+ const int mask_c = (2 * mask_group + 1) * kernel_size * kernel_size - c - 1;
+ int mask_index =
+ Loc2Index(n, mask_c, mask_y, mask_x, mask_channels, height, width);
+ shared_mask[c * WARP_SIZE + pixel_id] = bottom_masks[mask_index];
+ }
+ __syncthreads();
+ const int channels_per_group = ceilf(channels / (float)group_size);
+#pragma unroll
+ for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) {
+ int mask_group = c / channels_per_group;
+ int top_index = Loc2Index(n, ph, pw, c, height, width, channels);
+ scalar_t output_val = 0;
+#pragma unroll
+ for (int iy = start_h; iy < end_h; iy += scale_factor) {
+#pragma unroll
+ for (int ix = start_w; ix < end_w; ix += scale_factor) {
+ if (iy < 0 || iy > height - 1 || ix < 0 || ix > width - 1) {
+ continue;
+ }
+ int mask_iy =
+ (iy - ph + (kernel_size - 1) * scale_factor / 2) / scale_factor;
+ int mask_ix =
+ (ix - pw + (kernel_size - 1) * scale_factor / 2) / scale_factor;
+ int mask_c =
+ (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix;
+ int feat_index = Loc2Index(n, iy, ix, c, height, width, channels);
+ output_val +=
+ shared_mask[mask_c * WARP_SIZE + pixel_id] * top_diff[feat_index];
+ }
+ }
+ bottom_diff[top_index] = output_val;
+ }
+}
+
+template
+__global__ void FeatureSum(const int num_kernels,
+ const scalar_t *__restrict__ input_data,
+ const int scale_factor, const int channels,
+ const int height, const int width,
+ scalar_t *__restrict__ output_data) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index > num_kernels - 1) {
+ return;
+ }
+ const int split_id = threadIdx.x % THREADS_PER_PIXEL;
+ index = index / THREADS_PER_PIXEL;
+ const int pw = index % width;
+ const int ph = (index / width) % height;
+ const int n = index / width / height;
+ for (int c = split_id; c < channels; c += THREADS_PER_PIXEL) {
+ scalar_t output_val = 0;
+ for (int iy = ph * scale_factor; iy < (ph + 1) * scale_factor; iy++) {
+ for (int ix = pw * scale_factor; ix < (pw + 1) * scale_factor; ix++) {
+ int input_id = Loc2Index(n, iy, ix, c, height * scale_factor,
+ width * scale_factor, channels);
+ output_val += input_data[input_id];
+ }
+ }
+ const int output_id = Loc2Index(n, ph, pw, c, height, width, channels);
+ output_data[output_id] = output_val;
+ }
+}
+
+template
+__global__ void CARAFEBackward_Mask(const int num_kernels,
+ const scalar_t *__restrict__ top_diff,
+ const scalar_t *__restrict__ bottom_data,
+ const int kernel_size, const int group_size,
+ const int scale_factor, const int channels,
+ const int down_height, const int down_width,
+ const int height, const int width,
+ const int mask_channels,
+ scalar_t *__restrict__ mask_diff) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index > num_kernels - 1) {
+ return;
+ }
+
+ const int lane_id = index % WARP_SIZE;
+ index = index / WARP_SIZE;
+ const int mask_c = index % mask_channels;
+ // (n, c, ph, pw) is an element in the bottom_data
+ index = index / mask_channels;
+ const int pw = index % width;
+ const int ph = (index / width) % height;
+ const int n = index / width / height;
+
+ const int down_pw = pw / scale_factor;
+ const int down_ph = ph / scale_factor;
+
+ const int mask_group = mask_c / (kernel_size * kernel_size);
+ const int mask_loc = mask_c % (kernel_size * kernel_size);
+
+ const int offset_x = mask_loc % kernel_size - (kernel_size - 1) / 2;
+ const int offset_y =
+ mask_loc / kernel_size % kernel_size - (kernel_size - 1) / 2;
+
+ const int down_x = down_pw + offset_x;
+ const int down_y = down_ph + offset_y;
+
+ scalar_t output_val = 0;
+
+ if (down_y >= 0 && down_y <= down_height - 1 && down_x >= 0 &&
+ down_x <= down_width - 1) {
+ const int channels_per_mask = ceilf(channels / (float)group_size);
+ const int start = channels_per_mask * mask_group;
+ const int end = min(channels_per_mask * (mask_group + 1), channels);
+ for (int c = start + lane_id; c < end; c += WARP_SIZE) {
+ int bottom_id =
+ Loc2Index(n, down_y, down_x, c, down_height, down_width, channels);
+ int top_id = Loc2Index(n, ph, pw, c, height, width, channels);
+ output_val += top_diff[top_id] * bottom_data[bottom_id];
+ }
+ }
+#ifdef HIP_DIFF
+ __syncthreads();
+#else
+ __syncwarp();
+#endif
+ output_val = warpReduceSum(output_val);
+ if (lane_id == 0) {
+ const int mask_id =
+ Loc2Index(n, ph, pw, mask_c, height, width, mask_channels);
+ mask_diff[mask_id] = output_val;
+ }
+}
+
+#endif // CARAFE_CUDA_KERNEL_CUH
diff --git a/mmcv/mmcv/ops/csrc/common/cuda/carafe_naive_cuda_kernel.cuh b/mmcv/mmcv/ops/csrc/common/cuda/carafe_naive_cuda_kernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..48230c632f223b736aa72a9d5fd682c97b3aa93a
--- /dev/null
+++ b/mmcv/mmcv/ops/csrc/common/cuda/carafe_naive_cuda_kernel.cuh
@@ -0,0 +1,111 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef CARAFE_NAIVE_CUDA_KERNEL_CUH
+#define CARAFE_NAIVE_CUDA_KERNEL_CUH
+
+#ifdef MMCV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else
+#include "pytorch_cuda_helper.hpp"
+#endif
+
+__device__ inline int Loc2Index(const int n, const int c, const int h,
+ const int w, const int channel_num,
+ const int height, const int width) {
+ int index = w + (h + (c + n * channel_num) * height) * width;
+ return index;
+}
+
+template
+__global__ void carafe_naive_forward_cuda_kernel(
+ const int nthreads, const scalar_t *bottom_data,
+ const scalar_t *bottom_masks, scalar_t *top_data, const int kernel_size,
+ const int group_size, const int scale_factor, const int channels,
+ const int height, const int width) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (n, c, ph, pw) is an element in the bottom_data
+ int pw = index % width;
+ int ph = (index / width) % height;
+ int c = (index / width / height) % channels;
+ int n = index / width / height / channels;
+
+ int mask_channels = kernel_size * kernel_size * group_size;
+ int mask_group = c / (channels / group_size);
+
+ int down_pw = pw / scale_factor;
+ int down_ph = ph / scale_factor;
+ int down_width = width / scale_factor;
+ int down_height = height / scale_factor;
+ int start_w = down_pw - (kernel_size - 1) / 2;
+ int end_w = down_pw + (kernel_size - 1) / 2 + 1;
+ int start_h = down_ph - (kernel_size - 1) / 2;
+ int end_h = down_ph + (kernel_size - 1) / 2 + 1;
+
+ scalar_t output_val = 0;
+ for (int iy = start_h; iy < end_h; iy++) {
+ for (int ix = start_w; ix < end_w; ix++) {
+ if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) {
+ continue;
+ }
+ int mask_iy = iy - down_ph + (kernel_size - 1) / 2;
+ int mask_ix = ix - down_pw + (kernel_size - 1) / 2;
+ int mask_c =
+ (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix;
+ int feat_index =
+ Loc2Index(n, c, iy, ix, channels, down_height, down_width);
+ int mask_index =
+ Loc2Index(n, mask_c, ph, pw, mask_channels, height, width);
+ output_val += bottom_data[feat_index] * bottom_masks[mask_index];
+ }
+ }
+ top_data[index] = output_val;
+ }
+}
+
+template
+__global__ void carafe_naive_backward_cuda_kernel(
+ const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_data,
+ const scalar_t *bottom_masks, scalar_t *bottom_diff, scalar_t *mask_diff,
+ const int kernel_size, const int group_size, const int scale_factor,
+ const int channels, const int height, const int width) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (n, c, ph, pw) is an element in the bottom_data
+ int pw = index % width;
+ int ph = (index / width) % height;
+ int c = (index / width / height) % channels;
+ int n = index / width / height / channels;
+
+ int mask_channels = kernel_size * kernel_size * group_size;
+ int mask_group = c / (channels / group_size);
+
+ int down_pw = pw / scale_factor;
+ int down_ph = ph / scale_factor;
+ int down_width = width / scale_factor;
+ int down_height = height / scale_factor;
+ int start_w = down_pw - (kernel_size - 1) / 2;
+ int end_w = down_pw + (kernel_size - 1) / 2 + 1;
+ int start_h = down_ph - (kernel_size - 1) / 2;
+ int end_h = down_ph + (kernel_size - 1) / 2 + 1;
+
+ for (int iy = start_h; iy < end_h; iy++) {
+ for (int ix = start_w; ix < end_w; ix++) {
+ if (iy < 0 || iy > down_height - 1 || ix < 0 || ix > down_width - 1) {
+ continue;
+ }
+ int mask_iy = iy - down_ph + (kernel_size - 1) / 2;
+ int mask_ix = ix - down_pw + (kernel_size - 1) / 2;
+ int mask_c =
+ (mask_group * kernel_size + mask_iy) * kernel_size + mask_ix;
+ int feat_index =
+ Loc2Index(n, c, iy, ix, channels, down_height, down_width);
+ int mask_index =
+ Loc2Index(n, mask_c, ph, pw, mask_channels, height, width);
+ atomicAdd(bottom_diff + feat_index,
+ bottom_masks[mask_index] * top_diff[index]);
+ atomicAdd(mask_diff + mask_index,
+ bottom_data[feat_index] * top_diff[index]);
+ }
+ }
+ }
+}
+
+#endif // CARAFE_NAIVE_CUDA_KERNEL_CUH
diff --git a/mmcv/mmcv/ops/csrc/common/cuda/chamfer_distance_cuda_kernel.cuh b/mmcv/mmcv/ops/csrc/common/cuda/chamfer_distance_cuda_kernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..89feea4a546a5093967f26393ca6be3b9fe6ae05
--- /dev/null
+++ b/mmcv/mmcv/ops/csrc/common/cuda/chamfer_distance_cuda_kernel.cuh
@@ -0,0 +1,101 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+// Modified from
+// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cu
+#ifndef CHAMFER_DISTANCE_CUDA_KERNEL_CUH
+#define CHAMFER_DISTANCE_CUDA_KERNEL_CUH
+
+#ifdef MMCV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else
+#include "pytorch_cuda_helper.hpp"
+#endif
+
+#define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144
+
+template
+__global__ void chamfer_distance_forward_cuda_kernel(int b, int n,
+ const scalar_t* xyz, int m,
+ const scalar_t* xyz2,
+ scalar_t* result,
+ int* result_i) {
+ __shared__ scalar_t buf[MAX_SHARED_SCALAR_T];
+ for (int i = blockIdx.x; i < b; i += gridDim.x) {
+ for (int k2 = 0; k2 < m; k2 += THREADS_PER_BLOCK) {
+ int end_k = min(m, k2 + THREADS_PER_BLOCK) - k2;
+ for (int j = threadIdx.x; j < end_k * 2; j += blockDim.x) {
+ buf[j] = xyz2[(i * m + k2) * 2 + j];
+ }
+ __syncthreads();
+ for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) {
+ scalar_t x1 = xyz[(i * n + j) * 2 + 0];
+ scalar_t y1 = xyz[(i * n + j) * 2 + 1];
+ int best_i = 0;
+ scalar_t best = 1e10;
+ int end_ka = end_k & (~2);
+ if (end_ka == THREADS_PER_BLOCK) {
+ for (int k = 0; k < THREADS_PER_BLOCK; k += 4) {
+#pragma unroll
+ for (int j = 0; j < 4; ++j) {
+ scalar_t x2 = buf[(k + j) * 2] - x1;
+ scalar_t y2 = buf[(k + j) * 2 + 1] - y1;
+ scalar_t d = x2 * x2 + y2 * y2;
+ if (d < best) {
+ best = d;
+ best_i = k + k2 + j;
+ }
+ }
+ }
+ } else {
+ for (int k = 0; k < end_ka; k += 4) {
+#pragma unroll
+ for (int j = 0; j < 4; ++j) {
+ scalar_t x2 = buf[(k + j) * 2] - x1;
+ scalar_t y2 = buf[(k + j) * 2 + 1] - y1;
+ scalar_t d = x2 * x2 + y2 * y2;
+ if (d < best) {
+ best = d;
+ best_i = k + k2 + j;
+ }
+ }
+ }
+ }
+ for (int k = end_ka; k < end_k; k++) {
+ scalar_t x2 = buf[k * 2 + 0] - x1;
+ scalar_t y2 = buf[k * 2 + 1] - y1;
+ scalar_t d = x2 * x2 + y2 * y2;
+ if (k == 0 || d < best) {
+ best = d;
+ best_i = k + k2;
+ }
+ }
+ if (k2 == 0 || result[(i * n + j)] > best) {
+ result[(i * n + j)] = best;
+ result_i[(i * n + j)] = best_i;
+ }
+ }
+ __syncthreads();
+ }
+ }
+}
+
+template
+__global__ void chamfer_distance_backward_cuda_kernel(
+ int b, int n, const scalar_t* xyz1, int m, const scalar_t* xyz2,
+ const scalar_t* grad_dist1, const int* idx1, scalar_t* grad_xyz1,
+ scalar_t* grad_xyz2) {
+ for (int i = blockIdx.x; i < b; i += gridDim.x) {
+ for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) {
+ scalar_t x1 = xyz1[(i * n + j) * 2 + 0];
+ scalar_t y1 = xyz1[(i * n + j) * 2 + 1];
+ int j2 = idx1[i * n + j];
+ scalar_t x2 = xyz2[(i * m + j2) * 2 + 0];
+ scalar_t y2 = xyz2[(i * m + j2) * 2 + 1];
+ scalar_t g = grad_dist1[i * n + j] * 2;
+ atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 0]), g * (x1 - x2));
+ atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 1]), g * (y1 - y2));
+ atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 0]), -(g * (x1 - x2)));
+ atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 1]), -(g * (y1 - y2)));
+ }
+ }
+}
+#endif // CHAMFER_DISTANCE_CUDA_KERNEL_CUH
diff --git a/mmcv/mmcv/ops/csrc/common/cuda/common_cuda_helper.hpp b/mmcv/mmcv/ops/csrc/common/cuda/common_cuda_helper.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..b12aa9a26a2cc162fd89f68ccc97e17749090a41
--- /dev/null
+++ b/mmcv/mmcv/ops/csrc/common/cuda/common_cuda_helper.hpp
@@ -0,0 +1,120 @@
+#ifndef COMMON_CUDA_HELPER
+#define COMMON_CUDA_HELPER
+
+#include
+
+#define CUDA_1D_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x)
+
+#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
+ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x) \
+ for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); \
+ j += blockDim.y * gridDim.y)
+
+#define CUDA_2D_KERNEL_BLOCK_LOOP(i, n, j, m) \
+ for (size_t i = blockIdx.x; i < (n); i += gridDim.x) \
+ for (size_t j = blockIdx.y; j < (m); j += gridDim.y)
+
+#define THREADS_PER_BLOCK 512
+
+inline int GET_BLOCKS(const int N, const int num_threads = THREADS_PER_BLOCK) {
+ int optimal_block_num = (N + num_threads - 1) / num_threads;
+ int max_block_num = 4096;
+ return min(optimal_block_num, max_block_num);
+}
+
+template
+__device__ T bilinear_interpolate(const T* input, const int height,
+ const int width, T y, T x,
+ const int index /* index for debug only*/) {
+ // deal with cases that inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) return 0;
+
+ if (y <= 0) y = 0;
+ if (x <= 0) x = 0;
+
+ int y_low = (int)y;
+ int x_low = (int)x;
+ int y_high;
+ int x_high;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+ // do bilinear interpolation
+ T v1 = input[y_low * width + x_low];
+ T v2 = input[y_low * width + x_high];
+ T v3 = input[y_high * width + x_low];
+ T v4 = input[y_high * width + x_high];
+ T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+ return val;
+}
+
+template
+__device__ void bilinear_interpolate_gradient(
+ const int height, const int width, T y, T x, T& w1, T& w2, T& w3, T& w4,
+ int& x_low, int& x_high, int& y_low, int& y_high,
+ const int index /* index for debug only*/) {
+ // deal with cases that inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
+ // empty
+ w1 = w2 = w3 = w4 = 0.;
+ x_low = x_high = y_low = y_high = -1;
+ return;
+ }
+
+ if (y <= 0) y = 0;
+ if (x <= 0) x = 0;
+
+ y_low = (int)y;
+ x_low = (int)x;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+
+ // reference in forward
+ // T v1 = input[y_low * width + x_low];
+ // T v2 = input[y_low * width + x_high];
+ // T v3 = input[y_high * width + x_low];
+ // T v4 = input[y_high * width + x_high];
+ // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+ w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ return;
+}
+#endif // COMMON_CUDA_HELPER
diff --git a/mmcv/mmcv/ops/csrc/common/cuda/convex_iou_cuda_kernel.cuh b/mmcv/mmcv/ops/csrc/common/cuda/convex_iou_cuda_kernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..2af96f7963ec347486ced942a5ef7cc4f187db8b
--- /dev/null
+++ b/mmcv/mmcv/ops/csrc/common/cuda/convex_iou_cuda_kernel.cuh
@@ -0,0 +1,831 @@
+// Copyright (c) OpenMMLab. All rights reserved
+#ifndef CONVEX_IOU_CUDA_KERNEL_CUH
+#define CONVEX_IOU_CUDA_KERNEL_CUH
+
+#ifdef MMCV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else
+#include "pytorch_cuda_helper.hpp"
+#endif
+
+#define MAXN 100
+#define NMAX 512
+__device__ const double EPS = 1E-8;
+
+__device__ inline int sig(double d) { return (d > EPS) - (d < -EPS); }
+
+struct Point {
+ double x, y;
+ __device__ Point() {}
+ __device__ Point(double x, double y) : x(x), y(y) {}
+};
+
+__device__ inline bool point_same(Point& a, Point& b) {
+ return sig(a.x - b.x) == 0 && sig(a.y - b.y) == 0;
+}
+
+__device__ inline void swap1(Point* a, Point* b) {
+ Point temp;
+ temp.x = a->x;
+ temp.y = a->y;
+
+ a->x = b->x;
+ a->y = b->y;
+
+ b->x = temp.x;
+ b->y = temp.y;
+}
+
+__device__ inline void reverse1(Point* a, const int n) {
+ for (int i = 0; i < (n - 1) / 2.0; i++) {
+ Point* j = &(a[i]);
+ Point* k = &(a[n - 1 - i]);
+ swap1(j, k);
+ }
+}
+
+__device__ inline double cross(Point o, Point a, Point b) {
+ return (a.x - o.x) * (b.y - o.y) - (b.x - o.x) * (a.y - o.y);
+}
+
+__device__ inline double dis(Point a, Point b) {
+ return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y);
+}
+__device__ inline double area(Point* ps, int n) {
+ ps[n] = ps[0];
+ double res = 0;
+ for (int i = 0; i < n; i++) {
+ res += ps[i].x * ps[i + 1].y - ps[i].y * ps[i + 1].x;
+ }
+ return res / 2.0;
+}
+__device__ inline double polygon_area_grad(Point* ps, int n,
+ int* polygon_to_pred_index,
+ int n_pred, double* grad_C) {
+ ps[n] = ps[0];
+ double partion_grad[4 * 30 + 2];
+ double res = 0;
+ for (int i = 0; i < n; i++) {
+ res += ps[i].x * ps[i + 1].y - ps[i].y * ps[i + 1].x;
+ partion_grad[i * 4 + 2] = ps[i + 1].y;
+ partion_grad[i * 4 + 3] = -ps[i + 1].x;
+ if (i != n - 1) {
+ partion_grad[i * 4 + 4] = -ps[i].y;
+ partion_grad[i * 4 + 5] = ps[i].x;
+ } else {
+ partion_grad[0] = -ps[i].y;
+ partion_grad[1] = ps[i].x;
+ }
+ }
+ for (int i = 0; i < n; i++) {
+ for (int j = 0; j < n_pred; j++) {
+ if (i == polygon_to_pred_index[j]) {
+ grad_C[2 * polygon_to_pred_index[j + n_pred]] =
+ (partion_grad[i * 4] + partion_grad[i * 4 + 2]) / 2;
+ break;
+ }
+ }
+ for (int j = 0; j < n_pred; j++) {
+ if (i == polygon_to_pred_index[j]) {
+ grad_C[2 * polygon_to_pred_index[j + n_pred] + 1] =
+ (partion_grad[i * 4 + 1] + partion_grad[i * 4 + 1 + 2]) / 2;
+ break;
+ }
+ }
+ }
+
+ return res / 2.0;
+}
+
+__device__ inline int lineCross(Point a, Point b, Point c, Point d, Point& p,
+ double* cut_grad, int m, int n, int i) {
+ double s1, s2;
+ double s2_s1_2;
+ double ds1_dxc, ds1_dyc, ds2_dxd, ds2_dyd;
+ double dxp_dxc, dxp_dyc, dxp_dxd, dxp_dyd, dyp_dxc, dyp_dyc, dyp_dxd, dyp_dyd;
+ s1 = cross(a, b, c);
+ s2 = cross(a, b, d);
+
+ ds1_dxc = -(b.y - a.y);
+ ds1_dyc = b.x - a.x;
+ ds2_dxd = ds1_dxc;
+ ds2_dyd = ds1_dyc;
+ s2_s1_2 = (s2 - s1) * (s2 - s1);
+
+ if (sig(s1) == 0 && sig(s2) == 0) return 2;
+ if (sig(s2 - s1) == 0) return 0;
+
+ dxp_dxc =
+ ((s2 - d.x * ds1_dxc) * (s2 - s1) - (c.x * s2 - d.x * s1) * (-ds1_dxc)) /
+ (s2_s1_2);
+ dxp_dyc =
+ ((0 - d.x * ds1_dyc) * (s2 - s1) - (c.x * s2 - d.x * s1) * (-ds1_dyc)) /
+ (s2_s1_2);
+ dxp_dxd =
+ ((c.x * ds2_dxd - s1) * (s2 - s1) - (c.x * s2 - d.x * s1) * (ds2_dxd)) /
+ (s2_s1_2);
+ dxp_dyd =
+ ((c.x * ds2_dyd - 0) * (s2 - s1) - (c.x * s2 - d.x * s1) * (ds2_dyd)) /
+ (s2_s1_2);
+
+ dyp_dxc =
+ ((0 - d.y * ds1_dxc) * (s2 - s1) - (c.y * s2 - d.y * s1) * (-ds1_dxc)) /
+ (s2_s1_2);
+ dyp_dyc =
+ ((s2 - d.y * ds1_dyc) * (s2 - s1) - (c.y * s2 - d.y * s1) * (-ds1_dyc)) /
+ (s2_s1_2);
+ dyp_dxd =
+ ((c.y * ds2_dxd - 0) * (s2 - s1) - (c.y * s2 - d.y * s1) * (ds2_dxd)) /
+ (s2_s1_2);
+ dyp_dyd =
+ ((c.y * ds2_dyd - s1) * (s2 - s1) - (c.y * s2 - d.y * s1) * (ds2_dyd)) /
+ (s2_s1_2);
+
+ p.x = (c.x * s2 - d.x * s1) / (s2 - s1);
+ p.y = (c.y * s2 - d.y * s1) / (s2 - s1);
+ if (i == n - 1) {
+ cut_grad[4 * n * m + 4 * i] = dxp_dxc; // + dyp_dxc;
+ cut_grad[4 * n * m + 4 * i + 1] = dyp_dxc;
+ cut_grad[4 * n * m + 4 * i + 2] = dxp_dyc; // + dyp_dyc;
+ cut_grad[4 * n * m + 4 * i + 3] = dyp_dyc;
+ cut_grad[4 * n * m + 0] = dxp_dxd; // + dyp_dxd;
+ cut_grad[4 * n * m + 1] = dyp_dxd;
+ cut_grad[4 * n * m + 2] = dxp_dyd; // + dyp_dyd;
+ cut_grad[4 * n * m + 3] = dyp_dyd;
+ } else {
+ cut_grad[4 * n * m + 4 * i] = dxp_dxc; // + dyp_dxc;
+ cut_grad[4 * n * m + 4 * i + 1] = dyp_dxc;
+ cut_grad[4 * n * m + 4 * i + 2] = dxp_dyc; // + dyp_dyc;
+ cut_grad[4 * n * m + 4 * i + 3] = dyp_dyc;
+ cut_grad[4 * n * m + 4 * (i + 1)] = dxp_dxd; // + dyp_dxd;
+ cut_grad[4 * n * m + 4 * (i + 1) + 1] = dyp_dxd;
+ cut_grad[4 * n * m + 4 * (i + 1) + 2] = dxp_dyd; // + dyp_dyd;
+ cut_grad[4 * n * m + 4 * (i + 1) + 3] = dyp_dyd;
+ }
+
+ return 1;
+}
+__device__ inline void polygon_cut(Point* p, int& n, Point a, Point b,
+ double* cut_grad) {
+ Point pp[MAXN];
+ double ccur_grad[MAXN] = {};
+ int m = 0;
+ p[n] = p[0];
+ int k = n;
+ for (int i = 0; i < n; i++) {
+ if (sig(cross(a, b, p[i])) > 0) {
+ pp[m] = p[i];
+ ccur_grad[4 * n * m + 4 * i] = 1.0;
+ ccur_grad[4 * n * m + 4 * i + 3] = 1.0;
+ m++;
+ }
+ if (sig(cross(a, b, p[i])) != sig(cross(a, b, p[i + 1]))) {
+ lineCross(a, b, p[i], p[i + 1], pp[m], ccur_grad, m, n, i);
+ m++;
+ }
+ }
+
+ n = 0;
+ for (int i = 0; i < m; i++) {
+ if (!i || !(point_same(pp[i], pp[i - 1]))) {
+ p[n] = pp[i];
+ for (int j = 0; j < 4 * k; j++) {
+ cut_grad[4 * k * n + j] = ccur_grad[4 * k * i + j];
+ }
+ n++;
+ }
+ }
+
+ while (n > 1 && point_same(p[n - 1], p[0])) n--;
+}
+
+__device__ inline double intersectArea(Point a, Point b, Point c, Point d,
+ double* grad_AB, int order,
+ int convex_n) {
+ Point o(0, 0);
+ int res_flag = 0;
+ int s1 = sig(cross(o, a, b));
+ int s2 = sig(cross(o, c, d));
+ if (s1 == 0 || s2 == 0) return 0.0;
+ if (s1 == -1) {
+ Point* i = &a;
+ Point* j = &b;
+ swap1(i, j);
+ res_flag = 1;
+ }
+ if (s2 == -1) {
+ Point* i = &c;
+ Point* j = &d;
+ swap1(i, j);
+ }
+ Point p[10] = {o, a, b};
+ int n = 3, n0 = 3, n1, n2, n3;
+ double cut_grad1[MAXN] = {};
+ double cut_grad2[MAXN] = {};
+ double cut_grad3[MAXN] = {};
+ double p1_p_grad[10][10] = {};
+ double p2_p1_grad[10][10] = {};
+ double p3_p2_grad[10][10] = {};
+
+ double p3_p1_grad[10][10] = {};
+ double p3_p_grad[10][10] = {};
+
+ // 1
+ polygon_cut(p, n, o, c, cut_grad1);
+ n1 = n;
+ for (int i = 0; i < n; i++) {
+ for (int j = 0; j < 4 * n0; j++) {
+ if (!(j % 2)) {
+ p1_p_grad[2 * i][j / 2] = cut_grad1[4 * n0 * i + j];
+ } else {
+ p1_p_grad[2 * i + 1][j / 2] = cut_grad1[4 * n0 * i + j];
+ }
+ }
+ }
+
+ // 2
+ polygon_cut(p, n, c, d, cut_grad2);
+ n2 = n;
+ for (int i = 0; i < n; i++) {
+ for (int j = 0; j < 4 * n1; j++) {
+ if (!(j % 2)) {
+ p2_p1_grad[2 * i][j / 2] = cut_grad2[4 * n1 * i + j];
+ } else {
+ p2_p1_grad[2 * i + 1][j / 2] = cut_grad2[4 * n1 * i + j];
+ }
+ }
+ }
+ // 3
+ polygon_cut(p, n, d, o, cut_grad3);
+ n3 = n;
+ for (int i = 0; i < n; i++) {
+ for (int j = 0; j < 4 * n2; j++) {
+ if (!(j % 2)) {
+ p3_p2_grad[2 * i][j / 2] = cut_grad3[4 * n2 * i + j];
+ } else {
+ p3_p2_grad[2 * i + 1][j / 2] = cut_grad3[4 * n2 * i + j];
+ }
+ }
+ }
+
+ // mul
+ // p3_p2(n3 * n2) * p2_p1(n2 * n1) = p3_p1 (n3 * n1)
+ for (int i = 0; i < 2 * n3; i++) {
+ for (int j = 0; j < 2 * n1; j++) {
+ double sum = 0.0;
+ for (int m = 0; m < 2 * n2; m++) {
+ sum = sum + p3_p2_grad[i][m] * p2_p1_grad[m][j];
+ }
+ p3_p1_grad[i][j] = sum;
+ }
+ }
+
+ // p3_p1 (n3 * n1) * p1_p (n1 * n0) = p3_p (n3 * n0)
+ for (int i = 0; i < 2 * n3; i++) {
+ for (int j = 0; j < 2 * n0; j++) {
+ double sum = 0.0;
+ for (int m = 0; m < 2 * n1; m++) {
+ sum = sum + p3_p1_grad[i][m] * p1_p_grad[m][j];
+ }
+ p3_p_grad[i][j] = sum;
+ }
+ }
+
+ // calculate S_grad
+ int polygon_index_box_index[20];
+ double grad_polygon[20];
+ double S_grad[6];
+
+ for (int i = 0; i < n3; i++) {
+ polygon_index_box_index[i] = i;
+ polygon_index_box_index[i + n3] = i;
+ }
+
+ double res =
+ polygon_area_grad(p, n3, polygon_index_box_index, n3, grad_polygon);
+
+ if (s1 * s2 == -1) {
+ for (int j = 0; j < 2 * 3; j++) {
+ double sum = 0.0;
+ for (int m = 0; m < 2 * n3; m++) {
+ sum = sum - grad_polygon[m] * p3_p_grad[m][j];
+ }
+ S_grad[j] = sum;
+ }
+
+ if (order != convex_n - 1) {
+ if (res_flag) {
+ grad_AB[2 * order] += S_grad[4];
+ grad_AB[2 * order + 1] += S_grad[5];
+ grad_AB[2 * order + 2] += S_grad[2];
+ grad_AB[2 * order + 3] += S_grad[3];
+
+ } else {
+ grad_AB[2 * order] += S_grad[2];
+ grad_AB[2 * order + 1] += S_grad[3];
+ grad_AB[2 * order + 2] += S_grad[4];
+ grad_AB[2 * order + 3] += S_grad[5];
+ }
+ } else {
+ if (res_flag) {
+ grad_AB[2 * order] += S_grad[4];
+ grad_AB[2 * order + 1] += S_grad[5];
+ grad_AB[0] += S_grad[2];
+ grad_AB[1] += S_grad[3];
+
+ } else {
+ grad_AB[2 * order] += S_grad[2];
+ grad_AB[2 * order + 1] += S_grad[3];
+ grad_AB[0] += S_grad[4];
+ grad_AB[1] += S_grad[5];
+ }
+ }
+ res = -res;
+ } else {
+ for (int j = 0; j < 2 * 3; j++) {
+ double sum = 0.0;
+ for (int m = 0; m < 2 * n3; m++) {
+ sum = sum + grad_polygon[m] * p3_p_grad[m][j];
+ }
+ S_grad[j] = sum;
+ }
+
+ if (order != convex_n - 1) {
+ if (res_flag) {
+ grad_AB[2 * order] += S_grad[4];
+ grad_AB[2 * order + 1] += S_grad[5];
+ grad_AB[2 * order + 2] += S_grad[2];
+ grad_AB[2 * order + 3] += S_grad[3];
+ } else {
+ grad_AB[2 * order] += S_grad[2];
+ grad_AB[2 * order + 1] += S_grad[3];
+ grad_AB[2 * order + 2] += S_grad[4];
+ grad_AB[2 * order + 3] += S_grad[5];
+ }
+ } else {
+ if (res_flag) {
+ grad_AB[2 * order] += S_grad[4];
+ grad_AB[2 * order + 1] += S_grad[5];
+ grad_AB[0] += S_grad[2];
+ grad_AB[1] += S_grad[3];
+ } else {
+ grad_AB[2 * order] += S_grad[2];
+ grad_AB[2 * order + 1] += S_grad[3];
+ grad_AB[0] += S_grad[4];
+ grad_AB[1] += S_grad[5];
+ }
+ }
+ }
+ return res;
+}
+
+__device__ inline double intersectAreaO(Point* ps1, int n1, Point* ps2, int n2,
+ double* grad_AB) {
+ if (area(ps1, n1) < 0) reverse1(ps1, n1);
+ if (area(ps2, n2) < 0) reverse1(ps2, n2);
+ ps1[n1] = ps1[0];
+ ps2[n2] = ps2[0];
+ double res = 0;
+ for (int i = 0; i < n1; i++) {
+ for (int j = 0; j < n2; j++) {
+ res +=
+ intersectArea(ps1[i], ps1[i + 1], ps2[j], ps2[j + 1], grad_AB, i, n1);
+ }
+ }
+ return res;
+}
+
+__device__ inline void Jarvis(Point* in_poly, int& n_poly) {
+ Point p_max, p_k;
+ int max_index, k_index;
+ int Stack[NMAX] = {}, top1, top2;
+ double sign;
+ Point right_point[10], left_point[10];
+
+ for (int i = 0; i < n_poly; i++) {
+ if (in_poly[i].y < in_poly[0].y ||
+ in_poly[i].y == in_poly[0].y && in_poly[i].x < in_poly[0].x) {
+ Point* j = &(in_poly[0]);
+ Point* k = &(in_poly[i]);
+ swap1(j, k);
+ }
+ if (i == 0) {
+ p_max = in_poly[0];
+ max_index = 0;
+ }
+ if (in_poly[i].y > p_max.y ||
+ in_poly[i].y == p_max.y && in_poly[i].x > p_max.x) {
+ p_max = in_poly[i];
+ max_index = i;
+ }
+ }
+
+ if (max_index == 0) {
+ max_index = 1;
+ p_max = in_poly[max_index];
+ }
+
+ k_index = 0, Stack[0] = 0, top1 = 0;
+ while (k_index != max_index) {
+ p_k = p_max;
+ k_index = max_index;
+ for (int i = 1; i < n_poly; i++) {
+ sign = cross(in_poly[Stack[top1]], in_poly[i], p_k);
+ if ((sign > 0) || ((sign == 0) && (dis(in_poly[Stack[top1]], in_poly[i]) >
+ dis(in_poly[Stack[top1]], p_k)))) {
+ p_k = in_poly[i];
+ k_index = i;
+ }
+ }
+ top1++;
+ Stack[top1] = k_index;
+ }
+ for (int i = 0; i <= top1; i++) right_point[i] = in_poly[Stack[i]];
+
+ k_index = 0, Stack[0] = 0, top2 = 0;
+
+ while (k_index != max_index) {
+ p_k = p_max;
+ k_index = max_index;
+ for (int i = 1; i < n_poly; i++) {
+ sign = cross(in_poly[Stack[top2]], in_poly[i], p_k);
+ if ((sign < 0) || (sign == 0) && (dis(in_poly[Stack[top2]], in_poly[i]) >
+ dis(in_poly[Stack[top2]], p_k))) {
+ p_k = in_poly[i];
+ k_index = i;
+ }
+ }
+ top2++;
+ Stack[top2] = k_index;
+ }
+ for (int i = top2 - 1; i >= 0; i--) left_point[i] = in_poly[Stack[i]];
+
+ for (int i = 0; i < top1 + top2; i++) {
+ if (i <= top1) {
+ in_poly[i] = right_point[i];
+ } else {
+ in_poly[i] = left_point[top2 - (i - top1)];
+ }
+ }
+ n_poly = top1 + top2;
+}
+
+__device__ inline double intersectAreaPoly(Point* ps1, int n1, Point* ps2,
+ int n2, double* grad_C) {
+ Point polygon[MAXN];
+ int n = n1 + n2, n_poly = 0;
+ for (int i = 0; i < n1; i++) {
+ for (int j = 0; j < n - n1; j++) {
+ if (point_same(ps1[i], ps2[j])) {
+ for (int k = j; k < n - n1 - 1; k++) {
+ ps2[k] = ps2[k + 1];
+ }
+ n2--;
+ break;
+ }
+ }
+ }
+ n_poly = n1 + n2;
+ for (int i = 0; i < n_poly; i++) {
+ if (i < n1) {
+ polygon[i] = ps1[i];
+ } else {
+ polygon[i] = ps2[i - n1];
+ }
+ }
+
+ Jarvis(polygon, n_poly);
+
+ int polygon_to_pred_index[18] = {-1, -1, -1, -1, -1, -1, -1, -1, -1,
+ -1, -1, -1, -1, -1, -1, -1, -1, -1};
+ int n_pred = 0;
+ for (int i = 0; i < n_poly; i++) {
+ for (int j = 0; j < n1; j++) {
+ if (polygon[i].x == ps1[j].x && polygon[i].y == ps1[j].y) {
+ polygon_to_pred_index[n_pred] = i;
+ polygon_to_pred_index[n_pred + n1] = j;
+ n_pred += 1;
+ break;
+ }
+ }
+ }
+ if (n_pred == 0) {
+ double polygon_area = fabs(area(polygon, n_poly));
+ for (int i = 0; i < 18; i++) {
+ grad_C[i] = 0.0;
+ }
+ return polygon_area;
+ } else {
+ double polygon_area =
+ polygon_area_grad(polygon, n_poly, polygon_to_pred_index, n1, grad_C);
+ if (polygon_area < 0) {
+ for (int i = 0; i < 18; i++) {
+ grad_C[i] = -grad_C[i];
+ }
+ }
+ return fabs(polygon_area);
+ }
+}
+
+// convex_find and get the polygon_index_box_index
+__device__ inline void Jarvis_and_index(Point* in_poly, int& n_poly,
+ int* points_to_convex_ind) {
+ int n_input = n_poly;
+ Point input_poly[20];
+ for (int i = 0; i < n_input; i++) {
+ input_poly[i].x = in_poly[i].x;
+ input_poly[i].y = in_poly[i].y;
+ }
+ Point p_max, p_k;
+ int max_index, k_index;
+ int Stack[20], top1, top2;
+ double sign;
+ Point right_point[10], left_point[10];
+
+ for (int i = 0; i < n_poly; i++) {
+ if (in_poly[i].y < in_poly[0].y ||
+ in_poly[i].y == in_poly[0].y && in_poly[i].x < in_poly[0].x) {
+ Point* j = &(in_poly[0]);
+ Point* k = &(in_poly[i]);
+ swap1(j, k);
+ }
+ if (i == 0) {
+ p_max = in_poly[0];
+ max_index = 0;
+ }
+ if (in_poly[i].y > p_max.y ||
+ in_poly[i].y == p_max.y && in_poly[i].x > p_max.x) {
+ p_max = in_poly[i];
+ max_index = i;
+ }
+ }
+ if (max_index == 0) {
+ max_index = 1;
+ p_max = in_poly[max_index];
+ }
+
+ k_index = 0, Stack[0] = 0, top1 = 0;
+ while (k_index != max_index) {
+ p_k = p_max;
+ k_index = max_index;
+ for (int i = 1; i < n_poly; i++) {
+ sign = cross(in_poly[Stack[top1]], in_poly[i], p_k);
+ if ((sign > 0) || ((sign == 0) && (dis(in_poly[Stack[top1]], in_poly[i]) >
+ dis(in_poly[Stack[top1]], p_k)))) {
+ p_k = in_poly[i];
+ k_index = i;
+ }
+ }
+ top1++;
+ Stack[top1] = k_index;
+ }
+ for (int i = 0; i <= top1; i++) {
+ right_point[i] = in_poly[Stack[i]];
+ }
+
+ k_index = 0, Stack[0] = 0, top2 = 0;
+
+ while (k_index != max_index) {
+ p_k = p_max;
+ k_index = max_index;
+ for (int i = 1; i < n_poly; i++) {
+ sign = cross(in_poly[Stack[top2]], in_poly[i], p_k);
+ if ((sign < 0) || (sign == 0) && (dis(in_poly[Stack[top2]], in_poly[i]) >
+ dis(in_poly[Stack[top2]], p_k))) {
+ p_k = in_poly[i];
+ k_index = i;
+ }
+ }
+ top2++;
+ Stack[top2] = k_index;
+ }
+
+ for (int i = top2 - 1; i >= 0; i--) {
+ left_point[i] = in_poly[Stack[i]];
+ }
+
+ for (int i = 0; i < top1 + top2; i++) {
+ if (i <= top1) {
+ in_poly[i] = right_point[i];
+ } else {
+ in_poly[i] = left_point[top2 - (i - top1)];
+ }
+ }
+ n_poly = top1 + top2;
+ for (int i = 0; i < n_poly; i++) {
+ for (int j = 0; j < n_input; j++) {
+ if (point_same(in_poly[i], input_poly[j])) {
+ points_to_convex_ind[i] = j;
+ break;
+ }
+ }
+ }
+}
+
+template
+__device__ inline float devrIoU(T const* const p, T const* const q,
+ T* point_grad, const int idx) {
+ Point ps1[MAXN], ps2[MAXN];
+
+ Point convex[MAXN];
+ for (int i = 0; i < 9; i++) {
+ convex[i].x = (double)p[i * 2];
+ convex[i].y = (double)p[i * 2 + 1];
+ }
+ int n_convex = 9;
+ int points_to_convex_ind[9] = {-1, -1, -1, -1, -1, -1, -1, -1, -1};
+ Jarvis_and_index(convex, n_convex, points_to_convex_ind);
+
+ int n1 = n_convex;
+ int n2 = 4;
+
+ for (int i = 0; i < n1; i++) {
+ ps1[i].x = (double)convex[i].x;
+ ps1[i].y = (double)convex[i].y;
+ }
+
+ for (int i = 0; i < n2; i++) {
+ ps2[i].x = (double)q[i * 2];
+ ps2[i].y = (double)q[i * 2 + 1];
+ }
+
+ int polygon_index_box_index[18];
+ for (int i = 0; i < n1; i++) {
+ polygon_index_box_index[i] = i;
+ polygon_index_box_index[i + n1] = i;
+ }
+
+ double grad_A[18] = {};
+ double grad_AB[18] = {};
+ double grad_C[18] = {};
+
+ double inter_area = intersectAreaO(ps1, n1, ps2, n2, grad_AB);
+ double S_pred =
+ polygon_area_grad(ps1, n1, polygon_index_box_index, n1, grad_A);
+ if (S_pred < 0) {
+ for (int i = 0; i < n_convex * 2; i++) {
+ grad_A[i] = -grad_A[i];
+ }
+ }
+ double union_area = fabs(S_pred) + fabs(area(ps2, n2)) - inter_area;
+
+ double iou = inter_area / union_area;
+ double polygon_area = intersectAreaPoly(ps1, n1, ps2, n2, grad_C);
+
+ // printf("%d:live\n", idx);
+ double rot_giou = iou - (polygon_area - union_area) / polygon_area;
+
+ float grad_point_temp[18] = {};
+
+ for (int i = 0; i < n_convex; i++) {
+ int grad_point = points_to_convex_ind[i];
+ grad_point_temp[2 * grad_point] =
+ (float)((union_area + inter_area) / (union_area * union_area) *
+ grad_AB[2 * i] -
+ iou / union_area * grad_A[2 * i] -
+ 1 / polygon_area * (grad_AB[2 * i] - grad_A[2 * i]) -
+ (union_area) / polygon_area / polygon_area * grad_C[2 * i]);
+ grad_point_temp[2 * grad_point + 1] =
+ (float)((union_area + inter_area) / (union_area * union_area) *
+ grad_AB[2 * i + 1] -
+ iou / union_area * grad_A[2 * i + 1] -
+ 1 / polygon_area * (grad_AB[2 * i + 1] - grad_A[2 * i + 1]) -
+ (union_area) / polygon_area / polygon_area * grad_C[2 * i + 1]);
+ }
+
+ for (int i = 0; i < 9; i++) {
+ point_grad[2 * i] = grad_point_temp[2 * i];
+ point_grad[2 * i + 1] = grad_point_temp[2 * i + 1];
+ }
+ return (float)rot_giou;
+}
+
+template
+__global__ void convex_giou_cuda_kernel(const int ex_n_boxes,
+ const int gt_n_boxes, const T* ex_boxes,
+ const T* gt_boxes, T* point_grad) {
+ CUDA_1D_KERNEL_LOOP(index, ex_n_boxes) {
+ const T* cur_box = ex_boxes + index * 18;
+ const T* cur_gt_box = gt_boxes + index * 8;
+ T* cur_grad = point_grad + index * 19;
+ T giou = devrIoU(cur_box, cur_gt_box, cur_grad, threadIdx.x);
+ cur_grad[18] = giou;
+ }
+}
+
+__device__ inline int lineCross(Point a, Point b, Point c, Point d, Point& p) {
+ double s1, s2;
+ s1 = cross(a, b, c);
+ s2 = cross(a, b, d);
+ if (sig(s1) == 0 && sig(s2) == 0) return 2;
+ if (sig(s2 - s1) == 0) return 0;
+ p.x = (c.x * s2 - d.x * s1) / (s2 - s1);
+ p.y = (c.y * s2 - d.y * s1) / (s2 - s1);
+ return 1;
+}
+
+__device__ inline void polygon_cut(Point* p, int& n, Point a, Point b) {
+ Point pp[MAXN];
+ int m = 0;
+ p[n] = p[0];
+ for (int i = 0; i < n; i++) {
+ if (sig(cross(a, b, p[i])) > 0) {
+ pp[m] = p[i];
+ m++;
+ }
+ if (sig(cross(a, b, p[i])) != sig(cross(a, b, p[i + 1]))) {
+ lineCross(a, b, p[i], p[i + 1], pp[m]);
+ m++;
+ }
+ }
+ n = 0;
+ for (int i = 0; i < m; i++) {
+ if (!i || !(point_same(pp[i], pp[i - 1]))) {
+ p[n] = pp[i];
+ n++;
+ }
+ }
+
+ while (n > 1 && point_same(p[n - 1], p[0])) n--;
+}
+
+__device__ inline double intersectArea(Point a, Point b, Point c, Point d) {
+ Point o(0, 0);
+ int s1 = sig(cross(o, a, b));
+ int s2 = sig(cross(o, c, d));
+ if (s1 == 0 || s2 == 0) return 0.0;
+ if (s1 == -1) {
+ Point* i = &a;
+ Point* j = &b;
+ swap1(i, j);
+ }
+ if (s2 == -1) {
+ Point* i = &c;
+ Point* j = &d;
+ swap1(i, j);
+ }
+ Point p[10] = {o, a, b};
+ int n = 3;
+
+ polygon_cut(p, n, o, c);
+ polygon_cut(p, n, c, d);
+ polygon_cut(p, n, d, o);
+ double res = area(p, n);
+ if (s1 * s2 == -1) res = -res;
+ return res;
+}
+__device__ inline double intersectAreaO(Point* ps1, int n1, Point* ps2,
+ int n2) {
+ if (area(ps1, n1) < 0) reverse1(ps1, n1);
+ if (area(ps2, n2) < 0) reverse1(ps2, n2);
+ ps1[n1] = ps1[0];
+ ps2[n2] = ps2[0];
+ double res = 0;
+ for (int i = 0; i < n1; i++) {
+ for (int j = 0; j < n2; j++) {
+ res += intersectArea(ps1[i], ps1[i + 1], ps2[j], ps2[j + 1]);
+ }
+ }
+ return res;
+}
+
+template
+__device__ inline float devrIoU(T const* const p, T const* const q) {
+ Point ps1[MAXN], ps2[MAXN];
+ Point convex[MAXN];
+ for (int i = 0; i < 9; i++) {
+ convex[i].x = (double)p[i * 2];
+ convex[i].y = (double)p[i * 2 + 1];
+ }
+ int n_convex = 9;
+ int points_to_convex_ind[9] = {-1, -1, -1, -1, -1, -1, -1, -1, -1};
+ Jarvis_and_index(convex, n_convex, points_to_convex_ind);
+ int n1 = n_convex;
+ for (int i = 0; i < n1; i++) {
+ ps1[i].x = (double)convex[i].x;
+ ps1[i].y = (double)convex[i].y;
+ }
+ int n2 = 4;
+ for (int i = 0; i < n2; i++) {
+ ps2[i].x = (double)q[i * 2];
+ ps2[i].y = (double)q[i * 2 + 1];
+ }
+ double inter_area = intersectAreaO(ps1, n1, ps2, n2);
+ double S_pred = area(ps1, n1);
+ double union_area = fabs(S_pred) + fabs(area(ps2, n2)) - inter_area;
+ double iou = inter_area / union_area;
+ return (float)iou;
+}
+
+template
+__global__ void convex_iou_cuda_kernel(const int ex_n_boxes,
+ const int gt_n_boxes, const T* ex_boxes,
+ const T* gt_boxes, T* iou) {
+ CUDA_1D_KERNEL_LOOP(index, ex_n_boxes) {
+ const T* cur_box = ex_boxes + index * 18;
+ for (int i = 0; i < gt_n_boxes; i++) {
+ iou[index * gt_n_boxes + i] = devrIoU(cur_box, gt_boxes + i * 8);
+ }
+ }
+}
+#endif // CONVEX_IOU_CUDA_KERNEL_CUH
diff --git a/mmcv/mmcv/ops/csrc/common/cuda/correlation_cuda.cuh b/mmcv/mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..2f7f112989127da235cb35476e15b206d4c2e3d4
--- /dev/null
+++ b/mmcv/mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
@@ -0,0 +1,225 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+// Modified from
+// https://github.com/ClementPinard/Pytorch-Correlation-extension/blob/master/Correlation_Module/correlation_cuda_kernel.cu
+// Original licence: Under MIT License
+
+#ifndef CORRELATION_CUDA
+#define CORRELATION_CUDA
+
+#ifdef MMCV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else
+#include "pytorch_cuda_helper.hpp"
+#endif
+
+#include
+#include
+// Using is recommended in the official documentation in
+// https://pytorch.org/tutorials/advanced/cpp_extension.html#writing-the-c-op.
+// However, we use for compatibility with CUDA 9.0
+// Read https://github.com/pytorch/extension-cpp/issues/35 for more details.
+#include
+
+#include
+#include
+
+using namespace torch;
+
+#define TensorAcc4R PackedTensorAccessor32
+#define TensorAcc5R PackedTensorAccessor32
+#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W)
+
+#define WARP_SIZE 32
+#define FULL_MASK 0xffffffff
+
+template
+__global__ void correlation_forward_cuda_kernel(
+ const TensorAcc4R rInput1, const TensorAcc4R rInput2, TensorAcc5R output,
+ int kH, int kW, int patchH, int patchW, int padH, int padW, int dilationH,
+ int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW) {
+ const int iH = rInput1.size(1);
+ const int iW = rInput1.size(2);
+ const int C = rInput1.size(3);
+
+ const int n = blockIdx.x;
+ const int h = blockIdx.y * blockDim.y + threadIdx.y;
+ const int w = blockIdx.z * blockDim.z + threadIdx.z;
+ const int thread = threadIdx.x;
+
+ const int start_i = -padH + h * dH;
+ const int start_j = -padW + w * dW;
+
+ const int patchRadH = dilation_patchH * (patchH - 1) / 2;
+ const int patchRadW = dilation_patchW * (patchW - 1) / 2;
+
+ for (int ph = 0; ph < patchH; ++ph) {
+ int ph_dilated = ph * dilation_patchH - patchRadH;
+ for (int pw = 0; pw < patchW; ++pw) {
+ int pw_dilated = pw * dilation_patchW - patchRadW;
+ scalar_t prod_sum = 0.0f;
+ for (int i = 0; i < kH; ++i) {
+ int i1 = start_i + i * dilationH;
+ int i2 = i1 + ph_dilated;
+ if
+ WITHIN_BOUNDS(i1, i2, iH, iH) {
+ for (int j = 0; j < kW; ++j) {
+ int j1 = start_j + j * dilationW;
+ int j2 = j1 + pw_dilated;
+ if
+ WITHIN_BOUNDS(j1, j2, iW, iW) {
+ for (int c = thread; c < C; c += WARP_SIZE) {
+ scalar_t v1 = rInput1[n][i1][j1][c];
+ scalar_t v2 = rInput2[n][i2][j2][c];
+ prod_sum += v1 * v2;
+ }
+ }
+ }
+ }
+ }
+ // accumulate
+ for (int offset = 16; offset > 0; offset /= 2)
+ prod_sum += __shfl_down_sync(FULL_MASK, float(prod_sum), offset);
+ if (thread == 0) {
+ output[n][ph][pw][h][w] = prod_sum;
+ }
+ }
+ }
+}
+
+template
+__global__ void correlation_backward_cuda_kernel_input1(
+ const TensorAcc5R grad_output, const TensorAcc4R input2,
+ TensorAcc4R grad_input1, const int kH, const int kW, const int patchH,
+ const int patchW, const int padH, const int padW, const int dilationH,
+ const int dilationW, const int dilation_patchH, const int dilation_patchW,
+ const int dH, const int dW) {
+ const int iH = input2.size(1);
+ const int iW = input2.size(2);
+ const int C = input2.size(3);
+
+ const int H = grad_output.size(3);
+ const int W = grad_output.size(4);
+
+ const int patchRadH = (patchH - 1) / 2;
+ const int patchRadW = (patchW - 1) / 2;
+
+ const int n = blockIdx.x;
+ const int h = blockIdx.y;
+ const int w = blockIdx.z;
+
+ const int h_2 = h + padH;
+ const int w_2 = w + padW;
+ const int min_h = h_2 - kH * dilationH;
+ const int min_w = w_2 - kW * dilationW;
+
+ extern __shared__ __align__(sizeof(4)) unsigned char grad_cache_char[];
+ scalar_t *grad_cache = reinterpret_cast(grad_cache_char);
+ for (int i = threadIdx.x; i < patchH * patchW; i += blockDim.x) {
+ const int ph = i / patchW;
+ const int pw = i % patchW;
+ int i1 = h + dilation_patchH * (ph - patchRadH);
+ int j1 = w + dilation_patchW * (pw - patchRadW);
+
+ if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
+ scalar_t grad_val = 0.0f;
+ for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
+ int i2 = (h_3) / dH;
+ if (i2 * dH != h_3) continue;
+ for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
+ int j2 = (w_3) / dW;
+ if (j2 * dW != w_3) continue;
+ if (WITHIN_BOUNDS(i2, j2, H, W)) {
+ grad_val += grad_output[n][ph][pw][i2][j2];
+ }
+ }
+ }
+ grad_cache[i] = grad_val;
+ }
+ }
+ __syncthreads();
+
+ for (int c = threadIdx.x; c < C; c += blockDim.x) {
+ scalar_t grad_input_val = 0.0f;
+ for (int ph = 0; ph < patchH; ++ph) {
+ int i1 = h + dilation_patchH * (ph - patchRadH);
+ for (int pw = 0; pw < patchW; ++pw) {
+ int j1 = w + dilation_patchW * (pw - patchRadW);
+ if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
+ grad_input_val += input2[n][i1][j1][c] * grad_cache[ph * patchW + pw];
+ }
+ }
+ }
+ grad_input1[n][c][h][w] = grad_input_val;
+ }
+}
+
+template
+__global__ void correlation_backward_cuda_kernel_input2(
+ const TensorAcc5R grad_output, const TensorAcc4R input1,
+ TensorAcc4R grad_input2, int kH, int kW, int patchH, int patchW, int padH,
+ int padW, int dilationH, int dilationW, int dilation_patchH,
+ int dilation_patchW, int dH, int dW) {
+ const int iH = input1.size(1);
+ const int iW = input1.size(2);
+ const int C = input1.size(3);
+
+ const int patchRadH = (patchH - 1) / 2;
+ const int patchRadW = (patchW - 1) / 2;
+
+ const int H = grad_output.size(3);
+ const int W = grad_output.size(4);
+
+ const int dilatedKH = kH * dilationH;
+ const int dilatedKW = kW * dilationW;
+
+ const int n = blockIdx.x;
+ const int h = blockIdx.y;
+ const int w = blockIdx.z;
+
+ extern __shared__ __align__(sizeof(4)) unsigned char grad_cache_char[];
+ scalar_t *grad_cache = reinterpret_cast(grad_cache_char);
+ for (int i = threadIdx.x; i < patchH * patchW; i += blockDim.x) {
+ const int ph = i / patchW;
+ const int pw = i % patchW;
+ int i1 = h - dilation_patchH * (ph - patchRadH);
+ int j1 = w - dilation_patchW * (pw - patchRadW);
+
+ if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
+ scalar_t grad_val = 0.0f;
+
+ const int h_2 = i1 + padH;
+ const int w_2 = j1 + padW;
+ const int min_h = h_2 - dilatedKH;
+ const int min_w = w_2 - dilatedKW;
+
+ for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
+ int i2 = (h_3) / dH;
+ if (i2 * dH != h_3) continue;
+ for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
+ int j2 = (w_3) / dW;
+ if (j2 * dW != w_3) continue;
+ if (WITHIN_BOUNDS(i2, j2, H, W)) {
+ grad_val += grad_output[n][ph][pw][i2][j2];
+ }
+ }
+ }
+ grad_cache[i] = grad_val;
+ }
+ }
+ __syncthreads();
+
+ for (int c = threadIdx.x; c < C; c += blockDim.x) {
+ scalar_t grad_input_val = 0.0f;
+ for (int ph = 0; ph < patchH; ++ph) {
+ int i1 = h - dilation_patchH * (ph - patchRadH);
+ for (int pw = 0; pw < patchW; ++pw) {
+ int j1 = w - dilation_patchW * (pw - patchRadW);
+ if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
+ grad_input_val += input1[n][i1][j1][c] * grad_cache[ph * patchW + pw];
+ }
+ }
+ }
+ grad_input2[n][c][h][w] = grad_input_val;
+ }
+}
+#endif
diff --git a/mmcv/mmcv/ops/csrc/common/cuda/deform_conv_cuda_kernel.cuh b/mmcv/mmcv/ops/csrc/common/cuda/deform_conv_cuda_kernel.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..6b4d1bbd85bad1b87ee5d6b8a3cd3b29e3cbc411
--- /dev/null
+++ b/mmcv/mmcv/ops/csrc/common/cuda/deform_conv_cuda_kernel.cuh
@@ -0,0 +1,367 @@
+/*!
+ ******************* BEGIN Caffe Copyright Notice and Disclaimer
+ *****************
+ *
+ * COPYRIGHT
+ *
+ * All contributions by the University of California:
+ * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+ * All rights reserved.
+ *
+ * All other contributions:
+ * Copyright (c) 2014-2017, the respective contributors
+ * All rights reserved.
+ *
+ * Caffe uses a shared copyright model: each contributor holds copyright over
+ * their contributions to Caffe. The project versioning records all such
+ * contribution and copyright details. If a contributor wants to further mark
+ * their specific copyright on a particular contribution, they should indicate
+ * their copyright solely in the commit message of the change when it is
+ * committed.
+ *
+ * LICENSE
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice,
+ *this list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
+ *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * CONTRIBUTION AGREEMENT
+ *
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
+ * or otherwise, the contributor releases their content to the
+ * license and copyright terms herein.
+ *
+ ***************** END Caffe Copyright Notice and Disclaimer
+ *********************
+ *
+ * Copyright (c) 2018 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file modulated_deformable_im2col.cuh
+ * \brief Function definitions of converting an image to
+ * column matrix based on kernel, padding, dilation, and offset.
+ * These functions are mainly used in deformable convolution operators.
+ * \ref: https://arxiv.org/abs/1703.06211
+ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
+ */
+
+// modified from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
+
+#ifndef DEFORM_CONV_CUDA_KERNEL_CUH
+#define DEFORM_CONV_CUDA_KERNEL_CUH
+
+#include
+#ifdef MMCV_WITH_TRT
+#include "common_cuda_helper.hpp"
+#else // MMCV_WITH_TRT
+#ifdef MMCV_USE_PARROTS
+#include "parrots_cuda_helper.hpp"
+#else // MMCV_USE_PARROTS
+#include "pytorch_cuda_helper.hpp"
+#endif // MMCV_USE_PARROTS
+#endif // MMCV_WITH_TRT
+
+template
+__device__ T deformable_im2col_bilinear(const T *input, const int data_width,
+ const int height, const int width, T h,
+ T w) {
+ if (h <= -1 || height <= h || w <= -1 || width <= w) {
+ return 0;
+ }
+
+ int h_low = floorf(h);
+ int w_low = floorf(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ T lh = h - h_low;
+ T lw = w - w_low;
+ T hh = 1 - lh, hw = 1 - lw;
+
+ T v1 = 0;
+ if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low];
+ T v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = input[h_low * data_width + w_high];
+ T v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = input[h_high * data_width + w_low];
+ T v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = input[h_high * data_width + w_high];
+
+ T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ T get_gradient_weight(T argmax_h, T argmax_w, const int h,
+ const int w, const int height,
+ const int width) {
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
+ argmax_w >= width) {
+ // empty
+ return 0;
+ }
+
+ int argmax_h_low = floorf(argmax_h);
+ int argmax_w_low = floorf(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ T weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+template
+__device__ T get_coordinate_weight(T argmax_h, T argmax_w, const int height,
+ const int width, const T *im_data,
+ const int data_width, const int bp_dir) {
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
+ argmax_w >= width) {
+ // empty
+ return 0;
+ }
+
+ int argmax_h_low = floorf(argmax_h);
+ int argmax_w_low = floorf(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ T weight = 0;
+
+ if (bp_dir == 0) {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) *
+ im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) *
+ im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) *
+ im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) *
+ im_data[argmax_h_high * data_width + argmax_w_high];
+ } else if (bp_dir == 1) {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) *
+ im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) *
+ im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) *
+ im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) *
+ im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+template
+__global__ void deformable_im2col_gpu_kernel(
+ const int n, const T *data_im, const T *data_offset, const int height,
+ const int width, const int kernel_h, const int kernel_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group, const int batch_size,
+ const int num_channels, const int deformable_group, const int height_col,
+ const int width_col, T *data_col) {
+ CUDA_1D_KERNEL_LOOP(index, n) {
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ const int b_col = (index / width_col / height_col) % batch_size;
+ const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+ T *data_col_ptr =
+ data_col +
+ ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ const T *data_im_ptr =
+ data_im + (b_col * num_channels + c_im) * height * width;
+ const T *data_offset_ptr =
+ data_offset + (b_col * deformable_group + deformable_group_index) * 2 *
+ kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i) {
+ for (int j = 0; j < kernel_w; ++j) {
+ const int data_offset_h_ptr =
+ ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr =
+ ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
+ w_col;
+ const T offset_h = data_offset_ptr[data_offset_h_ptr];
+ const T offset_w = data_offset_ptr[data_offset_w_ptr];
+ T val = static_cast(0);
+ const T h_im = h_in + i * dilation_h + offset_h;
+ const T w_im = w_in + j * dilation_w + offset_w;
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ val = deformable_im2col_bilinear(data_im_ptr, width, height, width,
+ h_im, w_im);
+ *data_col_ptr = val;
+ data_col_ptr += batch_size * height_col * width_col;
+ }
+ }
+ }
+}
+
+template