diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f716b4f112b65a121a616dc2e6121e041b6fd564 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +**/*.pyc +.idea/ +__pycache__/ + +deps/ +datasets/ +experiments_t2m/ +experiments_t2m_test/ +experiments_control/ +experiments_control_test/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f03fbb7cd30918f38ed7bca1b67b752db82d9359 --- /dev/null +++ b/LICENSE @@ -0,0 +1,25 @@ +Copyright Tsinghua University and Shanghai AI Laboratory. All Rights Reserved. + +License for Non-commercial Scientific Research Purposes. + +For more information see . +If you use this software, please cite the corresponding publications +listed on the above website. + +Permission to use, copy, modify, and distribute this software and its +documentation for educational, research, and non-profit purposes only. +Any modification based on this work must be open-source and prohibited +for commercial, pornographic, military, or surveillance use. + +The authors grant you a non-exclusive, worldwide, non-transferable, +non-sublicensable, revocable, royalty-free, and limited license under +our copyright interests to reproduce, distribute, and create derivative +works of the text, videos, and codes solely for your non-commercial +research purposes. + +You must retain, in the source form of any derivative works that you +distribute, all copyright, patent, trademark, and attribution notices +from the source form of this work. + +For commercial uses of this software, please send email to all people +in the author list. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..510d1ffc7a7f77cb400b170e9984fe75243ffa13 --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +--- +title: MotionLCM +emoji: 🏃 +colorFrom: yellow +colorTo: gray +sdk: gradio +sdk_version: 3.24.1 +app_file: app.py +pinned: false +license: other +python_version: 3.10.12 +--- + +Check out the configuration reference at https://huggingface.co./docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..85719e76a128affff56be86a04ed56c815dd58e1 --- /dev/null +++ b/app.py @@ -0,0 +1,258 @@ +import os +import time +import random +import datetime +import os.path as osp +from functools import partial + +import torch +import gradio as gr +from omegaconf import OmegaConf + +from mld.config import get_module_config +from mld.data.get_data import get_datasets +from mld.models.modeltype.mld import MLD +from mld.utils.utils import set_seed +from mld.data.humanml.utils.plot_script import plot_3d_motion + +WEBSITE = """ +
+

MotionLCM: Real-time Controllable Motion Generation via Latent Consistency Model

+ +

+Wenxun Dai1   +Ling-Hao Chen1   +Jingbo Wang2   +Jinpeng Liu1   +Bo Dai2   +Yansong Tang1 +

+ +

+1Tsinghua University   +2Shanghai AI Laboratory +

+ +
+""" + +WEBSITE_bottom = """ +
+

+Space adapted from TMR +and MoMask. +

+
+""" + +EXAMPLES = [ + "a person does a jump", + "a person waves both arms in the air.", + "The person takes 4 steps backwards.", + "this person bends forward as if to bow.", + "The person was pushed but did not fall.", + "a man walks forward in a snake like pattern.", + "a man paces back and forth along the same line.", + "with arms out to the sides a person walks forward", + "A man bends down and picks something up with his right hand.", + "The man walked forward, spun right on one foot and walked back to his original position.", + "a person slightly bent over with right hand pressing against the air walks forward slowly" +] + +CSS = """ +.contour_video { + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; + z-index: var(--layer-5); + border-radius: var(--block-radius); + background: var(--background-fill-primary); + padding: 0 var(--size-6); + max-height: var(--size-screen-h); + overflow: hidden; +} +""" + +if not os.path.exists("./experiments_t2m/"): + os.system("bash prepare/download_pretrained_models.sh") +if not os.path.exists('./deps/glove/'): + os.system("bash prepare/download_glove.sh") +if not os.path.exists('./deps/sentence-t5-large/'): + os.system("bash prepare/prepare_t5.sh") +if not os.path.exists('./deps/t2m/'): + os.system("bash prepare/download_t2m_evaluators.sh") +if not os.path.exists('./datasets/humanml3d/'): + os.system("bash prepare/prepare_tiny_humanml3d.sh") + +DEFAULT_TEXT = "A person is " +MAX_VIDEOS = 12 +T2M_CFG = "./configs/motionlcm_t2m.yaml" + +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +cfg = OmegaConf.load(T2M_CFG) +cfg_model = get_module_config(cfg.model, cfg.model.target) +cfg = OmegaConf.merge(cfg, cfg_model) +set_seed(1949) + +name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) +output_dir = osp.join(cfg.TEST_FOLDER, name_time_str) +vis_dir = osp.join(output_dir, 'samples') +os.makedirs(output_dir, exist_ok=False) +os.makedirs(vis_dir, exist_ok=False) + +state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"] +print("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS)) + +lcm_key = 'denoiser.time_embedding.cond_proj.weight' +is_lcm = False +if lcm_key in state_dict: + is_lcm = True + time_cond_proj_dim = state_dict[lcm_key].shape[1] + cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim +print(f'Is LCM: {is_lcm}') + +cfg.model.is_controlnet = False + +datasets = get_datasets(cfg, phase="test")[0] +model = MLD(cfg, datasets) +model.to(device) +model.eval() +model.load_state_dict(state_dict) + + +@torch.no_grad() +def generate(text, motion_len, num_videos): + batch = {"text": [text] * num_videos, "length": [motion_len] * num_videos} + + s = time.time() + joints, _ = model(batch) + runtime = round(time.time() - s, 3) + runtime_info = f'Inference {len(joints)} motions, runtime: {runtime}s, device: {device}' + path = [] + for i in range(num_videos): + uid = random.randrange(999999999) + video_path = osp.join(vis_dir, f"sample_{uid}.mp4") + plot_3d_motion(video_path, joints[i].detach().cpu().numpy(), '', fps=20) + path.append(video_path) + return path, runtime_info + + +# HTML component +def get_video_html(path, video_id, width=700, height=700): + video_html = f""" + +""" + return video_html + + +def generate_component(generate_function, text, motion_len, num_videos): + if text == DEFAULT_TEXT or text == "" or text is None: + return [None for _ in range(MAX_VIDEOS)] + [None] + + motion_len = max(36, min(int(float(motion_len) * 20), 196)) + paths, info = generate_function(text, motion_len, num_videos) + htmls = [get_video_html(path, idx) for idx, path in enumerate(paths)] + htmls = htmls + [None for _ in range(max(0, MAX_VIDEOS - num_videos))] + return htmls + [info] + + +theme = gr.themes.Default(primary_hue="purple", secondary_hue="gray") +generate_and_show = partial(generate_component, generate) + +with gr.Blocks(css=CSS, theme=theme) as demo: + gr.HTML(WEBSITE) + videos = [] + + with gr.Row(): + with gr.Column(scale=3): + text = gr.Textbox( + show_label=True, + label="Text prompt", + value=DEFAULT_TEXT, + ) + + with gr.Row(): + with gr.Column(scale=1): + motion_len = gr.Textbox( + show_label=True, + label="Motion length (in seconds, <=9.8s)", + value=5, + info="Any length exceeding 9.8s will be restricted to 9.8s.", + ) + with gr.Column(scale=1): + num_videos = gr.Radio( + [1, 4, 8, 12], + label="Videos", + value=8, + info="Number of videos to generate.", + ) + + gen_btn = gr.Button("Generate", variant="primary") + clear = gr.Button("Clear", variant="secondary") + + results = gr.Textbox(show_label=True, + label='Inference info (runtime and device)', + info='Real-time inference cannot be achieved using the free CPU. Local GPU deployment is recommended.', + interactive=False) + + with gr.Column(scale=2): + def generate_example(text, motion_len, num_videos): + return generate_and_show(text, motion_len, num_videos) + + examples = gr.Examples( + examples=[[x, None, None] for x in EXAMPLES], + inputs=[text, motion_len, num_videos], + examples_per_page=12, + run_on_click=False, + cache_examples=False, + fn=generate_example, + outputs=[], + ) + + for _ in range(3): + with gr.Row(): + for _ in range(4): + video = gr.HTML() + videos.append(video) + + # gr.HTML(WEBSITE_bottom) + # connect the examples to the output + # a bit hacky + examples.outputs = videos + + def load_example(example_id): + processed_example = examples.non_none_processed_examples[example_id] + return gr.utils.resolve_singleton(processed_example) + + examples.dataset.click( + load_example, + inputs=[examples.dataset], + outputs=examples.inputs_with_examples, # type: ignore + show_progress=False, + postprocess=False, + queue=False, + ).then(fn=generate_example, inputs=examples.inputs, outputs=videos + [results]) + + gen_btn.click( + fn=generate_and_show, + inputs=[text, motion_len, num_videos], + outputs=videos + [results], + ) + text.submit( + fn=generate_and_show, + inputs=[text, motion_len, num_videos], + outputs=videos + [results], + ) + + def clear_videos(): + return [None for _ in range(MAX_VIDEOS)] + [DEFAULT_TEXT] + [None] + + clear.click(fn=clear_videos, outputs=videos + [text] + [results]) + +demo.launch() diff --git a/configs/mld_control.yaml b/configs/mld_control.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fc076974a2275b4e4f328fb704b6fec95c7a56c4 --- /dev/null +++ b/configs/mld_control.yaml @@ -0,0 +1,105 @@ +FOLDER: './experiments_control' +TEST_FOLDER: './experiments_control_test' + +NAME: 'mld_humanml' + +TRAIN: + DATASETS: ['humanml3d'] + BATCH_SIZE: 128 + SPLIT: 'train' + NUM_WORKERS: 8 + PERSISTENT_WORKERS: true + SEED_VALUE: 1234 + PRETRAINED: 'experiments_t2m/mld_humanml/mld_humanml.ckpt' + + validation_steps: -1 + validation_epochs: 50 + checkpointing_steps: -1 + checkpointing_epochs: 50 + max_train_steps: -1 + max_train_epochs: 1000 + learning_rate: 1e-4 + learning_rate_spatial: 1e-4 + lr_scheduler: "cosine" + lr_warmup_steps: 1000 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 0.0 + adam_epsilon: 1e-08 + max_grad_norm: 1.0 + +EVAL: + DATASETS: ['humanml3d'] + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + +TEST: + DATASETS: ['humanml3d'] + BATCH_SIZE: 1 + SPLIT: 'test' + NUM_WORKERS: 12 + + CHECKPOINTS: 'experiments_control/mld_humanml/mld_humanml.ckpt' + + # Testing Args + REPLICATION_TIMES: 1 + MM_NUM_SAMPLES: 100 + MM_NUM_REPEATS: 30 + MM_NUM_TIMES: 10 + DIVERSITY_TIMES: 300 + MAX_NUM_SAMPLES: 1024 + +DATASET: + SMPL_PATH: './deps/smpl' + WORD_VERTILIZER_PATH: './deps/glove/' + HUMANML3D: + PICK_ONE_TEXT: true + FRAME_RATE: 20.0 + UNIT_LEN: 4 + ROOT: './datasets/humanml3d' + SPLIT_ROOT: './datasets/humanml3d' + SAMPLER: + MAX_LEN: 196 + MIN_LEN: 40 + MAX_TEXT_LEN: 20 + +METRIC: + DIST_SYNC_ON_STEP: true + TYPE: ['TM2TMetrics', 'ControlMetrics'] + +model: + target: 'modules_mld' + latent_dim: [1, 256] + guidance_scale: 7.5 + guidance_uncondp: 0.1 + + # ControlNet Args + is_controlnet: true + is_controlnet_temporal: false + training_control_joint: [0] + testing_control_joint: [0] + training_density: 'random' + testing_density: 100 + control_scale: 1.0 + vaeloss: true + vaeloss_type: 'sum' + cond_ratio: 1.0 + rot_ratio: 0.0 + + t2m_textencoder: + dim_word: 300 + dim_pos_ohot: 15 + dim_text_hidden: 512 + dim_coemb_hidden: 512 + + t2m_motionencoder: + dim_move_hidden: 512 + dim_move_latent: 512 + dim_motion_hidden: 1024 + dim_motion_latent: 512 + + bert_path: './deps/distilbert-base-uncased' + clip_path: './deps/clip-vit-large-patch14' + t5_path: './deps/sentence-t5-large' + t2m_path: './deps/t2m/' diff --git a/configs/mld_t2m_infer.yaml b/configs/mld_t2m_infer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..22d533c3b3d6eff17bffd7f16f43a8d81c21bfcd --- /dev/null +++ b/configs/mld_t2m_infer.yaml @@ -0,0 +1,72 @@ +FOLDER: './experiments_t2m' +TEST_FOLDER: './experiments_t2m_test' + +NAME: 'mld_humanml' + +TRAIN: + DATASETS: ['humanml3d'] + BATCH_SIZE: 1 + NUM_WORKERS: 8 + PERSISTENT_WORKERS: true + SEED_VALUE: 1234 + +EVAL: + DATASETS: ['humanml3d'] + BATCH_SIZE: 32 + SPLIT: test + NUM_WORKERS: 12 + +TEST: + DATASETS: ['humanml3d'] + SPLIT: test + BATCH_SIZE: 1 + NUM_WORKERS: 12 + + CHECKPOINTS: 'experiments_t2m/mld_humanml/mld_humanml.ckpt' + + # Testing Args + REPLICATION_TIMES: 20 + MM_NUM_SAMPLES: 100 + MM_NUM_REPEATS: 30 + MM_NUM_TIMES: 10 + DIVERSITY_TIMES: 300 + +DATASET: + SMPL_PATH: './deps/smpl' + WORD_VERTILIZER_PATH: './deps/glove/' + HUMANML3D: + PICK_ONE_TEXT: true + FRAME_RATE: 20.0 + UNIT_LEN: 4 + ROOT: './datasets/humanml3d' + SPLIT_ROOT: './datasets/humanml3d' + SAMPLER: + MAX_LEN: 196 + MIN_LEN: 40 + MAX_TEXT_LEN: 20 + +METRIC: + DIST_SYNC_ON_STEP: True + TYPE: ['TM2TMetrics'] + +model: + target: 'modules_mld' + latent_dim: [1, 256] + guidance_scale: 7.5 + + t2m_textencoder: + dim_word: 300 + dim_pos_ohot: 15 + dim_text_hidden: 512 + dim_coemb_hidden: 512 + + t2m_motionencoder: + dim_move_hidden: 512 + dim_move_latent: 512 + dim_motion_hidden: 1024 + dim_motion_latent: 512 + + bert_path: './deps/distilbert-base-uncased' + clip_path: './deps/clip-vit-large-patch14' + t5_path: './deps/sentence-t5-large' + t2m_path: './deps/t2m/' diff --git a/configs/modules/denoiser.yaml b/configs/modules/denoiser.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1a70ed3f67b95d79c6245a3cb2ecfdd6bdfb214e --- /dev/null +++ b/configs/modules/denoiser.yaml @@ -0,0 +1,16 @@ +denoiser: + target: mld.models.architectures.mld_denoiser.MldDenoiser + params: + text_encoded_dim: 768 + ff_size: 1024 + num_layers: 9 + num_heads: 4 + dropout: 0.1 + normalize_before: false + activation: 'gelu' + flip_sin_to_cos: true + return_intermediate_dec: false + position_embedding: 'learned' + arch: 'trans_enc' + freq_shift: 0 + latent_dim: ${model.latent_dim} diff --git a/configs/modules/motion_vae.yaml b/configs/modules/motion_vae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..43ba84ba7e89b5d8314276d935d0040163623612 --- /dev/null +++ b/configs/modules/motion_vae.yaml @@ -0,0 +1,13 @@ +motion_vae: + target: mld.models.architectures.mld_vae.MldVae + params: + arch: 'encoder_decoder' + ff_size: 1024 + num_layers: 9 + num_heads: 4 + dropout: 0.1 + normalize_before: false + activation: 'gelu' + position_embedding: 'learned' + latent_dim: ${model.latent_dim} + nfeats: ${DATASET.NFEATS} diff --git a/configs/modules/scheduler.yaml b/configs/modules/scheduler.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e1ef19aa4f19070f086cfb2feaee4effd085568 --- /dev/null +++ b/configs/modules/scheduler.yaml @@ -0,0 +1,20 @@ +scheduler: + target: diffusers.LCMScheduler + num_inference_timesteps: 1 + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: 'scaled_linear' + clip_sample: false + set_alpha_to_one: false + +noise_scheduler: + target: diffusers.DDPMScheduler + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: 'scaled_linear' + variance_type: 'fixed_small' + clip_sample: false diff --git a/configs/modules/text_encoder.yaml b/configs/modules/text_encoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..edd739430f38f25078fce73c0ec1d350bc6c4d39 --- /dev/null +++ b/configs/modules/text_encoder.yaml @@ -0,0 +1,5 @@ +text_encoder: + target: mld.models.architectures.mld_clip.MldTextEncoder + params: + last_hidden_state: false # if true, the last hidden state is used as the text embedding + modelpath: ${model.t5_path} diff --git a/configs/modules/traj_encoder.yaml b/configs/modules/traj_encoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9c2a3eded617a78cc397a96fedcd2438200a948f --- /dev/null +++ b/configs/modules/traj_encoder.yaml @@ -0,0 +1,12 @@ +traj_encoder: + target: mld.models.architectures.mld_traj_encoder.MldTrajEncoder + params: + ff_size: 1024 + num_layers: 9 + num_heads: 4 + dropout: 0.1 + normalize_before: false + activation: 'gelu' + position_embedding: 'learned' + latent_dim: ${model.latent_dim} + nfeats: ${DATASET.NJOINTS} diff --git a/configs/modules_mld/denoiser.yaml b/configs/modules_mld/denoiser.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1a70ed3f67b95d79c6245a3cb2ecfdd6bdfb214e --- /dev/null +++ b/configs/modules_mld/denoiser.yaml @@ -0,0 +1,16 @@ +denoiser: + target: mld.models.architectures.mld_denoiser.MldDenoiser + params: + text_encoded_dim: 768 + ff_size: 1024 + num_layers: 9 + num_heads: 4 + dropout: 0.1 + normalize_before: false + activation: 'gelu' + flip_sin_to_cos: true + return_intermediate_dec: false + position_embedding: 'learned' + arch: 'trans_enc' + freq_shift: 0 + latent_dim: ${model.latent_dim} diff --git a/configs/modules_mld/motion_vae.yaml b/configs/modules_mld/motion_vae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..43ba84ba7e89b5d8314276d935d0040163623612 --- /dev/null +++ b/configs/modules_mld/motion_vae.yaml @@ -0,0 +1,13 @@ +motion_vae: + target: mld.models.architectures.mld_vae.MldVae + params: + arch: 'encoder_decoder' + ff_size: 1024 + num_layers: 9 + num_heads: 4 + dropout: 0.1 + normalize_before: false + activation: 'gelu' + position_embedding: 'learned' + latent_dim: ${model.latent_dim} + nfeats: ${DATASET.NFEATS} diff --git a/configs/modules_mld/scheduler.yaml b/configs/modules_mld/scheduler.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5a1f4cccbec9c64450bc387cf7e3c718923ff149 --- /dev/null +++ b/configs/modules_mld/scheduler.yaml @@ -0,0 +1,23 @@ +scheduler: + target: diffusers.DDIMScheduler + num_inference_timesteps: 50 + eta: 0.0 + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: 'scaled_linear' + clip_sample: false + # below are for ddim + set_alpha_to_one: false + steps_offset: 1 + +noise_scheduler: + target: diffusers.DDPMScheduler + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: 'scaled_linear' + variance_type: 'fixed_small' + clip_sample: false diff --git a/configs/modules_mld/text_encoder.yaml b/configs/modules_mld/text_encoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..edd739430f38f25078fce73c0ec1d350bc6c4d39 --- /dev/null +++ b/configs/modules_mld/text_encoder.yaml @@ -0,0 +1,5 @@ +text_encoder: + target: mld.models.architectures.mld_clip.MldTextEncoder + params: + last_hidden_state: false # if true, the last hidden state is used as the text embedding + modelpath: ${model.t5_path} diff --git a/configs/modules_mld/traj_encoder.yaml b/configs/modules_mld/traj_encoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9c2a3eded617a78cc397a96fedcd2438200a948f --- /dev/null +++ b/configs/modules_mld/traj_encoder.yaml @@ -0,0 +1,12 @@ +traj_encoder: + target: mld.models.architectures.mld_traj_encoder.MldTrajEncoder + params: + ff_size: 1024 + num_layers: 9 + num_heads: 4 + dropout: 0.1 + normalize_before: false + activation: 'gelu' + position_embedding: 'learned' + latent_dim: ${model.latent_dim} + nfeats: ${DATASET.NJOINTS} diff --git a/configs/motionlcm_control.yaml b/configs/motionlcm_control.yaml new file mode 100644 index 0000000000000000000000000000000000000000..63220d03f36f8eb14586f31bcfec67cbedf06aba --- /dev/null +++ b/configs/motionlcm_control.yaml @@ -0,0 +1,105 @@ +FOLDER: './experiments_control' +TEST_FOLDER: './experiments_control_test' + +NAME: 'motionlcm_humanml' + +TRAIN: + DATASETS: ['humanml3d'] + BATCH_SIZE: 128 + SPLIT: 'train' + NUM_WORKERS: 8 + PERSISTENT_WORKERS: true + SEED_VALUE: 1234 + PRETRAINED: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt' + + validation_steps: -1 + validation_epochs: 50 + checkpointing_steps: -1 + checkpointing_epochs: 50 + max_train_steps: -1 + max_train_epochs: 1000 + learning_rate: 1e-4 + learning_rate_spatial: 1e-4 + lr_scheduler: "cosine" + lr_warmup_steps: 1000 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 0.0 + adam_epsilon: 1e-08 + max_grad_norm: 1.0 + +EVAL: + DATASETS: ['humanml3d'] + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + +TEST: + DATASETS: ['humanml3d'] + BATCH_SIZE: 1 + SPLIT: 'test' + NUM_WORKERS: 12 + + CHECKPOINTS: 'experiments_control/motionlcm_humanml/motionlcm_humanml.ckpt' + + # Testing Args + REPLICATION_TIMES: 1 + MM_NUM_SAMPLES: 100 + MM_NUM_REPEATS: 30 + MM_NUM_TIMES: 10 + DIVERSITY_TIMES: 300 + MAX_NUM_SAMPLES: 1024 + +DATASET: + SMPL_PATH: './deps/smpl' + WORD_VERTILIZER_PATH: './deps/glove/' + HUMANML3D: + PICK_ONE_TEXT: true + FRAME_RATE: 20.0 + UNIT_LEN: 4 + ROOT: './datasets/humanml3d' + SPLIT_ROOT: './datasets/humanml3d' + SAMPLER: + MAX_LEN: 196 + MIN_LEN: 40 + MAX_TEXT_LEN: 20 + +METRIC: + DIST_SYNC_ON_STEP: true + TYPE: ['TM2TMetrics', 'ControlMetrics'] + +model: + target: 'modules' + latent_dim: [1, 256] + guidance_scale: 7.5 + guidance_uncondp: 0.0 + + # ControlNet Args + is_controlnet: true + is_controlnet_temporal: false + training_control_joint: [0] + testing_control_joint: [0] + training_density: 'random' + testing_density: 100 + control_scale: 1.0 + vaeloss: true + vaeloss_type: 'sum' + cond_ratio: 1.0 + rot_ratio: 0.0 + + t2m_textencoder: + dim_word: 300 + dim_pos_ohot: 15 + dim_text_hidden: 512 + dim_coemb_hidden: 512 + + t2m_motionencoder: + dim_move_hidden: 512 + dim_move_latent: 512 + dim_motion_hidden: 1024 + dim_motion_latent: 512 + + bert_path: './deps/distilbert-base-uncased' + clip_path: './deps/clip-vit-large-patch14' + t5_path: './deps/sentence-t5-large' + t2m_path: './deps/t2m/' diff --git a/configs/motionlcm_t2m.yaml b/configs/motionlcm_t2m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..99076b34dee6fddf8f6aaf7882ae03963c910528 --- /dev/null +++ b/configs/motionlcm_t2m.yaml @@ -0,0 +1,100 @@ +FOLDER: './experiments_t2m' +TEST_FOLDER: './experiments_t2m_test' + +NAME: 'motionlcm_humanml' + +TRAIN: + DATASETS: ['humanml3d'] + BATCH_SIZE: 256 + SPLIT: 'train' + NUM_WORKERS: 8 + PERSISTENT_WORKERS: true + SEED_VALUE: 1234 + PRETRAINED: 'experiments_t2m/mld_humanml/mld_humanml.ckpt' + + validation_steps: -1 + validation_epochs: 50 + checkpointing_steps: -1 + checkpointing_epochs: 50 + max_train_steps: -1 + max_train_epochs: 1000 + learning_rate: 2e-4 + lr_scheduler: "cosine" + lr_warmup_steps: 1000 + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 0.0 + adam_epsilon: 1e-08 + max_grad_norm: 1.0 + + # Latent Consistency Distillation Specific Arguments + w_min: 5.0 + w_max: 15.0 + num_ddim_timesteps: 50 + loss_type: 'huber' + huber_c: 0.001 + unet_time_cond_proj_dim: 256 + ema_decay: 0.95 + +EVAL: + DATASETS: ['humanml3d'] + BATCH_SIZE: 32 + SPLIT: 'test' + NUM_WORKERS: 12 + +TEST: + DATASETS: ['humanml3d'] + BATCH_SIZE: 1 + SPLIT: 'test' + NUM_WORKERS: 12 + + CHECKPOINTS: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt' + + # Testing Args + REPLICATION_TIMES: 20 + MM_NUM_SAMPLES: 100 + MM_NUM_REPEATS: 30 + MM_NUM_TIMES: 10 + DIVERSITY_TIMES: 300 + +DATASET: + SMPL_PATH: './deps/smpl' + WORD_VERTILIZER_PATH: './deps/glove/' + HUMANML3D: + PICK_ONE_TEXT: true + FRAME_RATE: 20.0 + UNIT_LEN: 4 + ROOT: './datasets/humanml3d' + SPLIT_ROOT: './datasets/humanml3d' + SAMPLER: + MAX_LEN: 196 + MIN_LEN: 40 + MAX_TEXT_LEN: 20 + +METRIC: + DIST_SYNC_ON_STEP: true + TYPE: ['TM2TMetrics'] + +model: + target: 'modules' + latent_dim: [1, 256] + guidance_scale: 7.5 + guidance_uncondp: 0.0 + is_controlnet: false + + t2m_textencoder: + dim_word: 300 + dim_pos_ohot: 15 + dim_text_hidden: 512 + dim_coemb_hidden: 512 + + t2m_motionencoder: + dim_move_hidden: 512 + dim_move_latent: 512 + dim_motion_hidden: 1024 + dim_motion_latent: 512 + + bert_path: './deps/distilbert-base-uncased' + clip_path: './deps/clip-vit-large-patch14' + t5_path: './deps/sentence-t5-large' + t2m_path: './deps/t2m/' diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..62424598163db92e9ff0a3817d5242d4c2fc0dd0 --- /dev/null +++ b/demo.py @@ -0,0 +1,154 @@ +import os +import pickle +import sys +import datetime +import logging +import os.path as osp + +from omegaconf import OmegaConf + +import torch + +from mld.config import parse_args +from mld.data.get_data import get_datasets +from mld.models.modeltype.mld import MLD +from mld.utils.utils import set_seed, move_batch_to_device +from mld.data.humanml.utils.plot_script import plot_3d_motion +from mld.utils.temos_utils import remove_padding + + +def load_example_input(text_path: str) -> tuple: + with open(text_path, "r") as f: + lines = f.readlines() + + count = 0 + texts, lens = [], [] + # Strips the newline character + for line in lines: + count += 1 + s = line.strip() + s_l = s.split(" ")[0] + s_t = s[(len(s_l) + 1):] + lens.append(int(s_l)) + texts.append(s_t) + return texts, lens + + +def main(): + cfg = parse_args() + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + set_seed(cfg.TRAIN.SEED_VALUE) + + name_time_str = osp.join(cfg.NAME, "demo_" + datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) + output_dir = osp.join(cfg.TEST_FOLDER, name_time_str) + vis_dir = osp.join(output_dir, 'samples') + os.makedirs(output_dir, exist_ok=False) + os.makedirs(vis_dir, exist_ok=False) + + steam_handler = logging.StreamHandler(sys.stdout) + file_handler = logging.FileHandler(osp.join(output_dir, 'output.log')) + logging.basicConfig(level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[steam_handler, file_handler]) + logger = logging.getLogger(__name__) + + OmegaConf.save(cfg, osp.join(output_dir, 'config.yaml')) + + state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"] + logger.info("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS)) + + lcm_key = 'denoiser.time_embedding.cond_proj.weight' + is_lcm = False + if lcm_key in state_dict: + is_lcm = True + time_cond_proj_dim = state_dict[lcm_key].shape[1] + cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim + logger.info(f'Is LCM: {is_lcm}') + + cn_key = "controlnet.controlnet_cond_embedding.0.weight" + is_controlnet = True if cn_key in state_dict else False + cfg.model.is_controlnet = is_controlnet + logger.info(f'Is Controlnet: {is_controlnet}') + + datasets = get_datasets(cfg, phase="test")[0] + model = MLD(cfg, datasets) + model.to(device) + model.eval() + model.load_state_dict(state_dict) + + # example only support text-to-motion + if cfg.example is not None and not is_controlnet: + text, length = load_example_input(cfg.example) + for t, l in zip(text, length): + logger.info(f"{l}: {t}") + + batch = {"length": length, "text": text} + + for rep_i in range(cfg.replication): + with torch.no_grad(): + joints, _ = model(batch) + + num_samples = len(joints) + batch_id = 0 + for i in range(num_samples): + res = dict() + pkl_path = osp.join(vis_dir, f"batch_id_{batch_id}_sample_id_{i}_length_{length[i]}_rep_{rep_i}.pkl") + res['joints'] = joints[i].detach().cpu().numpy() + res['text'] = text[i] + res['length'] = length[i] + res['hint'] = None + with open(pkl_path, 'wb') as f: + pickle.dump(res, f) + logger.info(f"Motions are generated here:\n{pkl_path}") + + if not cfg.no_plot: + plot_3d_motion(pkl_path.replace('.pkl', '.mp4'), joints[i].detach().cpu().numpy(), text[i], fps=20) + + else: + test_dataloader = datasets.test_dataloader() + for rep_i in range(cfg.replication): + for batch_id, batch in enumerate(test_dataloader): + batch = move_batch_to_device(batch, device) + with torch.no_grad(): + joints, joints_ref = model(batch) + + num_samples = len(joints) + text = batch['text'] + length = batch['length'] + if 'hint' in batch: + hint = batch['hint'] + mask_hint = hint.view(hint.shape[0], hint.shape[1], model.njoints, 3).sum(dim=-1, keepdim=True) != 0 + hint = model.datamodule.denorm_spatial(hint) + hint = hint.view(hint.shape[0], hint.shape[1], model.njoints, 3) * mask_hint + hint = remove_padding(hint, lengths=length) + else: + hint = None + + for i in range(num_samples): + res = dict() + pkl_path = osp.join(vis_dir, f"batch_id_{batch_id}_sample_id_{i}_length_{length[i]}_rep_{rep_i}.pkl") + res['joints'] = joints[i].detach().cpu().numpy() + res['text'] = text[i] + res['length'] = length[i] + res['hint'] = hint[i].detach().cpu().numpy() if hint is not None else None + with open(pkl_path, 'wb') as f: + pickle.dump(res, f) + logger.info(f"Motions are generated here:\n{pkl_path}") + + if not cfg.no_plot: + plot_3d_motion(pkl_path.replace('.pkl', '.mp4'), joints[i].detach().cpu().numpy(), + text[i], fps=20, hint=hint[i].detach().cpu().numpy() if hint is not None else None) + + if rep_i == 0: + res['joints'] = joints_ref[i].detach().cpu().numpy() + with open(pkl_path.replace('.pkl', '_ref.pkl'), 'wb') as f: + pickle.dump(res, f) + logger.info(f"Motions are generated here:\n{pkl_path.replace('.pkl', '_ref.pkl')}") + if not cfg.no_plot: + plot_3d_motion(pkl_path.replace('.pkl', '_ref.mp4'), joints_ref[i].detach().cpu().numpy(), + text[i], fps=20, hint=hint[i].detach().cpu().numpy() if hint is not None else None) + + +if __name__ == "__main__": + main() diff --git a/fit.py b/fit.py new file mode 100644 index 0000000000000000000000000000000000000000..62908cb1b101759e452712a24286b527ce57d651 --- /dev/null +++ b/fit.py @@ -0,0 +1,134 @@ +# borrow from optimization https://github.com/wangsen1312/joints2smpl +import os +import argparse +import pickle + +import h5py +import natsort +import smplx + +import torch + +from mld.transforms.joints2rots import config +from mld.transforms.joints2rots.smplify import SMPLify3D + +parser = argparse.ArgumentParser() +parser.add_argument("--pkl", type=str, default=None, help="pkl motion file") +parser.add_argument("--dir", type=str, default=None, help="pkl motion folder") +parser.add_argument("--num_smplify_iters", type=int, default=150, help="num of smplify iters") +parser.add_argument("--cuda", type=bool, default=True, help="enables cuda") +parser.add_argument("--gpu_ids", type=int, default=0, help="choose gpu ids") +parser.add_argument("--num_joints", type=int, default=22, help="joint number") +parser.add_argument("--joint_category", type=str, default="AMASS", help="use correspondence") +parser.add_argument("--fix_foot", type=str, default="False", help="fix foot or not") +opt = parser.parse_args() +print(opt) + +if opt.pkl: + paths = [opt.pkl] +elif opt.dir: + paths = [] + file_list = natsort.natsorted(os.listdir(opt.dir)) + for item in file_list: + if item.endswith('.pkl') and not item.endswith("_mesh.pkl"): + paths.append(os.path.join(opt.dir, item)) +else: + raise ValueError(f'{opt.pkl} and {opt.dir} are both None!') + +for path in paths: + # load joints + if os.path.exists(path.replace('.pkl', '_mesh.pkl')): + print(f"{path} is rendered! skip!") + continue + + with open(path, 'rb') as f: + data = pickle.load(f) + + joints = data['joints'] + # load predefined something + device = torch.device("cuda:" + str(opt.gpu_ids) if opt.cuda else "cpu") + print(config.SMPL_MODEL_DIR) + smplxmodel = smplx.create( + config.SMPL_MODEL_DIR, + model_type="smpl", + gender="neutral", + ext="pkl", + batch_size=joints.shape[0], + ).to(device) + + # load the mean pose as original + smpl_mean_file = config.SMPL_MEAN_FILE + + file = h5py.File(smpl_mean_file, "r") + init_mean_pose = ( + torch.from_numpy(file["pose"][:]) + .unsqueeze(0).repeat(joints.shape[0], 1) + .float() + .to(device) + ) + init_mean_shape = ( + torch.from_numpy(file["shape"][:]) + .unsqueeze(0).repeat(joints.shape[0], 1) + .float() + .to(device) + ) + cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(device) + + # initialize SMPLify + smplify = SMPLify3D( + smplxmodel=smplxmodel, + batch_size=joints.shape[0], + joints_category=opt.joint_category, + num_iters=opt.num_smplify_iters, + device=device, + ) + print("initialize SMPLify3D done!") + + print("Start SMPLify!") + keypoints_3d = torch.Tensor(joints).to(device).float() + + if opt.joint_category == "AMASS": + confidence_input = torch.ones(opt.num_joints) + # make sure the foot and ankle + if opt.fix_foot: + confidence_input[7] = 1.5 + confidence_input[8] = 1.5 + confidence_input[10] = 1.5 + confidence_input[11] = 1.5 + else: + print("Such category not settle down!") + + # ----- from initial to fitting ------- + ( + new_opt_vertices, + new_opt_joints, + new_opt_pose, + new_opt_betas, + new_opt_cam_t, + new_opt_joint_loss, + ) = smplify( + init_mean_pose.detach(), + init_mean_shape.detach(), + cam_trans_zero.detach(), + keypoints_3d, + conf_3d=confidence_input.to(device) + ) + + # fix shape + betas = torch.zeros_like(new_opt_betas) + root = keypoints_3d[:, 0, :] + + output = smplxmodel( + betas=betas, + global_orient=new_opt_pose[:, :3], + body_pose=new_opt_pose[:, 3:], + transl=root, + return_verts=True, + ) + vertices = output.vertices.detach().cpu().numpy() + data['vertices'] = vertices + + save_file = path.replace('.pkl', '_mesh.pkl') + with open(save_file, 'wb') as f: + pickle.dump(data, f) + print(f'vertices saved in {save_file}') diff --git a/mld/__init__.py b/mld/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/config.py b/mld/config.py new file mode 100644 index 0000000000000000000000000000000000000000..0b1ee2fe13753bf5c2696f1285173d8e7110bf65 --- /dev/null +++ b/mld/config.py @@ -0,0 +1,47 @@ +import os +import importlib +from typing import Type, TypeVar +from argparse import ArgumentParser + +from omegaconf import OmegaConf, DictConfig + + +def get_module_config(cfg_model: DictConfig, path: str = "modules") -> DictConfig: + files = os.listdir(f'./configs/{path}/') + for file in files: + if file.endswith('.yaml'): + with open(f'./configs/{path}/' + file, 'r') as f: + cfg_model.merge_with(OmegaConf.load(f)) + return cfg_model + + +def get_obj_from_str(string: str, reload: bool = False) -> Type: + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config: DictConfig) -> TypeVar: + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def parse_args() -> DictConfig: + parser = ArgumentParser() + parser.add_argument("--cfg", type=str, required=True, help="config file") + + # Demo Args + parser.add_argument('--example', type=str, required=False, help="input text and lengths with txt format") + parser.add_argument('--no-plot', action="store_true", required=False, help="whether plot the skeleton-based motion") + parser.add_argument('--replication', type=int, default=1, help="the number of replication of sampling") + args = parser.parse_args() + + cfg = OmegaConf.load(args.cfg) + cfg_model = get_module_config(cfg.model, cfg.model.target) + cfg = OmegaConf.merge(cfg, cfg_model) + + cfg.example = args.example + cfg.no_plot = args.no_plot + cfg.replication = args.replication + return cfg diff --git a/mld/data/HumanML3D.py b/mld/data/HumanML3D.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1d63ae89cf5f51ef65e8ddc2845c962dcdc2ec --- /dev/null +++ b/mld/data/HumanML3D.py @@ -0,0 +1,79 @@ +import copy +from typing import Callable, Optional + +import numpy as np +from omegaconf import DictConfig + +import torch + +from .base import BASEDataModule +from .humanml.dataset import Text2MotionDatasetV2 +from .humanml.scripts.motion_process import recover_from_ric + + +class HumanML3DDataModule(BASEDataModule): + + def __init__(self, + cfg: DictConfig, + batch_size: int, + num_workers: int, + collate_fn: Optional[Callable] = None, + persistent_workers: bool = True, + phase: str = "train", + **kwargs) -> None: + super().__init__(batch_size=batch_size, + num_workers=num_workers, + collate_fn=collate_fn, + persistent_workers=persistent_workers) + self.hparams = copy.deepcopy(kwargs) + self.name = "humanml3d" + self.njoints = 22 + if phase == "text_only": + raise NotImplementedError + else: + self.Dataset = Text2MotionDatasetV2 + self.cfg = cfg + + sample_overrides = {"tiny": True, "progress_bar": False} + self._sample_set = self.get_sample_set(overrides=sample_overrides) + self.nfeats = self._sample_set.nfeats + + def denorm_spatial(self, hint: torch.Tensor) -> torch.Tensor: + raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint) + raw_std = torch.tensor(self._sample_set.raw_std).to(hint) + hint = hint * raw_std + raw_mean + return hint + + def norm_spatial(self, hint: torch.Tensor) -> torch.Tensor: + raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint) + raw_std = torch.tensor(self._sample_set.raw_std).to(hint) + hint = (hint - raw_mean) / raw_std + return hint + + def feats2joints(self, features: torch.Tensor) -> torch.Tensor: + mean = torch.tensor(self.hparams['mean']).to(features) + std = torch.tensor(self.hparams['std']).to(features) + features = features * std + mean + return recover_from_ric(features, self.njoints) + + def renorm4t2m(self, features: torch.Tensor) -> torch.Tensor: + # renorm to t2m norms for using t2m evaluators + ori_mean = torch.tensor(self.hparams['mean']).to(features) + ori_std = torch.tensor(self.hparams['std']).to(features) + eval_mean = torch.tensor(self.hparams['mean_eval']).to(features) + eval_std = torch.tensor(self.hparams['std_eval']).to(features) + features = features * ori_std + ori_mean + features = (features - eval_mean) / eval_std + return features + + def mm_mode(self, mm_on: bool = True) -> None: + if mm_on: + self.is_mm = True + self.name_list = self.test_dataset.name_list + self.mm_list = np.random.choice(self.name_list, + self.cfg.TEST.MM_NUM_SAMPLES, + replace=False) + self.test_dataset.name_list = self.mm_list + else: + self.is_mm = False + self.test_dataset.name_list = self.name_list diff --git a/mld/data/Kit.py b/mld/data/Kit.py new file mode 100644 index 0000000000000000000000000000000000000000..735e03b26b66768ec288aa4c6ae7d4993d8525de --- /dev/null +++ b/mld/data/Kit.py @@ -0,0 +1,79 @@ +import copy +from typing import Callable, Optional + +import numpy as np +from omegaconf import DictConfig + +import torch + +from .base import BASEDataModule +from .humanml.dataset import Text2MotionDatasetV2 +from .humanml.scripts.motion_process import recover_from_ric + + +class KitDataModule(BASEDataModule): + + def __init__(self, + cfg: DictConfig, + batch_size: int, + num_workers: int, + collate_fn: Optional[Callable] = None, + persistent_workers: bool = True, + phase: str = "train", + **kwargs) -> None: + super().__init__(batch_size=batch_size, + num_workers=num_workers, + collate_fn=collate_fn, + persistent_workers=persistent_workers) + self.hparams = copy.deepcopy(kwargs) + self.name = 'kit' + self.njoints = 21 + if phase == 'text_only': + raise NotImplementedError + else: + self.Dataset = Text2MotionDatasetV2 + self.cfg = cfg + + sample_overrides = {"tiny": True, "progress_bar": False} + self._sample_set = self.get_sample_set(overrides=sample_overrides) + self.nfeats = self._sample_set.nfeats + + def denorm_spatial(self, hint: torch.Tensor) -> torch.Tensor: + raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint) + raw_std = torch.tensor(self._sample_set.raw_std).to(hint) + hint = hint * raw_std + raw_mean + return hint + + def norm_spatial(self, hint: torch.Tensor) -> torch.Tensor: + raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint) + raw_std = torch.tensor(self._sample_set.raw_std).to(hint) + hint = (hint - raw_mean) / raw_std + return hint + + def feats2joints(self, features: torch.Tensor) -> torch.Tensor: + mean = torch.tensor(self.hparams['mean']).to(features) + std = torch.tensor(self.hparams['std']).to(features) + features = features * std + mean + return recover_from_ric(features, self.njoints) + + def renorm4t2m(self, features: torch.Tensor) -> torch.Tensor: + # renorm to t2m norms for using t2m evaluators + ori_mean = torch.tensor(self.hparams['mean']).to(features) + ori_std = torch.tensor(self.hparams['std']).to(features) + eval_mean = torch.tensor(self.hparams['mean_eval']).to(features) + eval_std = torch.tensor(self.hparams['std_eval']).to(features) + features = features * ori_std + ori_mean + features = (features - eval_mean) / eval_std + return features + + def mm_mode(self, mm_on: bool = True) -> None: + if mm_on: + self.is_mm = True + self.name_list = self.test_dataset.name_list + self.mm_list = np.random.choice(self.name_list, + self.cfg.TEST.MM_NUM_SAMPLES, + replace=False) + self.test_dataset.name_list = self.mm_list + else: + self.is_mm = False + self.test_dataset.name_list = self.name_list diff --git a/mld/data/__init__.py b/mld/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/data/base.py b/mld/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c880f7c97e6891e99f7371b71c8070d02f4c88dc --- /dev/null +++ b/mld/data/base.py @@ -0,0 +1,65 @@ +import copy +from os.path import join as pjoin +from typing import Any, Callable + +from torch.utils.data import DataLoader + +from .humanml.dataset import Text2MotionDatasetV2 + + +class BASEDataModule: + def __init__(self, collate_fn: Callable, batch_size: int, + num_workers: int, persistent_workers: bool) -> None: + super(BASEDataModule, self).__init__() + self.dataloader_options = { + "batch_size": batch_size, + "num_workers": num_workers, + "collate_fn": collate_fn, + "persistent_workers": persistent_workers + } + self.is_mm = False + + def get_sample_set(self, overrides: dict) -> Text2MotionDatasetV2: + sample_params = copy.deepcopy(self.hparams) + sample_params.update(overrides) + split_file = pjoin( + eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"), + self.cfg.EVAL.SPLIT + ".txt", + ) + return self.Dataset(split_file=split_file, **sample_params) + + def __getattr__(self, item: str) -> Any: + if item.endswith("_dataset") and not item.startswith("_"): + subset = item[:-len("_dataset")] + item_c = "_" + item + if item_c not in self.__dict__: + + subset = subset.upper() if subset != "val" else "EVAL" + split = eval(f"self.cfg.{subset}.SPLIT") + split_file = pjoin( + eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"), + eval(f"self.cfg.{subset}.SPLIT") + ".txt", + ) + self.__dict__[item_c] = self.Dataset(split_file=split_file, + split=split, + **self.hparams) + return getattr(self, item_c) + classname = self.__class__.__name__ + raise AttributeError(f"'{classname}' object has no attribute '{item}'") + + def train_dataloader(self) -> DataLoader: + return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_options) + + def val_dataloader(self) -> DataLoader: + dataloader_options = self.dataloader_options.copy() + dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE + dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS + dataloader_options["shuffle"] = False + return DataLoader(self.val_dataset, **dataloader_options) + + def test_dataloader(self) -> DataLoader: + dataloader_options = self.dataloader_options.copy() + dataloader_options["batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE + dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS + dataloader_options["shuffle"] = False + return DataLoader(self.test_dataset, **dataloader_options) diff --git a/mld/data/get_data.py b/mld/data/get_data.py new file mode 100644 index 0000000000000000000000000000000000000000..622338a262161afda60595dbb8cad28ec32e62f9 --- /dev/null +++ b/mld/data/get_data.py @@ -0,0 +1,93 @@ +from os.path import join as pjoin +from typing import Callable, Optional + +import numpy as np + +from omegaconf import DictConfig + +from .humanml.utils.word_vectorizer import WordVectorizer +from .HumanML3D import HumanML3DDataModule +from .Kit import KitDataModule +from .base import BASEDataModule +from .utils import mld_collate + + +def get_mean_std(phase: str, cfg: DictConfig, dataset_name: str) -> tuple[np.ndarray, np.ndarray]: + name = "t2m" if dataset_name == "humanml3d" else dataset_name + assert name in ["t2m", "kit"] + if phase in ["val"]: + if name == 't2m': + data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD01", "meta") + elif name == 'kit': + data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD005", "meta") + else: + raise ValueError("Only support t2m and kit") + mean = np.load(pjoin(data_root, "mean.npy")) + std = np.load(pjoin(data_root, "std.npy")) + else: + data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT") + mean = np.load(pjoin(data_root, "Mean.npy")) + std = np.load(pjoin(data_root, "Std.npy")) + + return mean, std + + +def get_WordVectorizer(cfg: DictConfig, phase: str, dataset_name: str) -> Optional[WordVectorizer]: + if phase not in ["text_only"]: + if dataset_name.lower() in ["humanml3d", "kit"]: + return WordVectorizer(cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab") + else: + raise ValueError("Only support WordVectorizer for HumanML3D") + else: + return None + + +def get_collate_fn(name: str) -> Callable: + if name.lower() in ["humanml3d", "kit"]: + return mld_collate + else: + raise NotImplementedError + + +dataset_module_map = {"humanml3d": HumanML3DDataModule, "kit": KitDataModule} +motion_subdir = {"humanml3d": "new_joint_vecs", "kit": "new_joint_vecs"} + + +def get_datasets(cfg: DictConfig, phase: str = "train") -> list[BASEDataModule]: + dataset_names = eval(f"cfg.{phase.upper()}.DATASETS") + datasets = [] + for dataset_name in dataset_names: + if dataset_name.lower() in ["humanml3d", "kit"]: + data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT") + mean, std = get_mean_std(phase, cfg, dataset_name) + mean_eval, std_eval = get_mean_std("val", cfg, dataset_name) + wordVectorizer = get_WordVectorizer(cfg, phase, dataset_name) + collate_fn = get_collate_fn(dataset_name) + dataset = dataset_module_map[dataset_name.lower()]( + cfg=cfg, + batch_size=cfg.TRAIN.BATCH_SIZE, + num_workers=cfg.TRAIN.NUM_WORKERS, + collate_fn=collate_fn, + persistent_workers=cfg.TRAIN.PERSISTENT_WORKERS, + mean=mean, + std=std, + mean_eval=mean_eval, + std_eval=std_eval, + w_vectorizer=wordVectorizer, + text_dir=pjoin(data_root, "texts"), + motion_dir=pjoin(data_root, motion_subdir[dataset_name]), + max_motion_length=cfg.DATASET.SAMPLER.MAX_LEN, + min_motion_length=cfg.DATASET.SAMPLER.MIN_LEN, + max_text_len=cfg.DATASET.SAMPLER.MAX_TEXT_LEN, + unit_length=eval( + f"cfg.DATASET.{dataset_name.upper()}.UNIT_LEN"), + model_kwargs=cfg.model + ) + datasets.append(dataset) + + elif dataset_name.lower() in ["humanact12", 'uestc', "amass"]: + raise NotImplementedError + + cfg.DATASET.NFEATS = datasets[0].nfeats + cfg.DATASET.NJOINTS = datasets[0].njoints + return datasets diff --git a/mld/data/humanml/__init__.py b/mld/data/humanml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/data/humanml/common/quaternion.py b/mld/data/humanml/common/quaternion.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0135e523300aa970e3d797b666557172fdbe06 --- /dev/null +++ b/mld/data/humanml/common/quaternion.py @@ -0,0 +1,29 @@ +import torch + + +def qinv(q: torch.Tensor) -> torch.Tensor: + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + mask = torch.ones_like(q) + mask[..., 1:] = -mask[..., 1:] + return q * mask + + +def qrot(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Rotate vector(s) v about the rotation described by quaternion(s) q. + Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, + where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + assert v.shape[-1] == 3 + assert q.shape[:-1] == v.shape[:-1] + + original_shape = list(v.shape) + q = q.contiguous().view(-1, 4) + v = v.contiguous().view(-1, 3) + + qvec = q[:, 1:] + uv = torch.cross(qvec, v, dim=1) + uuv = torch.cross(qvec, uv, dim=1) + return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) diff --git a/mld/data/humanml/dataset.py b/mld/data/humanml/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d9af8ea0be3d18b6cd7929694bcf88af42bcf365 --- /dev/null +++ b/mld/data/humanml/dataset.py @@ -0,0 +1,290 @@ +import codecs as cs +import random +from os.path import join as pjoin + +import numpy as np +from rich.progress import track + +import torch +from torch.utils import data + +from mld.data.humanml.scripts.motion_process import recover_from_ric +from .utils.word_vectorizer import WordVectorizer + + +class Text2MotionDatasetV2(data.Dataset): + + def __init__( + self, + mean: np.ndarray, + std: np.ndarray, + split_file: str, + w_vectorizer: WordVectorizer, + max_motion_length: int, + min_motion_length: int, + max_text_len: int, + unit_length: int, + motion_dir: str, + text_dir: str, + tiny: bool = False, + progress_bar: bool = True, + **kwargs, + ) -> None: + self.w_vectorizer = w_vectorizer + self.max_motion_length = max_motion_length + self.min_motion_length = min_motion_length + self.max_text_len = max_text_len + self.unit_length = unit_length + + data_dict = {} + id_list = [] + with cs.open(split_file, "r") as f: + for line in f.readlines(): + id_list.append(line.strip()) + self.id_list = id_list + + if tiny: + progress_bar = False + maxdata = 10 + else: + maxdata = 1e10 + + if progress_bar: + enumerator = enumerate( + track( + id_list, + f"Loading HumanML3D {split_file.split('/')[-1].split('.')[0]}", + )) + else: + enumerator = enumerate(id_list) + count = 0 + bad_count = 0 + new_name_list = [] + length_list = [] + for i, name in enumerator: + if count > maxdata: + break + try: + motion = np.load(pjoin(motion_dir, name + ".npy")) + if (len(motion)) < self.min_motion_length or (len(motion) >= 200): + bad_count += 1 + continue + text_data = [] + flag = False + with cs.open(pjoin(text_dir, name + ".txt")) as f: + for line in f.readlines(): + text_dict = {} + line_split = line.strip().split("#") + caption = line_split[0] + tokens = line_split[1].split(" ") + f_tag = float(line_split[2]) + to_tag = float(line_split[3]) + f_tag = 0.0 if np.isnan(f_tag) else f_tag + to_tag = 0.0 if np.isnan(to_tag) else to_tag + + text_dict["caption"] = caption + text_dict["tokens"] = tokens + if f_tag == 0.0 and to_tag == 0.0: + flag = True + text_data.append(text_dict) + else: + try: + n_motion = motion[int(f_tag * 20):int(to_tag * + 20)] + if (len(n_motion) + ) < self.min_motion_length or ( + (len(n_motion) >= 200)): + continue + new_name = ( + random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + + "_" + name) + while new_name in data_dict: + new_name = (random.choice( + "ABCDEFGHIJKLMNOPQRSTUVW") + "_" + + name) + data_dict[new_name] = { + "motion": n_motion, + "length": len(n_motion), + "text": [text_dict], + } + new_name_list.append(new_name) + length_list.append(len(n_motion)) + except: + print(line_split) + print(line_split[2], line_split[3], f_tag, to_tag, name) + + if flag: + data_dict[name] = { + "motion": motion, + "length": len(motion), + "text": text_data, + } + new_name_list.append(name) + length_list.append(len(motion)) + count += 1 + except: + pass + + name_list, length_list = zip( + *sorted(zip(new_name_list, length_list), key=lambda x: x[1])) + + self.mean = mean + self.std = std + + self.mode = None + model_params = kwargs['model_kwargs'] + if 'is_controlnet' in model_params and model_params.is_controlnet is True: + if 'test' in split_file or 'val' in split_file: + self.mode = 'eval' + else: + self.mode = 'train' + + self.t_ctrl = model_params.is_controlnet_temporal + spatial_norm_path = './datasets/humanml_spatial_norm' + self.raw_mean = np.load(pjoin(spatial_norm_path, 'Mean_raw.npy')) + self.raw_std = np.load(pjoin(spatial_norm_path, 'Std_raw.npy')) + + self.training_control_joint = np.array(model_params.training_control_joint) + self.testing_control_joint = np.array(model_params.testing_control_joint) + + self.training_density = model_params.training_density + self.testing_density = model_params.testing_density + + self.length_arr = np.array(length_list) + self.data_dict = data_dict + self.nfeats = motion.shape[1] + self.name_list = name_list + + def __len__(self) -> int: + return len(self.name_list) + + def random_mask(self, joints: np.ndarray, n_joints: int = 22) -> np.ndarray: + choose_joint = self.testing_control_joint + + length = joints.shape[0] + density = self.testing_density + if density in [1, 2, 5]: + choose_seq_num = density + else: + choose_seq_num = int(length * density / 100) + + if self.t_ctrl: + choose_seq = np.arange(0, choose_seq_num) + else: + choose_seq = np.random.choice(length, choose_seq_num, replace=False) + choose_seq.sort() + + mask_seq = np.zeros((length, n_joints, 3)).astype(bool) + + for cj in choose_joint: + mask_seq[choose_seq, cj] = True + + # normalize + joints = (joints - self.raw_mean.reshape(n_joints, 3)) / self.raw_std.reshape(n_joints, 3) + joints = joints * mask_seq + return joints + + def random_mask_train(self, joints: np.ndarray, n_joints: int = 22) -> np.ndarray: + if self.t_ctrl: + choose_joint = self.training_control_joint + else: + num_joints = len(self.training_control_joint) + num_joints_control = 1 + choose_joint = np.random.choice(num_joints, num_joints_control, replace=False) + choose_joint = self.training_control_joint[choose_joint] + + length = joints.shape[0] + + if self.training_density == 'random': + choose_seq_num = np.random.choice(length - 1, 1) + 1 + else: + choose_seq_num = int(length * random.uniform(self.training_density[0], self.training_density[1]) / 100) + + if self.t_ctrl: + choose_seq = np.arange(0, choose_seq_num) + else: + choose_seq = np.random.choice(length, choose_seq_num, replace=False) + choose_seq.sort() + + mask_seq = np.zeros((length, n_joints, 3)).astype(bool) + + for cj in choose_joint: + mask_seq[choose_seq, cj] = True + + # normalize + joints = (joints - self.raw_mean.reshape(n_joints, 3)) / self.raw_std.reshape(n_joints, 3) + joints = joints * mask_seq + return joints + + def __getitem__(self, idx: int) -> tuple: + data = self.data_dict[self.name_list[idx]] + motion, m_length, text_list = data["motion"], data["length"], data["text"] + # Randomly select a caption + text_data = random.choice(text_list) + caption, tokens = text_data["caption"], text_data["tokens"] + + if len(tokens) < self.max_text_len: + # pad with "unk" + tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"] + sent_len = len(tokens) + tokens = tokens + ["unk/OTHER" + ] * (self.max_text_len + 2 - sent_len) + else: + # crop + tokens = tokens[:self.max_text_len] + tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"] + sent_len = len(tokens) + pos_one_hots = [] + word_embeddings = [] + for token in tokens: + word_emb, pos_oh = self.w_vectorizer[token] + pos_one_hots.append(pos_oh[None, :]) + word_embeddings.append(word_emb[None, :]) + pos_one_hots = np.concatenate(pos_one_hots, axis=0) + word_embeddings = np.concatenate(word_embeddings, axis=0) + + # Crop the motions in to times of 4, and introduce small variations + if self.unit_length < 10: + coin2 = np.random.choice(["single", "single", "double"]) + else: + coin2 = "single" + + if coin2 == "double": + m_length = (m_length // self.unit_length - 1) * self.unit_length + elif coin2 == "single": + m_length = (m_length // self.unit_length) * self.unit_length + idx = random.randint(0, len(motion) - m_length) + motion = motion[idx:idx + m_length] + + hint = None + if self.mode is not None: + n_joints = 22 if motion.shape[-1] == 263 else 21 + # hint is global position of the controllable joints + joints = recover_from_ric(torch.from_numpy(motion).float(), n_joints) + joints = joints.numpy() + + # control any joints at any time + if self.mode == 'train': + hint = self.random_mask_train(joints, n_joints) + else: + hint = self.random_mask(joints, n_joints) + + hint = hint.reshape(hint.shape[0], -1) + + "Z Normalization" + motion = (motion - self.mean) / self.std + + # debug check nan + if np.any(np.isnan(motion)): + raise ValueError("nan in motion") + + return ( + word_embeddings, + pos_one_hots, + caption, + sent_len, + motion, + m_length, + "_".join(tokens), + hint + ) diff --git a/mld/data/humanml/scripts/motion_process.py b/mld/data/humanml/scripts/motion_process.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebcef746e94c586d13c8b859ae5e059890d788a --- /dev/null +++ b/mld/data/humanml/scripts/motion_process.py @@ -0,0 +1,51 @@ +import torch + +from ..common.quaternion import qinv, qrot + + +# Recover global angle and positions for rotation dataset +# root_rot_velocity (B, seq_len, 1) +# root_linear_velocity (B, seq_len, 2) +# root_y (B, seq_len, 1) +# ric_data (B, seq_len, (joint_num - 1)*3) +# rot_data (B, seq_len, (joint_num - 1)*6) +# local_velocity (B, seq_len, joint_num*3) +# foot contact (B, seq_len, 4) +def recover_root_rot_pos(data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + rot_vel = data[..., 0] + r_rot_ang = torch.zeros_like(rot_vel).to(data.device) + '''Get Y-axis rotation from rotation velocity''' + r_rot_ang[..., 1:] = rot_vel[..., :-1] + r_rot_ang = torch.cumsum(r_rot_ang, dim=-1) + + r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device) + r_rot_quat[..., 0] = torch.cos(r_rot_ang) + r_rot_quat[..., 2] = torch.sin(r_rot_ang) + + r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device) + r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3] + '''Add Y-axis rotation to root position''' + r_pos = qrot(qinv(r_rot_quat), r_pos) + + r_pos = torch.cumsum(r_pos, dim=-2) + + r_pos[..., 1] = data[..., 3] + return r_rot_quat, r_pos + + +def recover_from_ric(data: torch.Tensor, joints_num: int) -> torch.Tensor: + r_rot_quat, r_pos = recover_root_rot_pos(data) + positions = data[..., 4:(joints_num - 1) * 3 + 4] + positions = positions.view(positions.shape[:-1] + (-1, 3)) + + '''Add Y-axis rotation to local joints''' + positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions) + + '''Add root XZ to joints''' + positions[..., 0] += r_pos[..., 0:1] + positions[..., 2] += r_pos[..., 2:3] + + '''Concat root and joints''' + positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2) + + return positions diff --git a/mld/data/humanml/utils/__init__.py b/mld/data/humanml/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/data/humanml/utils/paramUtil.py b/mld/data/humanml/utils/paramUtil.py new file mode 100644 index 0000000000000000000000000000000000000000..019ea21f1a1ce17996c81aa559460b9733029ceb --- /dev/null +++ b/mld/data/humanml/utils/paramUtil.py @@ -0,0 +1,62 @@ +import numpy as np + +# Define a kinematic tree for the skeletal structure +kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] + +kit_raw_offsets = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1] + ] +) + +t2m_raw_offsets = np.array([[0, 0, 0], + [1, 0, 0], + [-1, 0, 0], + [0, 1, 0], + [0, -1, 0], + [0, -1, 0], + [0, 1, 0], + [0, -1, 0], + [0, -1, 0], + [0, 1, 0], + [0, 0, 1], + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + [-1, 0, 0], + [0, 0, 1], + [0, -1, 0], + [0, -1, 0], + [0, -1, 0], + [0, -1, 0], + [0, -1, 0], + [0, -1, 0]]) + +t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], + [9, 13, 16, 18, 20]] +t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] +t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] + +kit_tgt_skel_id = '03950' + +t2m_tgt_skel_id = '000021' diff --git a/mld/data/humanml/utils/plot_script.py b/mld/data/humanml/utils/plot_script.py new file mode 100644 index 0000000000000000000000000000000000000000..dcee7fc2497de4546ed62a9c9da5c74dc0a3372e --- /dev/null +++ b/mld/data/humanml/utils/plot_script.py @@ -0,0 +1,98 @@ +from textwrap import wrap +from typing import Optional + +import numpy as np + +import matplotlib.pyplot as plt +import mpl_toolkits.mplot3d.axes3d as p3 +from matplotlib.animation import FuncAnimation +from mpl_toolkits.mplot3d.art3d import Poly3DCollection + +import mld.data.humanml.utils.paramUtil as paramUtil + +skeleton = paramUtil.t2m_kinematic_chain + + +def plot_3d_motion(save_path: str, joints: np.ndarray, title: str, + figsize: tuple[int, int] = (3, 3), + fps: int = 120, radius: int = 3, kinematic_tree: list = skeleton, + hint: Optional[np.ndarray] = None) -> None: + + title = '\n'.join(wrap(title, 20)) + + def init(): + ax.set_xlim3d([-radius / 2, radius / 2]) + ax.set_ylim3d([0, radius]) + ax.set_zlim3d([-radius / 3., radius * 2 / 3.]) + fig.suptitle(title, fontsize=10) + ax.grid(b=False) + + def plot_xzPlane(minx, maxx, miny, minz, maxz): + # Plot a plane XZ + verts = [ + [minx, miny, minz], + [minx, miny, maxz], + [maxx, miny, maxz], + [maxx, miny, minz] + ] + xz_plane = Poly3DCollection([verts]) + xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) + ax.add_collection3d(xz_plane) + + # (seq_len, joints_num, 3) + data = joints.copy().reshape(len(joints), -1, 3) + + data *= 1.3 # scale for visualization + if hint is not None: + mask = hint.sum(-1) != 0 + hint = hint[mask] + hint *= 1.3 + + fig = plt.figure(figsize=figsize) + plt.tight_layout() + ax = p3.Axes3D(fig) + init() + MINS = data.min(axis=0).min(axis=0) + MAXS = data.max(axis=0).max(axis=0) + colors = ["#DD5A37", "#D69E00", "#B75A39", "#DD5A37", "#D69E00", + "#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00", + "#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", ] + + frame_number = data.shape[0] + + height_offset = MINS[1] + data[:, :, 1] -= height_offset + if hint is not None: + hint[..., 1] -= height_offset + trajec = data[:, 0, [0, 2]] + + data[..., 0] -= data[:, 0:1, 0] + data[..., 2] -= data[:, 0:1, 2] + + def update(index): + ax.lines = [] + ax.collections = [] + ax.view_init(elev=120, azim=-90) + ax.dist = 7.5 + plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1], + MAXS[2] - trajec[index, 1]) + + if hint is not None: + ax.scatter(hint[..., 0] - trajec[index, 0], hint[..., 1], hint[..., 2] - trajec[index, 1], color="#80B79A") + + for i, (chain, color) in enumerate(zip(kinematic_tree, colors)): + if i < 5: + linewidth = 4.0 + else: + linewidth = 2.0 + ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, + color=color) + + plt.axis('off') + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_zticklabels([]) + + ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False) + ani.save(save_path, fps=fps) + plt.close() diff --git a/mld/data/humanml/utils/word_vectorizer.py b/mld/data/humanml/utils/word_vectorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a95b28b1f8ad5c02fc4f1ec29f0be071eb22db7f --- /dev/null +++ b/mld/data/humanml/utils/word_vectorizer.py @@ -0,0 +1,82 @@ +import pickle +from os.path import join as pjoin + +import numpy as np + + +POS_enumerator = { + 'VERB': 0, + 'NOUN': 1, + 'DET': 2, + 'ADP': 3, + 'NUM': 4, + 'AUX': 5, + 'PRON': 6, + 'ADJ': 7, + 'ADV': 8, + 'Loc_VIP': 9, + 'Body_VIP': 10, + 'Obj_VIP': 11, + 'Act_VIP': 12, + 'Desc_VIP': 13, + 'OTHER': 14, +} + +Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward', + 'up', 'down', 'straight', 'curve') + +Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh') + +Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball') + +Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn', + 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll', + 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb') + +Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily', + 'angrily', 'sadly') + +VIP_dict = { + 'Loc_VIP': Loc_list, + 'Body_VIP': Body_list, + 'Obj_VIP': Obj_List, + 'Act_VIP': Act_list, + 'Desc_VIP': Desc_list, +} + + +class WordVectorizer(object): + def __init__(self, meta_root: str, prefix: str) -> None: + vectors = np.load(pjoin(meta_root, '%s_data.npy' % prefix)) + words = pickle.load(open(pjoin(meta_root, '%s_words.pkl' % prefix), 'rb')) + word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl' % prefix), 'rb')) + self.word2vec = {w: vectors[word2idx[w]] for w in words} + + def _get_pos_ohot(self, pos: str) -> np.ndarray: + pos_vec = np.zeros(len(POS_enumerator)) + if pos in POS_enumerator: + pos_vec[POS_enumerator[pos]] = 1 + else: + pos_vec[POS_enumerator['OTHER']] = 1 + return pos_vec + + def __len__(self) -> int: + return len(self.word2vec) + + def __getitem__(self, item: str) -> tuple: + word, pos = item.split('/') + if word in self.word2vec: + word_vec = self.word2vec[word] + vip_pos = None + for key, values in VIP_dict.items(): + if word in values: + vip_pos = key + break + if vip_pos is not None: + pos_vec = self._get_pos_ohot(vip_pos) + else: + pos_vec = self._get_pos_ohot(pos) + else: + word_vec = self.word2vec['unk'] + pos_vec = self._get_pos_ohot('OTHER') + return word_vec, pos_vec diff --git a/mld/data/utils.py b/mld/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ed84566271aa6f3c5f0b4cea02850c544a7734 --- /dev/null +++ b/mld/data/utils.py @@ -0,0 +1,38 @@ +import torch + + +def collate_tensors(batch: list) -> torch.Tensor: + dims = batch[0].dim() + max_size = [max([b.size(i) for b in batch]) for i in range(dims)] + size = (len(batch), ) + tuple(max_size) + canvas = batch[0].new_zeros(size=size) + for i, b in enumerate(batch): + sub_tensor = canvas[i] + for d in range(dims): + sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) + sub_tensor.add_(b) + return canvas + + +def mld_collate(batch: list) -> dict: + notnone_batches = [b for b in batch if b is not None] + notnone_batches.sort(key=lambda x: x[3], reverse=True) + adapted_batch = { + "motion": + collate_tensors([torch.tensor(b[4]).float() for b in notnone_batches]), + "text": [b[2] for b in notnone_batches], + "length": [b[5] for b in notnone_batches], + "word_embs": + collate_tensors([torch.tensor(b[0]).float() for b in notnone_batches]), + "pos_ohot": + collate_tensors([torch.tensor(b[1]).float() for b in notnone_batches]), + "text_len": + collate_tensors([torch.tensor(b[3]) for b in notnone_batches]), + "tokens": [b[6] for b in notnone_batches], + } + + # collate trajectory + if notnone_batches[0][-1] is not None: + adapted_batch['hint'] = collate_tensors([torch.tensor(b[-1]).float() for b in notnone_batches]) + + return adapted_batch diff --git a/mld/launch/__init__.py b/mld/launch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/launch/blender.py b/mld/launch/blender.py new file mode 100644 index 0000000000000000000000000000000000000000..09a7fae39721c918050bd15fb3356b58d34b78e9 --- /dev/null +++ b/mld/launch/blender.py @@ -0,0 +1,23 @@ +# Fix blender path +import os +import sys +from argparse import ArgumentParser + +sys.path.append(os.path.expanduser("~/.local/lib/python3.9/site-packages")) + + +# Monkey patch argparse such that +# blender / python parsing works +def parse_args(self, args=None, namespace=None): + if args is not None: + return self.parse_args_bak(args=args, namespace=namespace) + try: + idx = sys.argv.index("--") + args = sys.argv[idx + 1:] # the list after '--' + except ValueError as e: # '--' not in the list: + args = [] + return self.parse_args_bak(args=args, namespace=namespace) + + +setattr(ArgumentParser, 'parse_args_bak', ArgumentParser.parse_args) +setattr(ArgumentParser, 'parse_args', parse_args) diff --git a/mld/models/__init__.py b/mld/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/models/architectures/__init__.py b/mld/models/architectures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/models/architectures/mld_clip.py b/mld/models/architectures/mld_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..fe9b3545e363a651e33270054378fefdaccb2815 --- /dev/null +++ b/mld/models/architectures/mld_clip.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn + +from transformers import AutoModel, AutoTokenizer +from sentence_transformers import SentenceTransformer + + +class MldTextEncoder(nn.Module): + + def __init__(self, modelpath: str, last_hidden_state: bool = False) -> None: + super().__init__() + + if 't5' in modelpath: + self.text_model = SentenceTransformer(modelpath) + self.tokenizer = self.text_model.tokenizer + else: + self.tokenizer = AutoTokenizer.from_pretrained(modelpath) + self.text_model = AutoModel.from_pretrained(modelpath) + + self.max_length = self.tokenizer.model_max_length + if "clip" in modelpath: + self.text_encoded_dim = self.text_model.config.text_config.hidden_size + if last_hidden_state: + self.name = "clip_hidden" + else: + self.name = "clip" + elif "bert" in modelpath: + self.name = "bert" + self.text_encoded_dim = self.text_model.config.hidden_size + elif 't5' in modelpath: + self.name = 't5' + else: + raise ValueError(f"Model {modelpath} not supported") + + def forward(self, texts: list[str]) -> torch.Tensor: + # get prompt text embeddings + if self.name in ["clip", "clip_hidden"]: + text_inputs = self.tokenizer( + texts, + padding="max_length", + truncation=True, + max_length=self.max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + # split into max length Clip can handle + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + text_input_ids = text_input_ids[:, :self.tokenizer.model_max_length] + elif self.name == "bert": + text_inputs = self.tokenizer(texts, return_tensors="pt", padding=True) + + if self.name == "clip": + # (batch_Size, text_encoded_dim) + text_embeddings = self.text_model.get_text_features( + text_input_ids.to(self.text_model.device)) + # (batch_Size, 1, text_encoded_dim) + text_embeddings = text_embeddings.unsqueeze(1) + elif self.name == "clip_hidden": + # (batch_Size, seq_length , text_encoded_dim) + text_embeddings = self.text_model.text_model( + text_input_ids.to(self.text_model.device)).last_hidden_state + elif self.name == "bert": + # (batch_Size, seq_length , text_encoded_dim) + text_embeddings = self.text_model( + **text_inputs.to(self.text_model.device)).last_hidden_state + elif self.name == 't5': + text_embeddings = self.text_model.encode(texts, show_progress_bar=False, convert_to_tensor=True, batch_size=len(texts)) + text_embeddings = text_embeddings.unsqueeze(1) + else: + raise NotImplementedError(f"Model {self.name} not implemented") + + return text_embeddings diff --git a/mld/models/architectures/mld_denoiser.py b/mld/models/architectures/mld_denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..5a31b35963747a1536ba8616482e5e7101d5bfb7 --- /dev/null +++ b/mld/models/architectures/mld_denoiser.py @@ -0,0 +1,172 @@ +from typing import Optional, Union + +import torch +import torch.nn as nn + +from mld.models.architectures.tools.embeddings import (TimestepEmbedding, + Timesteps) +from mld.models.operator.cross_attention import (SkipTransformerEncoder, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer) +from mld.models.operator.position_encoding import build_position_encoding + + +class MldDenoiser(nn.Module): + + def __init__(self, + latent_dim: list = [1, 256], + ff_size: int = 1024, + num_layers: int = 6, + num_heads: int = 4, + dropout: float = 0.1, + normalize_before: bool = False, + activation: str = "gelu", + flip_sin_to_cos: bool = True, + return_intermediate_dec: bool = False, + position_embedding: str = "learned", + arch: str = "trans_enc", + freq_shift: float = 0, + text_encoded_dim: int = 768, + time_cond_proj_dim: int = None, + is_controlnet: bool = False) -> None: + + super().__init__() + + self.latent_dim = latent_dim[-1] + self.text_encoded_dim = text_encoded_dim + + self.arch = arch + self.time_cond_proj_dim = time_cond_proj_dim + + self.time_proj = Timesteps(text_encoded_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(text_encoded_dim, self.latent_dim, cond_proj_dim=time_cond_proj_dim) + if text_encoded_dim != self.latent_dim: + self.emb_proj = nn.Sequential(nn.ReLU(), nn.Linear(text_encoded_dim, self.latent_dim)) + + self.query_pos = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + + if self.arch == "trans_enc": + encoder_layer = TransformerEncoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + ) + encoder_norm = None if is_controlnet else nn.LayerNorm(self.latent_dim) + self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, + return_intermediate=is_controlnet) + + elif self.arch == "trans_dec": + assert not is_controlnet, f"controlnet not supported in architecture: 'trans_dec'" + self.mem_pos = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + + decoder_layer = TransformerDecoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + ) + decoder_norm = nn.LayerNorm(self.latent_dim) + self.decoder = TransformerDecoder( + decoder_layer, + num_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + + else: + raise ValueError(f"Not supported architecture: {self.arch}!") + + self.is_controlnet = is_controlnet + + def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + if self.is_controlnet: + self.controlnet_cond_embedding = nn.Sequential( + nn.Linear(self.latent_dim, self.latent_dim), + nn.Linear(self.latent_dim, self.latent_dim), + zero_module(nn.Linear(self.latent_dim, self.latent_dim)) + ) + + self.controlnet_down_mid_blocks = nn.ModuleList([ + zero_module(nn.Linear(self.latent_dim, self.latent_dim)) for _ in range(num_layers)]) + + def forward(self, + sample: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep_cond: Optional[torch.Tensor] = None, + controlnet_cond: Optional[torch.Tensor] = None, + controlnet_residuals: Optional[list[torch.Tensor]] = None + ) -> Union[torch.Tensor, list[torch.Tensor]]: + + # 0. dimension matching + # sample [latent_dim[0], batch_size, latent_dim] <= [batch_size, latent_dim[0], latent_dim[1]] + sample = sample.permute(1, 0, 2) + + # 1. check if controlnet + if self.is_controlnet: + controlnet_cond = controlnet_cond.permute(1, 0, 2) + sample = sample + self.controlnet_cond_embedding(controlnet_cond) + + # 2. time_embedding + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timestep.expand(sample.shape[1]).clone() + time_emb = self.time_proj(timesteps) + time_emb = time_emb.to(dtype=sample.dtype) + # [1, bs, latent_dim] <= [bs, latent_dim] + time_emb = self.time_embedding(time_emb, timestep_cond).unsqueeze(0) + + # 3. condition + time embedding + # text_emb [seq_len, batch_size, text_encoded_dim] <= [batch_size, seq_len, text_encoded_dim] + encoder_hidden_states = encoder_hidden_states.permute(1, 0, 2) + text_emb = encoder_hidden_states # [num_words, bs, latent_dim] + # text embedding projection + if self.text_encoded_dim != self.latent_dim: + # [1 or 2, bs, latent_dim] <= [1 or 2, bs, text_encoded_dim] + text_emb_latent = self.emb_proj(text_emb) + else: + text_emb_latent = text_emb + emb_latent = torch.cat((time_emb, text_emb_latent), 0) + + # 4. transformer + if self.arch == "trans_enc": + xseq = torch.cat((sample, emb_latent), axis=0) + + xseq = self.query_pos(xseq) + tokens = self.encoder(xseq, controlnet_residuals=controlnet_residuals) + + if self.is_controlnet: + control_res_samples = [] + for res, block in zip(tokens, self.controlnet_down_mid_blocks): + r = block(res) + control_res_samples.append(r) + return control_res_samples + + sample = tokens[:sample.shape[0]] + + elif self.arch == "trans_dec": + # tgt - [1 or 5 or 10, bs, latent_dim] + # memory - [token_num, bs, latent_dim] + sample = self.query_pos(sample) + emb_latent = self.mem_pos(emb_latent) + sample = self.decoder(tgt=sample, memory=emb_latent).squeeze(0) + + else: + raise TypeError(f"{self.arch} is not supported") + + # 5. [batch_size, latent_dim[0], latent_dim[1]] <= [latent_dim[0], batch_size, latent_dim[1]] + sample = sample.permute(1, 0, 2) + + return sample diff --git a/mld/models/architectures/mld_traj_encoder.py b/mld/models/architectures/mld_traj_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..08ec1e91029cf5b2477512ec9b07367c1ef38669 --- /dev/null +++ b/mld/models/architectures/mld_traj_encoder.py @@ -0,0 +1,78 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from mld.models.operator.cross_attention import SkipTransformerEncoder, TransformerEncoderLayer +from mld.models.operator.position_encoding import build_position_encoding +from mld.utils.temos_utils import lengths_to_mask + + +class MldTrajEncoder(nn.Module): + + def __init__(self, + nfeats: int, + latent_dim: list = [1, 256], + ff_size: int = 1024, + num_layers: int = 9, + num_heads: int = 4, + dropout: float = 0.1, + normalize_before: bool = False, + activation: str = "gelu", + position_embedding: str = "learned") -> None: + + super().__init__() + self.latent_size = latent_dim[0] + self.latent_dim = latent_dim[-1] + + self.skel_embedding = nn.Linear(nfeats * 3, self.latent_dim) + + self.query_pos_encoder = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + + encoder_layer = TransformerEncoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + ) + encoder_norm = nn.LayerNorm(self.latent_dim) + self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, + encoder_norm) + + self.global_motion_token = nn.Parameter( + torch.randn(self.latent_size, self.latent_dim)) + + def forward(self, features: torch.Tensor, lengths: Optional[list[int]] = None, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + + if lengths is None and mask is None: + lengths = [len(feature) for feature in features] + mask = lengths_to_mask(lengths, features.device) + + bs, nframes, nfeats = features.shape + + x = features + # Embed each human poses into latent vectors + x = self.skel_embedding(x) + + # Switch sequence and batch_size because the input of + # Pytorch Transformer is [Sequence, Batch size, ...] + x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim] + + # Each batch has its own set of tokens + dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1)) + + # create a bigger mask, to allow attend to emb + dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device) + aug_mask = torch.cat((dist_masks, mask), 1) + + # adding the embedding token for all sequences + xseq = torch.cat((dist, x), 0) + + xseq = self.query_pos_encoder(xseq) + global_token = self.encoder(xseq, src_key_padding_mask=~aug_mask)[:dist.shape[0]] + + return global_token diff --git a/mld/models/architectures/mld_vae.py b/mld/models/architectures/mld_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..aee5a5a6eb2ec39bfd62a600c8be792a9a884d81 --- /dev/null +++ b/mld/models/architectures/mld_vae.py @@ -0,0 +1,154 @@ +from typing import Optional + +import torch +import torch.nn as nn +from torch.distributions.distribution import Distribution + +from mld.models.operator.cross_attention import ( + SkipTransformerEncoder, + SkipTransformerDecoder, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer, +) +from mld.models.operator.position_encoding import build_position_encoding +from mld.utils.temos_utils import lengths_to_mask + + +class MldVae(nn.Module): + + def __init__(self, + nfeats: int, + latent_dim: list = [1, 256], + ff_size: int = 1024, + num_layers: int = 9, + num_heads: int = 4, + dropout: float = 0.1, + arch: str = "encoder_decoder", + normalize_before: bool = False, + activation: str = "gelu", + position_embedding: str = "learned") -> None: + + super().__init__() + + self.latent_size = latent_dim[0] + self.latent_dim = latent_dim[-1] + input_feats = nfeats + output_feats = nfeats + self.arch = arch + + self.query_pos_encoder = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + + encoder_layer = TransformerEncoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + ) + encoder_norm = nn.LayerNorm(self.latent_dim) + self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, + encoder_norm) + + if self.arch == "all_encoder": + decoder_norm = nn.LayerNorm(self.latent_dim) + self.decoder = SkipTransformerEncoder(encoder_layer, num_layers, + decoder_norm) + elif self.arch == 'encoder_decoder': + self.query_pos_decoder = build_position_encoding( + self.latent_dim, position_embedding=position_embedding) + + decoder_layer = TransformerDecoderLayer( + self.latent_dim, + num_heads, + ff_size, + dropout, + activation, + normalize_before, + ) + decoder_norm = nn.LayerNorm(self.latent_dim) + self.decoder = SkipTransformerDecoder(decoder_layer, num_layers, + decoder_norm) + else: + raise ValueError(f"Not support architecture: {self.arch}!") + + self.global_motion_token = nn.Parameter( + torch.randn(self.latent_size * 2, self.latent_dim)) + + self.skel_embedding = nn.Linear(input_feats, self.latent_dim) + self.final_layer = nn.Linear(self.latent_dim, output_feats) + + def forward(self, features: torch.Tensor, + lengths: Optional[list[int]] = None) -> tuple[torch.Tensor, torch.Tensor, Distribution]: + z, dist = self.encode(features, lengths) + feats_rst = self.decode(z, lengths) + return feats_rst, z, dist + + def encode(self, features: torch.Tensor, + lengths: Optional[list[int]] = None) -> tuple[torch.Tensor, Distribution]: + if lengths is None: + lengths = [len(feature) for feature in features] + + device = features.device + + bs, nframes, nfeats = features.shape + mask = lengths_to_mask(lengths, device) + + x = features + # Embed each human poses into latent vectors + x = self.skel_embedding(x) + + # Switch sequence and batch_size because the input of + # Pytorch Transformer is [Sequence, Batch size, ...] + x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim] + + # Each batch has its own set of tokens + dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1)) + + # create a bigger mask, to allow attend to emb + dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device) + aug_mask = torch.cat((dist_masks, mask), 1) + + # adding the embedding token for all sequences + xseq = torch.cat((dist, x), 0) + + xseq = self.query_pos_encoder(xseq) + dist = self.encoder(xseq, src_key_padding_mask=~aug_mask)[:dist.shape[0]] + + mu = dist[0:self.latent_size, ...] + logvar = dist[self.latent_size:, ...] + + # resampling + std = logvar.exp().pow(0.5) + dist = torch.distributions.Normal(mu, std) + latent = dist.rsample() + return latent, dist + + def decode(self, z: torch.Tensor, lengths: list[int]) -> torch.Tensor: + mask = lengths_to_mask(lengths, z.device) + bs, nframes = mask.shape + queries = torch.zeros(nframes, bs, self.latent_dim, device=z.device) + + if self.arch == "all_encoder": + xseq = torch.cat((z, queries), axis=0) + z_mask = torch.ones((bs, self.latent_size), dtype=torch.bool, device=z.device) + aug_mask = torch.cat((z_mask, mask), axis=1) + xseq = self.query_pos_decoder(xseq) + output = self.decoder(xseq, src_key_padding_mask=~aug_mask)[z.shape[0]:] + + elif self.arch == "encoder_decoder": + queries = self.query_pos_decoder(queries) + output = self.decoder( + tgt=queries, + memory=z, + tgt_key_padding_mask=~mask) + + output = self.final_layer(output) + # zero for padded area + output[~mask.T] = 0 + # Pytorch Transformer: [Sequence, Batch size, ...] + feats = output.permute(1, 0, 2) + return feats diff --git a/mld/models/architectures/t2m_motionenc.py b/mld/models/architectures/t2m_motionenc.py new file mode 100644 index 0000000000000000000000000000000000000000..f12367822a22e4d60a701120f616c83cb16c2f3e --- /dev/null +++ b/mld/models/architectures/t2m_motionenc.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence + + +class MovementConvEncoder(nn.Module): + def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None: + super(MovementConvEncoder, self).__init__() + self.main = nn.Sequential( + nn.Conv1d(input_size, hidden_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(hidden_size, output_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + ) + self.out_net = nn.Linear(output_size, output_size) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return self.out_net(outputs) + + +class MotionEncoderBiGRUCo(nn.Module): + def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None: + super(MotionEncoderBiGRUCo, self).__init__() + + self.input_emb = nn.Linear(input_size, hidden_size) + self.gru = nn.GRU( + hidden_size, hidden_size, batch_first=True, bidirectional=True + ) + self.output_net = nn.Sequential( + nn.Linear(hidden_size * 2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size), + ) + + self.hidden_size = hidden_size + self.hidden = nn.Parameter( + torch.randn((2, 1, self.hidden_size), requires_grad=True) + ) + + def forward(self, inputs: torch.Tensor, m_lens: torch.Tensor) -> torch.Tensor: + num_samples = inputs.shape[0] + + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = m_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) diff --git a/mld/models/architectures/t2m_textenc.py b/mld/models/architectures/t2m_textenc.py new file mode 100644 index 0000000000000000000000000000000000000000..d98fb36f3d6364f00196fa3b7221a88791838584 --- /dev/null +++ b/mld/models/architectures/t2m_textenc.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence + + +class TextEncoderBiGRUCo(nn.Module): + def __init__(self, word_size: int, pos_size: int, hidden_size: int, output_size: int) -> None: + super(TextEncoderBiGRUCo, self).__init__() + + self.pos_emb = nn.Linear(pos_size, word_size) + self.input_emb = nn.Linear(word_size, hidden_size) + self.gru = nn.GRU( + hidden_size, hidden_size, batch_first=True, bidirectional=True + ) + self.output_net = nn.Sequential( + nn.Linear(hidden_size * 2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size), + ) + + self.hidden_size = hidden_size + self.hidden = nn.Parameter( + torch.randn((2, 1, self.hidden_size), requires_grad=True) + ) + + def forward(self, word_embs: torch.Tensor, pos_onehot: torch.Tensor, + cap_lens: torch.Tensor) -> torch.Tensor: + num_samples = word_embs.shape[0] + + pos_embs = self.pos_emb(pos_onehot) + inputs = word_embs + pos_embs + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = cap_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) diff --git a/mld/models/architectures/tools/embeddings.py b/mld/models/architectures/tools/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..783f9ddf958d94cefdebbf75572550ebe1e678ba --- /dev/null +++ b/mld/models/architectures/tools/embeddings.py @@ -0,0 +1,89 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__(self, channel: int, time_embed_dim: int, + act_fn: str = "silu", cond_proj_dim: Optional[int] = None) -> None: + super().__init__() + + # distill CFG + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, channel, bias=False) + self.cond_proj.weight.data.fill_(0.0) + else: + self.cond_proj = None + + self.linear_1 = nn.Linear(channel, time_embed_dim) + self.act = None + if act_fn == "silu": + self.act = nn.SiLU() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + + def forward(self, sample: torch.Tensor, timestep_cond: Optional[torch.Tensor] = None) -> torch.Tensor: + if timestep_cond is not None: + sample = sample + self.cond_proj(timestep_cond) + + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, + downscale_freq_shift: float) -> None: + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb diff --git a/mld/models/metrics/__init__.py b/mld/models/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03142920ad70603f0cbd46d2737590f31a41a3bf --- /dev/null +++ b/mld/models/metrics/__init__.py @@ -0,0 +1,3 @@ +from .tm2t import TM2TMetrics +from .mm import MMMetrics +from .cm import ControlMetrics diff --git a/mld/models/metrics/cm.py b/mld/models/metrics/cm.py new file mode 100644 index 0000000000000000000000000000000000000000..7f395430c67343898998c253ea75394ae002c7d4 --- /dev/null +++ b/mld/models/metrics/cm.py @@ -0,0 +1,55 @@ +import torch +from torchmetrics import Metric +from torchmetrics.utilities import dim_zero_cat + +from mld.utils.temos_utils import remove_padding +from .utils import calculate_skating_ratio, calculate_trajectory_error, control_l2 + + +class ControlMetrics(Metric): + + def __init__(self, dist_sync_on_step: bool = True) -> None: + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = "control_metrics" + + self.add_state("count_seq", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("skate_ratio_sum", default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state("dist_sum", default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state("traj_err", default=[], dist_reduce_fx="cat") + self.traj_err_key = ["traj_fail_20cm", "traj_fail_50cm", "kps_fail_20cm", "kps_fail_50cm", "kps_mean_err(m)"] + + def compute(self) -> dict: + count_seq = self.count_seq.item() + + metrics = dict() + metrics['Skating Ratio'] = self.skate_ratio_sum / count_seq + metrics['Control L2 dist'] = self.dist_sum / count_seq + traj_err = dim_zero_cat(self.traj_err).mean(0) + + for (k, v) in zip(self.traj_err_key, traj_err): + metrics[k] = v + + return {**metrics} + + def update(self, joints: torch.Tensor, hint: torch.Tensor, + mask_hint: torch.Tensor, lengths: list[int]) -> None: + self.count_seq += len(lengths) + + joints_no_padding = remove_padding(joints, lengths) + for j in joints_no_padding: + skate_ratio, _ = calculate_skating_ratio(j.unsqueeze(0).permute(0, 2, 3, 1)) + self.skate_ratio_sum += skate_ratio[0] + + joints_np = joints.cpu().numpy() + hint_np = hint.cpu().numpy() + mask_hint_np = mask_hint.cpu().numpy() + + for j, h, m in zip(joints_np, hint_np, mask_hint_np): + control_error = control_l2(j[None], h[None], m[None]) + mean_error = control_error.sum() / m.sum() + self.dist_sum += mean_error + control_error = control_error.reshape(-1) + m = m.reshape(-1) + err_np = calculate_trajectory_error(control_error, mean_error, m) + self.traj_err.append(torch.tensor(err_np[None], device=joints.device)) diff --git a/mld/models/metrics/mm.py b/mld/models/metrics/mm.py new file mode 100644 index 0000000000000000000000000000000000000000..5064885ecc396cca960947a4e95f9a930256ccd1 --- /dev/null +++ b/mld/models/metrics/mm.py @@ -0,0 +1,46 @@ +import torch +from torchmetrics import Metric +from torchmetrics.utilities import dim_zero_cat + +from .utils import calculate_multimodality_np + + +class MMMetrics(Metric): + + def __init__(self, mm_num_times: int = 10, dist_sync_on_step: bool = True) -> None: + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = "MultiModality scores" + + self.mm_num_times = mm_num_times + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + self.metrics = ["MultiModality"] + self.add_state("MultiModality", + default=torch.tensor(0.), + dist_reduce_fx="sum") + + # cached batches + self.add_state("mm_motion_embeddings", default=[], dist_reduce_fx='cat') + + def compute(self) -> dict: + # init metrics + metrics = {metric: getattr(self, metric) for metric in self.metrics} + + # cat all embeddings + all_mm_motions = dim_zero_cat(self.mm_motion_embeddings).cpu().numpy() + metrics['MultiModality'] = calculate_multimodality_np( + all_mm_motions, self.mm_num_times) + + return {**metrics} + + def update(self, mm_motion_embeddings: torch.Tensor, lengths: list[int]) -> None: + self.count += sum(lengths) + self.count_seq += len(lengths) + + # store all mm motion embeddings + self.mm_motion_embeddings.append(mm_motion_embeddings) diff --git a/mld/models/metrics/tm2t.py b/mld/models/metrics/tm2t.py new file mode 100644 index 0000000000000000000000000000000000000000..5850286c217e76bc530830f40f170dc52130a7ad --- /dev/null +++ b/mld/models/metrics/tm2t.py @@ -0,0 +1,148 @@ +from torchmetrics import Metric +from torchmetrics.utilities import dim_zero_cat + +from .utils import * + + +class TM2TMetrics(Metric): + + def __init__(self, + top_k: int = 3, + R_size: int = 32, + diversity_times: int = 300, + dist_sync_on_step: bool = True) -> None: + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.name = "matching, fid, and diversity scores" + + self.top_k = top_k + self.R_size = R_size + self.diversity_times = diversity_times + + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("count_seq", + default=torch.tensor(0), + dist_reduce_fx="sum") + + self.metrics = [] + # Matching scores + self.add_state("Matching_score", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.add_state("gt_Matching_score", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.Matching_metrics = ["Matching_score", "gt_Matching_score"] + for k in range(1, top_k + 1): + self.add_state( + f"R_precision_top_{str(k)}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"R_precision_top_{str(k)}") + for k in range(1, top_k + 1): + self.add_state( + f"gt_R_precision_top_{str(k)}", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.Matching_metrics.append(f"gt_R_precision_top_{str(k)}") + + self.metrics.extend(self.Matching_metrics) + + # FID + self.add_state("FID", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.metrics.append("FID") + + # Diversity + self.add_state("Diversity", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.add_state("gt_Diversity", + default=torch.tensor(0.0), + dist_reduce_fx="sum") + self.metrics.extend(["Diversity", "gt_Diversity"]) + + # cached batches + self.add_state("text_embeddings", default=[], dist_reduce_fx='cat') + self.add_state("recmotion_embeddings", default=[], dist_reduce_fx='cat') + self.add_state("gtmotion_embeddings", default=[], dist_reduce_fx='cat') + + def compute(self) -> dict: + count_seq = self.count_seq.item() + + # init metrics + metrics = {metric: getattr(self, metric) for metric in self.metrics} + + # cat all embeddings + shuffle_idx = torch.randperm(count_seq) + all_texts = dim_zero_cat(self.text_embeddings, axis=0).cpu()[shuffle_idx, :] + all_genmotions = dim_zero_cat(self.recmotion_embeddings, axis=0).cpu()[shuffle_idx, :] + all_gtmotions = dim_zero_cat(self.gtmotion_embeddings, axis=0).cpu()[shuffle_idx, :] + + # Compute r-precision + assert count_seq > self.R_size + top_k_mat = torch.zeros((self.top_k,)) + for i in range(count_seq // self.R_size): + # [bs=32, 1*256] + group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 1*256] + group_motions = all_genmotions[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 32] + dist_mat = euclidean_distance_matrix(group_texts, group_motions).nan_to_num() + # print(dist_mat[:5]) + self.Matching_score += dist_mat.trace() + argmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argmax, top_k=self.top_k).sum(axis=0) + R_count = count_seq // self.R_size * self.R_size + metrics["Matching_score"] = self.Matching_score / R_count + for k in range(self.top_k): + metrics[f"R_precision_top_{str(k + 1)}"] = top_k_mat[k] / R_count + + # Compute r-precision with gt + assert count_seq > self.R_size + top_k_mat = torch.zeros((self.top_k,)) + for i in range(count_seq // self.R_size): + # [bs=32, 1*256] + group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 1*256] + group_motions = all_gtmotions[i * self.R_size:(i + 1) * self.R_size] + # [bs=32, 32] + dist_mat = euclidean_distance_matrix(group_texts, group_motions).nan_to_num() + # match score + self.gt_Matching_score += dist_mat.trace() + argmax = torch.argsort(dist_mat, dim=1) + top_k_mat += calculate_top_k(argmax, top_k=self.top_k).sum(axis=0) + metrics["gt_Matching_score"] = self.gt_Matching_score / R_count + for k in range(self.top_k): + metrics[f"gt_R_precision_top_{str(k + 1)}"] = top_k_mat[k] / R_count + + # tensor -> numpy for FID + all_genmotions = all_genmotions.numpy() + all_gtmotions = all_gtmotions.numpy() + + # Compute fid + mu, cov = calculate_activation_statistics_np(all_genmotions) + gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions) + metrics["FID"] = calculate_frechet_distance_np(gt_mu, gt_cov, mu, cov) + + # Compute diversity + assert count_seq > self.diversity_times + metrics["Diversity"] = calculate_diversity_np(all_genmotions, self.diversity_times) + metrics["gt_Diversity"] = calculate_diversity_np(all_gtmotions, self.diversity_times) + + return {**metrics} + + def update( + self, + text_embeddings: torch.Tensor, + recmotion_embeddings: torch.Tensor, + gtmotion_embeddings: torch.Tensor, + lengths: list[int]) -> None: + self.count += sum(lengths) + self.count_seq += len(lengths) + + # store all texts and motions + self.text_embeddings.append(text_embeddings.detach()) + self.recmotion_embeddings.append(recmotion_embeddings.detach()) + self.gtmotion_embeddings.append(gtmotion_embeddings.detach()) diff --git a/mld/models/metrics/utils.py b/mld/models/metrics/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a1fd18c13f9cb87ad1d029920538e0c2f36fc97a --- /dev/null +++ b/mld/models/metrics/utils.py @@ -0,0 +1,246 @@ +import numpy as np + +import scipy.linalg +from scipy.ndimage import uniform_filter1d + +import torch +from torch import linalg + + +# Text-to-Motion + +# (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train +def euclidean_distance_matrix(matrix1: torch.Tensor, matrix2: torch.Tensor) -> torch.Tensor: + """ + Params: + -- matrix1: N1 x D + -- matrix2: N2 x D + Returns: + -- dists: N1 x N2 + dists[i, j] == distance(matrix1[i], matrix2[j]) + """ + assert matrix1.shape[1] == matrix2.shape[1] + d1 = -2 * torch.mm(matrix1, matrix2.T) # shape (num_test, num_train) + d2 = torch.sum(torch.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) + d3 = torch.sum(torch.square(matrix2), axis=1) # shape (num_train, ) + dists = torch.sqrt(d1 + d2 + d3) # broadcasting + return dists + + +def euclidean_distance_matrix_np(matrix1: np.ndarray, matrix2: np.ndarray) -> np.ndarray: + """ + Params: + -- matrix1: N1 x D + -- matrix2: N2 x D + Returns: + -- dists: N1 x N2 + dists[i, j] == distance(matrix1[i], matrix2[j]) + """ + assert matrix1.shape[1] == matrix2.shape[1] + d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) + d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) + d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) + dists = np.sqrt(d1 + d2 + d3) # broadcasting + return dists + + +def calculate_top_k(mat: torch.Tensor, top_k: int) -> torch.Tensor: + size = mat.shape[0] + gt_mat = (torch.unsqueeze(torch.arange(size), 1).to(mat.device).repeat_interleave(size, 1)) + bool_mat = mat == gt_mat + correct_vec = False + top_k_list = [] + for i in range(top_k): + correct_vec = correct_vec | bool_mat[:, i] + top_k_list.append(correct_vec[:, None]) + top_k_mat = torch.cat(top_k_list, dim=1) + return top_k_mat + + +def calculate_activation_statistics(activations: torch.Tensor) -> tuple: + """ + Params: + -- activation: num_samples x dim_feat + Returns: + -- mu: dim_feat + -- sigma: dim_feat x dim_feat + """ + activations = activations.cpu().numpy() + mu = np.mean(activations, axis=0) + sigma = np.cov(activations, rowvar=False) + return mu, sigma + + +def calculate_activation_statistics_np(activations: np.ndarray) -> tuple: + """ + Params: + -- activation: num_samples x dim_feat + Returns: + -- mu: dim_feat + -- sigma: dim_feat x dim_feat + """ + mu = np.mean(activations, axis=0) + cov = np.cov(activations, rowvar=False) + return mu, cov + + +def calculate_frechet_distance_np( + mu1: np.ndarray, + sigma1: np.ndarray, + mu2: np.ndarray, + sigma2: np.ndarray, + eps: float = 1e-6) -> float: + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert (mu1.shape == mu2.shape + ), "Training and test mean vectors have different lengths" + assert (sigma1.shape == sigma2.shape + ), "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + # Product might be almost singular + covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ("fid calculation produces singular product; " + "adding %s to diagonal of cov estimates") % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + +def calculate_diversity(activation: torch.Tensor, diversity_times: int) -> float: + assert len(activation.shape) == 2 + assert activation.shape[0] > diversity_times + num_samples = activation.shape[0] + + first_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + second_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + dist = linalg.norm(activation[first_indices] - activation[second_indices], + axis=1) + return dist.mean() + + +def calculate_diversity_np(activation: np.ndarray, diversity_times: int) -> float: + assert len(activation.shape) == 2 + assert activation.shape[0] > diversity_times + num_samples = activation.shape[0] + + first_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + second_indices = np.random.choice(num_samples, + diversity_times, + replace=False) + dist = scipy.linalg.norm(activation[first_indices] - + activation[second_indices], + axis=1) + return dist.mean() + + +def calculate_multimodality_np(activation: np.ndarray, multimodality_times: int) -> float: + assert len(activation.shape) == 3 + assert activation.shape[1] > multimodality_times + num_per_sent = activation.shape[1] + + first_dices = np.random.choice(num_per_sent, + multimodality_times, + replace=False) + second_dices = np.random.choice(num_per_sent, + multimodality_times, + replace=False) + dist = scipy.linalg.norm(activation[:, first_dices] - + activation[:, second_dices], + axis=2) + return dist.mean() + + +# Motion Control + +def calculate_skating_ratio(motions: torch.Tensor) -> tuple: + thresh_height = 0.05 # 10 + fps = 20.0 + thresh_vel = 0.50 # 20 cm /s + avg_window = 5 # frames + + # 10 left, 11 right foot. XZ plane, y up + # motions [bs, 22, 3, max_len] + verts_feet = motions[:, [10, 11], :, :].detach().cpu().numpy() # [bs, 2, 3, max_len] + verts_feet_plane_vel = np.linalg.norm(verts_feet[:, :, [0, 2], 1:] - verts_feet[:, :, [0, 2], :-1], + axis=2) * fps # [bs, 2, max_len-1] + # [bs, 2, max_len-1] + vel_avg = uniform_filter1d(verts_feet_plane_vel, axis=-1, size=avg_window, mode='constant', origin=0) + + verts_feet_height = verts_feet[:, :, 1, :] # [bs, 2, max_len] + # If feet touch ground in adjacent frames + feet_contact = np.logical_and((verts_feet_height[:, :, :-1] < thresh_height), + (verts_feet_height[:, :, 1:] < thresh_height)) # [bs, 2, max_len - 1] + # skate velocity + skate_vel = feet_contact * vel_avg + + # it must both skating in the current frame + skating = np.logical_and(feet_contact, (verts_feet_plane_vel > thresh_vel)) + # and also skate in the windows of frames + skating = np.logical_and(skating, (vel_avg > thresh_vel)) + + # Both feet slide + skating = np.logical_or(skating[:, 0, :], skating[:, 1, :]) # [bs, max_len -1] + skating_ratio = np.sum(skating, axis=1) / skating.shape[1] + + return skating_ratio, skate_vel + + +def calculate_trajectory_error(dist_error: np.ndarray, mean_err_traj: np.ndarray, + mask: np.ndarray, strict: bool = True) -> np.ndarray: + if strict: + # Traj fails if any of the key frame fails + traj_fail_02 = 1.0 - (dist_error <= 0.2).all() + traj_fail_05 = 1.0 - (dist_error <= 0.5).all() + else: + # Traj fails if the mean error of all keyframes more than the threshold + traj_fail_02 = (mean_err_traj > 0.2) + traj_fail_05 = (mean_err_traj > 0.5) + all_fail_02 = (dist_error > 0.2).sum() / mask.sum() + all_fail_05 = (dist_error > 0.5).sum() / mask.sum() + + return np.array([traj_fail_02, traj_fail_05, all_fail_02, all_fail_05, dist_error.sum() / mask.sum()]) + + +def control_l2(motion: np.ndarray, hint: np.ndarray, hint_mask: np.ndarray) -> np.ndarray: + loss = np.linalg.norm((motion - hint) * hint_mask, axis=-1) + return loss diff --git a/mld/models/modeltype/__init__.py b/mld/models/modeltype/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/models/modeltype/base.py b/mld/models/modeltype/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d24a11f6c9875243e91aa9fd4d1f2a8ecbed49af --- /dev/null +++ b/mld/models/modeltype/base.py @@ -0,0 +1,83 @@ +from typing import Any +from collections import OrderedDict + +import numpy as np + +import torch.nn as nn + +from mld.models.metrics import TM2TMetrics, MMMetrics, ControlMetrics + + +class BaseModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.times = [] + self.text_encoder_times = [] + self.diffusion_times = [] + self.vae_decode_times = [] + self.all_lengths = [] + + def test_step(self, batch: dict) -> None: + test_batch_size = self.cfg.TEST.BATCH_SIZE + if len(self.times) > 0: + inference_aits = round(np.mean(self.times) / test_batch_size, 5) + inference_aits_text = round(np.mean(self.text_encoder_times) / test_batch_size, 5) + inference_aits_diff = round(np.mean(self.diffusion_times) / test_batch_size, 5) + inference_aits_vae = round(np.mean(self.vae_decode_times) / test_batch_size, 5) + print(f"\nAverage Inference Time per Sentence ({test_batch_size*len(self.times)}): {inference_aits}\n" + f"(Text: {inference_aits_text}, Diff: {inference_aits_diff}, VAE: {inference_aits_vae})") + print(f"Average length: {round(np.mean(self.all_lengths), 5)}") + return self.allsplit_step("test", batch) + + def allsplit_epoch_end(self) -> dict: + res = dict() + if self.datamodule.is_mm and "TM2TMetrics" in self.metrics_dict: + metrics_dicts = ['MMMetrics'] + else: + metrics_dicts = self.metrics_dict + for metric in metrics_dicts: + metrics_dict = getattr(self, metric).compute() + # reset metrics + getattr(self, metric).reset() + res.update({ + f"Metrics/{metric}": value.item() + for metric, value in metrics_dict.items() + }) + return res + + def on_save_checkpoint(self, checkpoint: dict) -> None: + state_dict = checkpoint['state_dict'] + clip_k = [] + for k, v in state_dict.items(): + if 'text_encoder' in k: + clip_k.append(k) + for k in clip_k: + del checkpoint['state_dict'][k] + + def load_state_dict(self, state_dict: dict, strict: bool = True) -> Any: + clip_state_dict = self.text_encoder.state_dict() + new_state_dict = OrderedDict() + for k, v in clip_state_dict.items(): + new_state_dict['text_encoder.' + k] = v + for k, v in state_dict.items(): + if 'text_encoder' not in k: + new_state_dict[k] = v + + return super().load_state_dict(new_state_dict, strict) + + def configure_metrics(self) -> None: + for metric in self.metrics_dict: + if metric == "TM2TMetrics": + self.TM2TMetrics = TM2TMetrics( + diversity_times=self.cfg.TEST.DIVERSITY_TIMES, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP, + ) + elif metric == 'ControlMetrics': + self.ControlMetrics = ControlMetrics(dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP) + else: + raise NotImplementedError(f"Do not support Metric Type {metric}") + + if "TM2TMetrics" in self.metrics_dict: + self.MMMetrics = MMMetrics( + mm_num_times=self.cfg.TEST.MM_NUM_TIMES, + dist_sync_on_step=self.cfg.METRIC.DIST_SYNC_ON_STEP) diff --git a/mld/models/modeltype/mld.py b/mld/models/modeltype/mld.py new file mode 100644 index 0000000000000000000000000000000000000000..dae2e9c93a6d87dcfeb2374ab471c00f2ce038a0 --- /dev/null +++ b/mld/models/modeltype/mld.py @@ -0,0 +1,543 @@ +import os +import time +import inspect +import logging +from typing import Optional + +import tqdm +import numpy as np +from omegaconf import DictConfig + +import torch +import torch.nn.functional as F + +from mld.data.base import BASEDataModule +from mld.config import instantiate_from_config +from mld.models.architectures import mld_denoiser, mld_vae, t2m_motionenc, t2m_textenc +from mld.utils.temos_utils import lengths_to_mask, remove_padding +from mld.utils.utils import count_parameters, get_guidance_scale_embedding, extract_into_tensor, sum_flat + +from .base import BaseModel + +logger = logging.getLogger(__name__) + + +class MLD(BaseModel): + def __init__(self, cfg: DictConfig, datamodule: BASEDataModule) -> None: + super().__init__() + + self.cfg = cfg + self.nfeats = cfg.DATASET.NFEATS + self.njoints = cfg.DATASET.NJOINTS + self.latent_dim = cfg.model.latent_dim + self.guidance_scale = cfg.model.guidance_scale + self.guidance_uncondp = cfg.model.guidance_uncondp + self.datamodule = datamodule + + self.text_encoder = instantiate_from_config(cfg.model.text_encoder) + self.vae = instantiate_from_config(cfg.model.motion_vae) + self.denoiser = instantiate_from_config(cfg.model.denoiser) + + self.scheduler = instantiate_from_config(cfg.model.scheduler) + self.noise_scheduler = instantiate_from_config(cfg.model.noise_scheduler) + + self._get_t2m_evaluator(cfg) + + self.metrics_dict = cfg.METRIC.TYPE + self.configure_metrics() + + self.feats2joints = datamodule.feats2joints + + self.is_controlnet = cfg.model.is_controlnet + self.alphas = torch.sqrt(self.noise_scheduler.alphas_cumprod) + self.sigmas = torch.sqrt(1 - self.noise_scheduler.alphas_cumprod) + + self.l2_loss = lambda a, b: (a - b) ** 2 + if self.is_controlnet: + c_cfg = self.cfg.model.denoiser.copy() + c_cfg['params']['is_controlnet'] = True + self.controlnet = instantiate_from_config(c_cfg) + self.training_control_joint = cfg.model.training_control_joint + self.testing_control_joint = cfg.model.testing_control_joint + self.training_density = cfg.model.training_density + self.testing_density = cfg.model.testing_density + self.control_scale = cfg.model.control_scale + self.vaeloss = cfg.model.vaeloss + self.vaeloss_type = cfg.model.vaeloss_type + self.cond_ratio = cfg.model.cond_ratio + self.rot_ratio = cfg.model.rot_ratio + + self.is_controlnet_temporal = cfg.model.is_controlnet_temporal + + self.traj_encoder = instantiate_from_config(cfg.model.traj_encoder) + + logger.info(f"control scale: {self.control_scale}, vaeloss: {self.vaeloss}, " + f"cond_ratio: {self.cond_ratio}, rot_ratio: {self.rot_ratio}, " + f"vaeloss_type: {self.vaeloss_type}") + logger.info(f"is_controlnet_temporal: {self.is_controlnet_temporal}") + logger.info(f"training_control_joint: {self.training_control_joint}") + logger.info(f"testing_control_joint: {self.testing_control_joint}") + logger.info(f"training_density: {self.training_density}") + logger.info(f"testing_density: {self.testing_density}") + + time.sleep(2) + + self.summarize_parameters() + + @property + def do_classifier_free_guidance(self) -> bool: + return self.guidance_scale > 1 and self.denoiser.time_cond_proj_dim is None + + def summarize_parameters(self) -> None: + logger.info(f'VAE Encoder: {count_parameters(self.vae.encoder)}M') + logger.info(f'VAE Decoder: {count_parameters(self.vae.decoder)}M') + logger.info(f'Denoiser: {count_parameters(self.denoiser)}M') + + if self.is_controlnet: + vae = count_parameters(self.traj_encoder) + controlnet = count_parameters(self.controlnet) + logger.info(f'ControlNet: {controlnet}M') + logger.info(f'Spatial VAE: {vae}M') + + def _get_t2m_evaluator(self, cfg: DictConfig) -> None: + self.t2m_moveencoder = t2m_motionenc.MovementConvEncoder( + input_size=cfg.DATASET.NFEATS - 4, + hidden_size=cfg.model.t2m_motionencoder.dim_move_hidden, + output_size=cfg.model.t2m_motionencoder.dim_move_latent, + ) + + self.t2m_motionencoder = t2m_motionenc.MotionEncoderBiGRUCo( + input_size=cfg.model.t2m_motionencoder.dim_move_latent, + hidden_size=cfg.model.t2m_motionencoder.dim_motion_hidden, + output_size=cfg.model.t2m_motionencoder.dim_motion_latent, + ) + + self.t2m_textencoder = t2m_textenc.TextEncoderBiGRUCo( + word_size=cfg.model.t2m_textencoder.dim_word, + pos_size=cfg.model.t2m_textencoder.dim_pos_ohot, + hidden_size=cfg.model.t2m_textencoder.dim_text_hidden, + output_size=cfg.model.t2m_textencoder.dim_coemb_hidden, + ) + + # load pretrained + dataname = cfg.TEST.DATASETS[0] + dataname = "t2m" if dataname == "humanml3d" else dataname + t2m_checkpoint = torch.load( + os.path.join(cfg.model.t2m_path, dataname, "text_mot_match/model/finest.tar"), map_location='cpu') + self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"]) + self.t2m_moveencoder.load_state_dict(t2m_checkpoint["movement_encoder"]) + self.t2m_motionencoder.load_state_dict(t2m_checkpoint["motion_encoder"]) + + # freeze params + self.t2m_textencoder.eval() + self.t2m_moveencoder.eval() + self.t2m_motionencoder.eval() + for p in self.t2m_textencoder.parameters(): + p.requires_grad = False + for p in self.t2m_moveencoder.parameters(): + p.requires_grad = False + for p in self.t2m_motionencoder.parameters(): + p.requires_grad = False + + def forward(self, batch: dict) -> tuple: + texts = batch["text"] + lengths = batch["length"] + + if self.do_classifier_free_guidance: + texts = [""] * len(texts) + texts + + text_emb = self.text_encoder(texts) + + hint = batch['hint'] if 'hint' in batch else None # control signals + z = self._diffusion_reverse(text_emb, hint) + + with torch.no_grad(): + feats_rst = self.vae.decode(z, lengths) + joints = self.feats2joints(feats_rst.detach().cpu()) + joints = remove_padding(joints, lengths) + + joints_ref = None + if 'motion' in batch: + feats_ref = batch['motion'] + joints_ref = self.feats2joints(feats_ref.detach().cpu()) + joints_ref = remove_padding(joints_ref, lengths) + + return joints, joints_ref + + def predicted_origin(self, model_output, timesteps, sample): + self.alphas = self.alphas.to(model_output.device) + self.sigmas = self.sigmas.to(model_output.device) + sigmas = extract_into_tensor(self.sigmas, timesteps, sample.shape) + alphas = extract_into_tensor(self.alphas, timesteps, sample.shape) + pred_x_0 = (sample - sigmas * model_output) / alphas + return pred_x_0 + + def _diffusion_reverse(self, encoder_hidden_states: torch.Tensor, hint: torch.Tensor = None) -> torch.Tensor: + + controlnet_cond = None + if self.is_controlnet: + hint_mask = hint.sum(-1) != 0 + controlnet_cond = self.traj_encoder(hint, mask=hint_mask) + controlnet_cond = controlnet_cond.permute(1, 0, 2) + + # init latents + bsz = encoder_hidden_states.shape[0] + if self.do_classifier_free_guidance: + bsz = bsz // 2 + + latents = torch.randn( + (bsz, self.latent_dim[0], self.latent_dim[-1]), + device=encoder_hidden_states.device, + dtype=torch.float, + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + # set timesteps + self.scheduler.set_timesteps( + self.cfg.model.scheduler.num_inference_timesteps) + timesteps = self.scheduler.timesteps.to(encoder_hidden_states.device) + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, and between [0, 1] + extra_step_kwargs = {} + if "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys()): + extra_step_kwargs["eta"] = self.cfg.model.scheduler.eta + + timestep_cond = None + if self.denoiser.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(latents.shape[0]) + timestep_cond = get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.denoiser.time_cond_proj_dim + ).to(device=latents.device, dtype=latents.dtype) + + # reverse + for i, t in tqdm.tqdm(enumerate(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = (torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + controlnet_residuals = None + if self.is_controlnet: + if self.do_classifier_free_guidance: + controlnet_prompt_embeds = encoder_hidden_states.chunk(2)[1] + else: + controlnet_prompt_embeds = encoder_hidden_states + + controlnet_residuals = self.controlnet( + latents, + t, + timestep_cond=timestep_cond, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=controlnet_cond) + + if self.do_classifier_free_guidance: + controlnet_residuals = [torch.cat([torch.zeros_like(d), d * self.control_scale], dim=1) + for d in controlnet_residuals] + else: + controlnet_residuals = [d * self.control_scale for d in controlnet_residuals] + + # predict the noise residual + noise_pred = self.denoiser( + sample=latent_model_input, + timestep=t, + timestep_cond=timestep_cond, + encoder_hidden_states=encoder_hidden_states, + controlnet_residuals=controlnet_residuals) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # [batch_size, 1, latent_dim] -> [1, batch_size, latent_dim] + latents = latents.permute(1, 0, 2) + return latents + + def _diffusion_process(self, latents: torch.Tensor, encoder_hidden_states: torch.Tensor, + hint: torch.Tensor = None) -> dict: + + controlnet_cond = None + if self.is_controlnet: + hint = hint + hint_mask = hint.sum(-1) != 0 + controlnet_cond = self.traj_encoder(hint, mask=hint_mask) + controlnet_cond = controlnet_cond.permute(1, 0, 2) + + # [n_token, batch_size, latent_dim] -> [batch_size, n_token, latent_dim] + latents = latents.permute(1, 0, 2) + + timestep_cond = None + if self.denoiser.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(latents.shape[0]) + timestep_cond = get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.denoiser.time_cond_proj_dim + ).to(device=latents.device, dtype=latents.dtype) + + # Sample noise that we'll add to the latents + # [batch_size, n_token, latent_dim] + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each motion + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (bsz,), + device=latents.device, + ) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + noisy_latents = self.noise_scheduler.add_noise(latents.clone(), noise, + timesteps) + + controlnet_residuals = None + if self.is_controlnet: + controlnet_residuals = self.controlnet( + sample=noisy_latents, + timestep=timesteps, + timestep_cond=timestep_cond, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_cond) + + # Predict the noise residual + noise_pred = self.denoiser( + sample=noisy_latents, + timestep=timesteps, + timestep_cond=timestep_cond, + encoder_hidden_states=encoder_hidden_states, + controlnet_residuals=controlnet_residuals) + + model_pred = self.predicted_origin(noise_pred, timesteps, noisy_latents) + + n_set = { + "noise": noise, + "noise_pred": noise_pred, + "model_pred": model_pred, + "model_gt": latents + } + return n_set + + def masked_l2(self, a: torch.Tensor, b: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + loss = self.l2_loss(a, b) + loss = sum_flat(loss * mask.float()) + n_entries = a.shape[-1] + non_zero_elements = sum_flat(mask) * n_entries + mse_loss = loss / non_zero_elements + return mse_loss.mean() + + def train_diffusion_forward(self, batch: dict) -> dict: + feats_ref = batch["motion"] + lengths = batch["length"] + + # motion encode + with torch.no_grad(): + z, dist = self.vae.encode(feats_ref, lengths) + + text = batch["text"] + # classifier free guidance: randomly drop text during training + text = [ + "" if np.random.rand(1) < self.guidance_uncondp else i + for i in text + ] + # text encode + cond_emb = self.text_encoder(text) + + # diffusion process return with noise and noise_pred + hint = batch['hint'] if 'hint' in batch else None # control signals + n_set = self._diffusion_process(z, cond_emb, hint) + + loss_dict = dict() + + if self.denoiser.time_cond_proj_dim is not None: + # LCM + model_pred = n_set['model_pred'] + target = n_set['model_gt'] + diff_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # DM + model_pred = n_set['noise'] + target = n_set['noise_pred'] + diff_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + loss_dict['diff_loss'] = diff_loss + + if self.vaeloss: + z_pred = n_set['model_pred'] + feats_rst = self.vae.decode(z_pred.transpose(0, 1), lengths) + joints_rst = self.feats2joints(feats_rst) + joints_rst = joints_rst.view(joints_rst.shape[0], joints_rst.shape[1], -1) + joints_rst = self.datamodule.norm_spatial(joints_rst) + joints_rst = joints_rst.view(joints_rst.shape[0], joints_rst.shape[1], self.njoints, 3) + hint = batch['hint'] + hint = hint.view(hint.shape[0], hint.shape[1], self.njoints, 3) + mask_hint = hint.sum(dim=-1, keepdim=True) != 0 + + if self.cond_ratio != 0: + if self.vaeloss_type == 'mean': + cond_loss = (self.l2_loss(joints_rst, hint) * mask_hint).mean() + loss_dict['cond_loss'] = self.cond_ratio * cond_loss + elif self.vaeloss_type == 'sum': + cond_loss = (self.l2_loss(joints_rst, hint).sum(-1, keepdims=True) * mask_hint).sum() / mask_hint.sum() + loss_dict['cond_loss'] = self.cond_ratio * cond_loss + elif self.vaeloss_type == 'mask': + cond_loss = self.masked_l2(joints_rst, hint, mask_hint) + loss_dict['cond_loss'] = self.cond_ratio * cond_loss + else: + raise ValueError(f'Unsupported vaeloss_type: {self.vaeloss_type}') + else: + loss_dict['cond_loss'] = torch.tensor(0., device=diff_loss.device) + + if self.rot_ratio != 0: + mask_rot = lengths_to_mask(lengths, feats_rst.device).unsqueeze(-1) + if self.vaeloss_type == 'mean': + rot_loss = (self.l2_loss(feats_rst, feats_ref) * mask_rot).mean() + loss_dict['rot_loss'] = self.rot_ratio * rot_loss + elif self.vaeloss_type == 'sum': + rot_loss = (self.l2_loss(feats_rst, feats_ref).sum(-1, keepdims=True) * mask_rot).sum() / mask_rot.sum() + rot_loss = rot_loss / self.nfeats + loss_dict['rot_loss'] = self.rot_ratio * rot_loss + elif self.vaeloss_type == 'mask': + rot_loss = self.masked_l2(feats_rst, feats_ref, mask_rot) + loss_dict['rot_loss'] = self.rot_ratio * rot_loss + else: + raise ValueError(f'Unsupported vaeloss_type: {self.vaeloss_type}') + else: + loss_dict['rot_loss'] = torch.tensor(0., device=diff_loss.device) + + else: + loss_dict['cond_loss'] = torch.tensor(0., device=diff_loss.device) + loss_dict['rot_loss'] = torch.tensor(0., device=diff_loss.device) + + return loss_dict + + def t2m_eval(self, batch: dict) -> dict: + texts = batch["text"] + motions = batch["motion"].detach().clone() + lengths = batch["length"] + word_embs = batch["word_embs"].detach().clone() + pos_ohot = batch["pos_ohot"].detach().clone() + text_lengths = batch["text_len"].detach().clone() + + # start time + start = time.time() + + if self.datamodule.is_mm: + texts = texts * self.cfg.TEST.MM_NUM_REPEATS + motions = motions.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, dim=0) + lengths = lengths * self.cfg.TEST.MM_NUM_REPEATS + word_embs = word_embs.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, dim=0) + pos_ohot = pos_ohot.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, dim=0) + text_lengths = text_lengths.repeat_interleave(self.cfg.TEST.MM_NUM_REPEATS, dim=0) + + if self.do_classifier_free_guidance: + texts = [""] * len(texts) + texts + + text_st = time.time() + text_emb = self.text_encoder(texts) + text_et = time.time() + self.text_encoder_times.append(text_et - text_st) + + diff_st = time.time() + hint = batch['hint'] if 'hint' in batch else None # control signals + z = self._diffusion_reverse(text_emb, hint) + diff_et = time.time() + self.diffusion_times.append(diff_et - diff_st) + + with torch.no_grad(): + vae_st = time.time() + feats_rst = self.vae.decode(z, lengths) + vae_et = time.time() + self.vae_decode_times.append(vae_et - vae_st) + + self.all_lengths.extend(lengths) + + # end time + end = time.time() + self.times.append(end - start) + + # joints recover + joints_rst = self.feats2joints(feats_rst) + joints_ref = self.feats2joints(motions) + + # renorm for t2m evaluators + feats_rst = self.datamodule.renorm4t2m(feats_rst) + motions = self.datamodule.renorm4t2m(motions) + + # t2m motion encoder + m_lens = lengths.copy() + m_lens = torch.tensor(m_lens, device=motions.device) + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + feats_rst = feats_rst[align_idx] + m_lens = m_lens[align_idx] + m_lens = torch.div(m_lens, self.cfg.DATASET.HUMANML3D.UNIT_LEN, + rounding_mode="floor") + + recons_mov = self.t2m_moveencoder(feats_rst[..., :-4]).detach() + recons_emb = self.t2m_motionencoder(recons_mov, m_lens) + motion_mov = self.t2m_moveencoder(motions[..., :-4]).detach() + motion_emb = self.t2m_motionencoder(motion_mov, m_lens) + + # t2m text encoder + text_emb = self.t2m_textencoder(word_embs, pos_ohot, + text_lengths)[align_idx] + + rs_set = { + "m_ref": motions, + "m_rst": feats_rst, + "lat_t": text_emb, + "lat_m": motion_emb, + "lat_rm": recons_emb, + "joints_ref": joints_ref, + "joints_rst": joints_rst, + } + + if 'hint' in batch: + hint = batch['hint'] + mask_hint = hint.view(hint.shape[0], hint.shape[1], self.njoints, 3).sum(dim=-1, keepdim=True) != 0 + hint = self.datamodule.denorm_spatial(hint) + hint = hint.view(hint.shape[0], hint.shape[1], self.njoints, 3) * mask_hint + rs_set['hint'] = hint + rs_set['mask_hint'] = mask_hint + else: + rs_set['hint'] = None + + return rs_set + + def allsplit_step(self, split: str, batch: dict) -> Optional[dict]: + if split == "train": + loss_dict = self.train_diffusion_forward(batch) + return loss_dict + + # Compute the metrics + if split == "test": + rs_set = self.t2m_eval(batch) + + # MultiModality evaluation + if self.datamodule.is_mm: + metrics_dicts = ['MMMetrics'] + else: + metrics_dicts = self.metrics_dict + + for metric in metrics_dicts: + if metric == "TM2TMetrics": + getattr(self, metric).update( + # lat_t, latent encoded from text + # lat_rm, latent encoded from reconstructed motion + # lat_m, latent encoded from gt motion + rs_set["lat_t"], + rs_set["lat_rm"], + rs_set["lat_m"], + batch["length"], + ) + elif metric == "MMMetrics": + getattr(self, metric).update(rs_set["lat_rm"].unsqueeze(0), + batch["length"]) + elif metric == 'ControlMetrics': + assert rs_set['hint'] is not None + getattr(self, metric).update(rs_set["joints_rst"], rs_set['hint'], + rs_set['mask_hint'], batch['length']) + else: + raise TypeError(f"Not support this metric {metric}") diff --git a/mld/models/operator/__init__.py b/mld/models/operator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/models/operator/cross_attention.py b/mld/models/operator/cross_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..433800ab5b636b474f32f322127e043a794cdceb --- /dev/null +++ b/mld/models/operator/cross_attention.py @@ -0,0 +1,396 @@ +import copy +from typing import Optional, Callable + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class SkipTransformerEncoder(nn.Module): + def __init__(self, encoder_layer: nn.Module, num_layers: int, + norm: Optional[nn.Module] = None, return_intermediate: bool = False) -> None: + super().__init__() + self.d_model = encoder_layer.d_model + + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + assert num_layers % 2 == 1 + + num_block = (num_layers - 1) // 2 + self.input_blocks = _get_clones(encoder_layer, num_block) + self.middle_block = _get_clone(encoder_layer) + self.output_blocks = _get_clones(encoder_layer, num_block) + self.linear_blocks = _get_clones(nn.Linear(2 * self.d_model, self.d_model), num_block) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src: torch.Tensor, + mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None, + controlnet_residuals: Optional[list[torch.Tensor]] = None) -> torch.Tensor: + x = src + intermediate = [] + index = 0 + xs = [] + for module in self.input_blocks: + x = module(x, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if controlnet_residuals is not None: + x = x + controlnet_residuals[index] + index += 1 + + xs.append(x) + + if self.return_intermediate: + intermediate.append(x) + + x = self.middle_block(x, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if controlnet_residuals is not None: + x = x + controlnet_residuals[index] + index += 1 + + if self.return_intermediate: + intermediate.append(x) + + for (module, linear) in zip(self.output_blocks, self.linear_blocks): + x = torch.cat([x, xs.pop()], dim=-1) + x = linear(x) + x = module(x, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if controlnet_residuals is not None: + x = x + controlnet_residuals[index] + index += 1 + + if self.return_intermediate: + intermediate.append(x) + + if self.norm is not None: + x = self.norm(x) + + if self.return_intermediate: + return torch.stack(intermediate) + + return x + + +class SkipTransformerDecoder(nn.Module): + def __init__(self, decoder_layer: nn.Module, num_layers: int, + norm: Optional[nn.Module] = None) -> None: + super().__init__() + self.d_model = decoder_layer.d_model + + self.num_layers = num_layers + self.norm = norm + + assert num_layers % 2 == 1 + + num_block = (num_layers - 1) // 2 + self.input_blocks = _get_clones(decoder_layer, num_block) + self.middle_block = _get_clone(decoder_layer) + self.output_blocks = _get_clones(decoder_layer, num_block) + self.linear_blocks = _get_clones(nn.Linear(2 * self.d_model, self.d_model), num_block) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None, + query_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + x = tgt + + xs = [] + for module in self.input_blocks: + x = module(x, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + xs.append(x) + + x = self.middle_block(x, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + + for (module, linear) in zip(self.output_blocks, self.linear_blocks): + x = torch.cat([x, xs.pop()], dim=-1) + x = linear(x) + x = module(x, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + + if self.norm is not None: + x = self.norm(x) + + return x + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer: nn.Module, num_layers: int, + norm: Optional[nn.Module] = None) -> None: + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src: torch.Tensor, + mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None) -> torch.Tensor: + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer: nn.Module, num_layers: int, norm: Optional[nn.Module] = None, + return_intermediate: bool = False) -> None: + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None, + query_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + if self.return_intermediate: + intermediate.append(output) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, + activation: str = "relu", normalize_before: bool = False) -> None: + super().__init__() + self.d_model = d_model + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor: torch.Tensor, pos: Optional[Tensor] = None) -> torch.Tensor: + return tensor if pos is None else tensor + pos + + def forward_post(self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None) -> torch.Tensor: + + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None) -> torch.Tensor: + + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, + activation: str = "relu", normalize_before: bool = False) -> None: + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.d_model = d_model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor: torch.Tensor, pos: Optional[Tensor] = None) -> torch.Tensor: + return tensor if pos is None else tensor + pos + + def forward_post(self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None, + query_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None, + query_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None, + query_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + + +def _get_clone(module) -> nn.Module: + return copy.deepcopy(module) + + +def _get_clones(module, N) -> nn.ModuleList: + return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + + +def _get_activation_fn(activation: str) -> Callable: + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/mld/models/operator/position_encoding.py b/mld/models/operator/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..a0cee22d0aa7ec8ea3415bc63fcb3330e5f7fc5c --- /dev/null +++ b/mld/models/operator/position_encoding.py @@ -0,0 +1,57 @@ +import numpy as np + +import torch +import torch.nn as nn + + +class PositionEmbeddingSine1D(nn.Module): + + def __init__(self, d_model: int, max_len: int = 500, batch_first: bool = False) -> None: + super().__init__() + self.batch_first = batch_first + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange( + 0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer('pe', pe) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.batch_first: + x = x + self.pe.permute(1, 0, 2)[:, :x.shape[1], :] + else: + x = x + self.pe[:x.shape[0], :] + return x + + +class PositionEmbeddingLearned1D(nn.Module): + + def __init__(self, d_model: int, max_len: int = 500, batch_first: bool = False) -> None: + super().__init__() + self.batch_first = batch_first + self.pe = nn.Parameter(torch.zeros(max_len, 1, d_model)) + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.uniform_(self.pe) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.batch_first: + x = x + self.pe.permute(1, 0, 2)[:, :x.shape[1], :] + else: + x = x + self.pe[:x.shape[0], :] + return x + + +def build_position_encoding(N_steps: int, position_embedding: str = "sine") -> nn.Module: + if position_embedding == 'sine': + position_embedding = PositionEmbeddingSine1D(N_steps) + elif position_embedding == 'learned': + position_embedding = PositionEmbeddingLearned1D(N_steps) + else: + raise ValueError(f"not supported {position_embedding}") + return position_embedding diff --git a/mld/render/__init__.py b/mld/render/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/render/blender/__init__.py b/mld/render/blender/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a82255db45b763479586f83f0b7c904387b814ba --- /dev/null +++ b/mld/render/blender/__init__.py @@ -0,0 +1 @@ +from .render import render diff --git a/mld/render/blender/camera.py b/mld/render/blender/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..bc3938ca2e3946438d3a77260cd777b3d4598a0a --- /dev/null +++ b/mld/render/blender/camera.py @@ -0,0 +1,35 @@ +import bpy + + +class Camera: + def __init__(self, first_root, mode): + camera = bpy.data.objects['Camera'] + + # initial position + camera.location.x = 7.36 + camera.location.y = -6.93 + camera.location.z = 5.6 + + # wider point of view + if mode == "sequence": + camera.data.lens = 65 + elif mode == "frame": + camera.data.lens = 130 + elif mode == "video": + camera.data.lens = 110 + + self.mode = mode + self.camera = camera + + self.camera.location.x += first_root[0] + self.camera.location.y += first_root[1] + + self._root = first_root + + def update(self, new_root): + delta_root = new_root - self._root + + self.camera.location.x += delta_root[0] + self.camera.location.y += delta_root[1] + + self._root = new_root diff --git a/mld/render/blender/floor.py b/mld/render/blender/floor.py new file mode 100644 index 0000000000000000000000000000000000000000..cf4b44a2d204de6f76f3bddd9215219484f2a94d --- /dev/null +++ b/mld/render/blender/floor.py @@ -0,0 +1,63 @@ +import bpy +from .materials import floor_mat + + +def plot_floor(data, big_plane=True): + # Create a floor + minx, miny, _ = data.min(axis=(0, 1)) + maxx, maxy, _ = data.max(axis=(0, 1)) + + location = ((maxx + minx)/2, (maxy + miny)/2, 0) + # a little bit bigger + scale = (1.08*(maxx - minx)/2, 1.08*(maxy - miny)/2, 1) + + bpy.ops.mesh.primitive_plane_add(size=2, enter_editmode=False, align='WORLD', location=location, scale=(1, 1, 1)) + + bpy.ops.transform.resize(value=scale, orient_type='GLOBAL', orient_matrix=((1, 0, 0), (0, 1, 0), (0, 0, 1)), orient_matrix_type='GLOBAL', + constraint_axis=(False, True, False), mirror=True, use_proportional_edit=False, + proportional_edit_falloff='SMOOTH', proportional_size=1, use_proportional_connected=False, + use_proportional_projected=False, release_confirm=True) + obj = bpy.data.objects["Plane"] + obj.name = "SmallPlane" + obj.data.name = "SmallPlane" + + if not big_plane: + obj.active_material = floor_mat(color=(0.2, 0.2, 0.2, 1)) + else: + obj.active_material = floor_mat(color=(0.1, 0.1, 0.1, 1)) + + if big_plane: + location = ((maxx + minx)/2, (maxy + miny)/2, -0.01) + bpy.ops.mesh.primitive_plane_add(size=2, enter_editmode=False, align='WORLD', location=location, scale=(1, 1, 1)) + + bpy.ops.transform.resize(value=[2*x for x in scale], orient_type='GLOBAL', orient_matrix=((1, 0, 0), (0, 1, 0), (0, 0, 1)), orient_matrix_type='GLOBAL', + constraint_axis=(False, True, False), mirror=True, use_proportional_edit=False, + proportional_edit_falloff='SMOOTH', proportional_size=1, use_proportional_connected=False, + use_proportional_projected=False, release_confirm=True) + + obj = bpy.data.objects["Plane"] + obj.name = "BigPlane" + obj.data.name = "BigPlane" + obj.active_material = floor_mat(color=(0.2, 0.2, 0.2, 1)) + + +def show_trajectory(coords): + for i, coord in enumerate(coords): + import matplotlib + cmap = matplotlib.cm.get_cmap('Greens') + begin = 0.45 + end = 1.0 + frac = i / len(coords) + rgb_color = cmap(begin + (end - begin) * frac) + + x, y, z = coord + bpy.ops.mesh.primitive_uv_sphere_add(radius=0.04, location=(x, y, z)) + obj = bpy.context.active_object + + mat = bpy.data.materials.new(name="SphereMaterial") + obj.data.materials.append(mat) + mat.use_nodes = True + bsdf = mat.node_tree.nodes["Principled BSDF"] + bsdf.inputs['Base Color'].default_value = rgb_color + + bpy.ops.object.mode_set(mode='OBJECT') diff --git a/mld/render/blender/materials.py b/mld/render/blender/materials.py new file mode 100644 index 0000000000000000000000000000000000000000..2152b433a0decea7d368d5b0c3da9f749ec1aaa9 --- /dev/null +++ b/mld/render/blender/materials.py @@ -0,0 +1,70 @@ +import bpy + + +def clear_material(material): + if material.node_tree: + material.node_tree.links.clear() + material.node_tree.nodes.clear() + + +def colored_material_diffuse_BSDF(r, g, b, a=1, roughness=0.127451): + materials = bpy.data.materials + material = materials.new(name="body") + material.use_nodes = True + clear_material(material) + nodes = material.node_tree.nodes + links = material.node_tree.links + output = nodes.new(type='ShaderNodeOutputMaterial') + diffuse = nodes.new(type='ShaderNodeBsdfDiffuse') + diffuse.inputs["Color"].default_value = (r, g, b, a) + diffuse.inputs["Roughness"].default_value = roughness + links.new(diffuse.outputs['BSDF'], output.inputs['Surface']) + return material + + +# keys: +# ['Base Color', 'Subsurface', 'Subsurface Radius', 'Subsurface Color', 'Metallic', 'Specular', 'Specular Tint', 'Roughness', 'Anisotropic', 'Anisotropic Rotation', 'Sheen', 1Sheen Tint', 'Clearcoat', 'Clearcoat Roughness', 'IOR', 'Transmission', 'Transmission Roughness', 'Emission', 'Emission Strength', 'Alpha', 'Normal', 'Clearcoat Normal', 'Tangent'] +DEFAULT_BSDF_SETTINGS = {"Subsurface": 0.15, + "Subsurface Radius": [1.1, 0.2, 0.1], + "Metallic": 0.3, + "Specular": 0.5, + "Specular Tint": 0.5, + "Roughness": 0.75, + "Anisotropic": 0.25, + "Anisotropic Rotation": 0.25, + "Sheen": 0.75, + "Sheen Tint": 0.5, + "Clearcoat": 0.5, + "Clearcoat Roughness": 0.5, + "IOR": 1.450, + "Transmission": 0.1, + "Transmission Roughness": 0.1, + "Emission": (0, 0, 0, 1), + "Emission Strength": 0.0, + "Alpha": 1.0} + + +def body_material(r, g, b, a=1, name="body", oldrender=True): + if oldrender: + material = colored_material_diffuse_BSDF(r, g, b, a=a) + else: + materials = bpy.data.materials + material = materials.new(name=name) + material.use_nodes = True + nodes = material.node_tree.nodes + diffuse = nodes["Principled BSDF"] + inputs = diffuse.inputs + + settings = DEFAULT_BSDF_SETTINGS.copy() + settings["Base Color"] = (r, g, b, a) + settings["Subsurface Color"] = (r, g, b, a) + settings["Subsurface"] = 0.0 + + for setting, val in settings.items(): + inputs[setting].default_value = val + + return material + + +def floor_mat(color=(0.1, 0.1, 0.1, 1), roughness=0.127451): + return colored_material_diffuse_BSDF(color[0], color[1], color[2], a=color[3], roughness=roughness) diff --git a/mld/render/blender/meshes.py b/mld/render/blender/meshes.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b8998332daa3cb14142c1a1a7e07461876339a --- /dev/null +++ b/mld/render/blender/meshes.py @@ -0,0 +1,82 @@ +import numpy as np + +from .materials import body_material + +# Orange +GEN_SMPL = body_material(0.658, 0.214, 0.0114) +# Green +GT_SMPL = body_material(0.035, 0.415, 0.122) + + +class Meshes: + def __init__(self, data, gt, mode, trajectory, faces_path, always_on_floor, oldrender=True): + data, trajectory = prepare_meshes(data, trajectory, always_on_floor=always_on_floor) + + self.faces = np.load(faces_path) + print(faces_path) + self.data = data + self.mode = mode + self.oldrender = oldrender + + self.N = len(data) + # self.trajectory = data[:, :, [0, 1]].mean(1) + self.trajectory = trajectory + + if gt: + self.mat = GT_SMPL + else: + self.mat = GEN_SMPL + + def get_sequence_mat(self, frac): + import matplotlib + # cmap = matplotlib.cm.get_cmap('Blues') + cmap = matplotlib.cm.get_cmap('Oranges') + # begin = 0.60 + # end = 0.90 + begin = 0.50 + end = 0.90 + rgb_color = cmap(begin + (end-begin)*frac) + mat = body_material(*rgb_color, oldrender=self.oldrender) + return mat + + def get_root(self, index): + return self.data[index].mean(0) + + def get_mean_root(self): + return self.data.mean((0, 1)) + + def load_in_blender(self, index, mat): + vertices = self.data[index] + faces = self.faces + name = f"{str(index).zfill(4)}" + + from .tools import load_numpy_vertices_into_blender + load_numpy_vertices_into_blender(vertices, faces, name, mat) + + return name + + def __len__(self): + return self.N + + +def prepare_meshes(data, trajectory, always_on_floor=False): + # Swap axis (gravity=Z instead of Y) + data = data[..., [2, 0, 1]] + + if trajectory is not None: + trajectory = trajectory[..., [2, 0, 1]] + mask = trajectory.sum(-1) != 0 + trajectory = trajectory[mask] + + # Remove the floor + height_offset = data[..., 2].min() + data[..., 2] -= height_offset + + if trajectory is not None: + trajectory[..., 2] -= height_offset + + # Put all the body on the floor + if always_on_floor: + data[..., 2] -= data[..., 2].min(1)[:, None] + + return data, trajectory diff --git a/mld/render/blender/render.py b/mld/render/blender/render.py new file mode 100644 index 0000000000000000000000000000000000000000..4f43ec749e791ca5677f909813a13a862cb61f01 --- /dev/null +++ b/mld/render/blender/render.py @@ -0,0 +1,142 @@ +import os +import shutil + +import bpy + +from .camera import Camera +from .floor import plot_floor, show_trajectory +from .sampler import get_frameidx +from .scene import setup_scene +from .tools import delete_objs + +from mld.render.video import Video + + +def prune_begin_end(data, perc): + to_remove = int(len(data) * perc) + if to_remove == 0: + return data + return data[to_remove:-to_remove] + + +def render_current_frame(path): + bpy.context.scene.render.filepath = path + bpy.ops.render.render(use_viewport=True, write_still=True) + + +def render(npydata, trajectory, path, mode, faces_path, gt=False, + exact_frame=None, num=8, always_on_floor=False, denoising=True, + oldrender=True, res="high", accelerator='gpu', device=[0], fps=20): + + if mode == 'video': + if always_on_floor: + frames_folder = path.replace(".pkl", "_of_frames") + else: + frames_folder = path.replace(".pkl", "_frames") + + if os.path.exists(frames_folder.replace("_frames", ".mp4")) or os.path.exists(frames_folder): + print(f"pkl is rendered or under rendering {path}") + return + + os.makedirs(frames_folder, exist_ok=False) + + elif mode == 'sequence': + path = path.replace('.pkl', '.png') + img_name, ext = os.path.splitext(path) + if always_on_floor: + img_name += "_of" + img_path = f"{img_name}{ext}" + if os.path.exists(img_path): + print(f"pkl is rendered or under rendering {img_path}") + return + + elif mode == 'frame': + path = path.replace('.pkl', '.png') + img_name, ext = os.path.splitext(path) + if always_on_floor: + img_name += "_of" + img_path = f"{img_name}_{exact_frame}{ext}" + if os.path.exists(img_path): + print(f"pkl is rendered or under rendering {img_path}") + return + else: + raise ValueError(f'Invalid mode: {mode}') + + # Setup the scene (lights / render engine / resolution etc) + setup_scene(res=res, denoising=denoising, oldrender=oldrender, accelerator=accelerator, device=device) + + # remove X% of beginning and end + # as it is almost always static + # in this part + # if mode == "sequence": + # perc = 0.2 + # npydata = prune_begin_end(npydata, perc) + + from .meshes import Meshes + data = Meshes(npydata, gt=gt, mode=mode, trajectory=trajectory, + faces_path=faces_path, always_on_floor=always_on_floor) + + # Number of frames possible to render + nframes = len(data) + + # Show the trajectory + if trajectory is not None: + show_trajectory(data.trajectory) + + # Create a floor + plot_floor(data.data, big_plane=False) + + # initialize the camera + camera = Camera(first_root=data.get_root(0), mode=mode) + + frameidx = get_frameidx(mode=mode, nframes=nframes, + exact_frame=exact_frame, + frames_to_keep=num) + + nframes_to_render = len(frameidx) + + # center the camera to the middle + if mode == "sequence": + camera.update(data.get_mean_root()) + + imported_obj_names = [] + for index, frameidx in enumerate(frameidx): + if mode == "sequence": + frac = index / (nframes_to_render - 1) + mat = data.get_sequence_mat(frac) + else: + mat = data.mat + camera.update(data.get_root(frameidx)) + + islast = index == (nframes_to_render - 1) + + obj_name = data.load_in_blender(frameidx, mat) + name = f"{str(index).zfill(4)}" + + if mode == "video": + path = os.path.join(frames_folder, f"frame_{name}.png") + else: + path = img_path + + if mode == "sequence": + imported_obj_names.extend(obj_name) + elif mode == "frame": + camera.update(data.get_root(frameidx)) + + if mode != "sequence" or islast: + render_current_frame(path) + delete_objs(obj_name) + + # remove every object created + delete_objs(imported_obj_names) + delete_objs(["Plane", "myCurve", "Cylinder"]) + + if mode == "video": + video = Video(frames_folder, fps=fps) + vid_path = frames_folder.replace("_frames", ".mp4") + video.save(out_path=vid_path) + shutil.rmtree(frames_folder) + print(f"remove tmp fig folder and save video in {vid_path}") + + else: + print(f"Frame generated at: {img_path}") diff --git a/mld/render/blender/sampler.py b/mld/render/blender/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..149c6d99d05e5614bc291e0bec32940c4bbbc8ed --- /dev/null +++ b/mld/render/blender/sampler.py @@ -0,0 +1,17 @@ +import numpy as np + + +def get_frameidx(mode, nframes, exact_frame, frames_to_keep): + if mode == "sequence": + frameidx = np.linspace(0, nframes - 1, frames_to_keep) + frameidx = np.round(frameidx).astype(int) + frameidx = list(frameidx) + return frameidx + elif mode == "frame": + index_frame = int(exact_frame*nframes) + frameidx = [index_frame] + elif mode == "video": + frameidx = range(0, nframes) + else: + raise ValueError(f"Not support {mode} render mode") + return frameidx diff --git a/mld/render/blender/scene.py b/mld/render/blender/scene.py new file mode 100644 index 0000000000000000000000000000000000000000..79c09731756322e70c4433c26337e5a5f22e856a --- /dev/null +++ b/mld/render/blender/scene.py @@ -0,0 +1,94 @@ +import bpy + + +def setup_renderer(denoising=True, oldrender=True, accelerator="gpu", device=[0]): + bpy.context.scene.render.engine = "CYCLES" + bpy.data.scenes[0].render.engine = "CYCLES" + if accelerator.lower() == "gpu": + bpy.context.preferences.addons[ + "cycles" + ].preferences.compute_device_type = "CUDA" + bpy.context.scene.cycles.device = "GPU" + i = 0 + bpy.context.preferences.addons["cycles"].preferences.get_devices() + for d in bpy.context.preferences.addons["cycles"].preferences.devices: + if i in device: # gpu id + d["use"] = 1 + print(d["name"], "".join(str(i) for i in device)) + else: + d["use"] = 0 + i += 1 + + if denoising: + bpy.context.scene.cycles.use_denoising = True + + bpy.context.scene.render.tile_x = 256 + bpy.context.scene.render.tile_y = 256 + bpy.context.scene.cycles.samples = 64 + + if not oldrender: + bpy.context.scene.view_settings.view_transform = "Standard" + bpy.context.scene.render.film_transparent = True + bpy.context.scene.display_settings.display_device = "sRGB" + bpy.context.scene.view_settings.gamma = 1.2 + bpy.context.scene.view_settings.exposure = -0.75 + + +# Setup scene +def setup_scene( + res="high", denoising=True, oldrender=True, accelerator="gpu", device=[0] +): + scene = bpy.data.scenes["Scene"] + assert res in ["ultra", "high", "med", "low"] + if res == "high": + scene.render.resolution_x = 1280 + scene.render.resolution_y = 1024 + elif res == "med": + scene.render.resolution_x = 1280 // 2 + scene.render.resolution_y = 1024 // 2 + elif res == "low": + scene.render.resolution_x = 1280 // 4 + scene.render.resolution_y = 1024 // 4 + elif res == "ultra": + scene.render.resolution_x = 1280 * 2 + scene.render.resolution_y = 1024 * 2 + + scene.render.film_transparent = True + world = bpy.data.worlds["World"] + world.use_nodes = True + bg = world.node_tree.nodes["Background"] + bg.inputs[0].default_value[:3] = (1.0, 1.0, 1.0) + bg.inputs[1].default_value = 1.0 + + # Remove default cube + if "Cube" in bpy.data.objects: + bpy.data.objects["Cube"].select_set(True) + bpy.ops.object.delete() + + bpy.ops.object.light_add( + type="SUN", align="WORLD", location=(0, 0, 0), scale=(1, 1, 1) + ) + bpy.data.objects["Sun"].data.energy = 1.5 + + # rotate camera + bpy.ops.object.empty_add( + type="PLAIN_AXES", align="WORLD", location=(0, 0, 0), scale=(1, 1, 1) + ) + bpy.ops.transform.resize( + value=(10, 10, 10), + orient_type="GLOBAL", + orient_matrix=((1, 0, 0), (0, 1, 0), (0, 0, 1)), + orient_matrix_type="GLOBAL", + mirror=True, + use_proportional_edit=False, + proportional_edit_falloff="SMOOTH", + proportional_size=1, + use_proportional_connected=False, + use_proportional_projected=False, + ) + bpy.ops.object.select_all(action="DESELECT") + + setup_renderer( + denoising=denoising, oldrender=oldrender, accelerator=accelerator, device=device + ) + return scene diff --git a/mld/render/blender/tools.py b/mld/render/blender/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..57ea7f2c59a41dd633fa37a56dc317733945373e --- /dev/null +++ b/mld/render/blender/tools.py @@ -0,0 +1,40 @@ +import bpy +import numpy as np + + +# see this for more explanation +# https://gist.github.com/iyadahmed/7c7c0fae03c40bd87e75dc7059e35377 +# This should be solved with new version of blender +class ndarray_pydata(np.ndarray): + def __bool__(self) -> bool: + return len(self) > 0 + + +def load_numpy_vertices_into_blender(vertices, faces, name, mat): + mesh = bpy.data.meshes.new(name) + mesh.from_pydata(vertices, [], faces.view(ndarray_pydata)) + mesh.validate() + + obj = bpy.data.objects.new(name, mesh) + bpy.context.scene.collection.objects.link(obj) + + bpy.ops.object.select_all(action='DESELECT') + obj.select_set(True) + obj.active_material = mat + bpy.context.view_layer.objects.active = obj + bpy.ops.object.shade_smooth() + bpy.ops.object.select_all(action='DESELECT') + return True + + +def delete_objs(names): + if not isinstance(names, list): + names = [names] + # bpy.ops.object.mode_set(mode='OBJECT') + bpy.ops.object.select_all(action='DESELECT') + for obj in bpy.context.scene.objects: + for name in names: + if obj.name.startswith(name) or obj.name.endswith(name): + obj.select_set(True) + bpy.ops.object.delete() + bpy.ops.object.select_all(action='DESELECT') diff --git a/mld/render/video.py b/mld/render/video.py new file mode 100644 index 0000000000000000000000000000000000000000..451170bb67b73f3856fbc70c3fda3b51f151ac49 --- /dev/null +++ b/mld/render/video.py @@ -0,0 +1,66 @@ +import moviepy.editor as mp +import os +import imageio + + +def mask_png(frames): + for frame in frames: + im = imageio.imread(frame) + im[im[:, :, 3] < 1, :] = 255 + imageio.imwrite(frame, im[:, :, 0:3]) + return + + +class Video: + def __init__(self, frame_path: str, fps: float = 12.5, res="high"): + frame_path = str(frame_path) + self.fps = fps + + self._conf = {"codec": "libx264", + "fps": self.fps, + "audio_codec": "aac", + "temp_audiofile": "temp-audio.m4a", + "remove_temp": True} + + if res == "low": + bitrate = "500k" + else: + bitrate = "5000k" + + self._conf = {"bitrate": bitrate, + "fps": self.fps} + + # Load video + # video = mp.VideoFileClip(video1_path, audio=False) + # Load with frames + frames = [os.path.join(frame_path, x) + for x in sorted(os.listdir(frame_path))] + + # mask background white for videos + mask_png(frames) + + video = mp.ImageSequenceClip(frames, fps=fps) + self.video = video + self.duration = video.duration + + def add_text(self, text): + # needs ImageMagick + video_text = mp.TextClip(text, + font='Amiri', + color='white', + method='caption', + align="center", + size=(self.video.w, None), + fontsize=30) + video_text = video_text.on_color(size=(self.video.w, video_text.h + 5), + color=(0, 0, 0), + col_opacity=0.6) + # video_text = video_text.set_pos('bottom') + video_text = video_text.set_pos('top') + + self.video = mp.CompositeVideoClip([self.video, video_text]) + + def save(self, out_path): + out_path = str(out_path) + self.video.subclip(0, self.duration).write_videofile( + out_path, **self._conf) diff --git a/mld/transforms/__init__.py b/mld/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/transforms/joints2rots/__init__.py b/mld/transforms/joints2rots/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/transforms/joints2rots/config.py b/mld/transforms/joints2rots/config.py new file mode 100644 index 0000000000000000000000000000000000000000..cd15c7581a68ed6e97f32a6f070d4bbd5a7befb8 --- /dev/null +++ b/mld/transforms/joints2rots/config.py @@ -0,0 +1,31 @@ +# Map joints Name to SMPL joints idx +JOINT_MAP = { + 'MidHip': 0, + 'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10, + 'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11, + 'LShoulder': 16, 'LElbow': 18, 'LWrist': 20, 'LHand': 22, + 'RShoulder': 17, 'RElbow': 19, 'RWrist': 21, 'RHand': 23, + 'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15, + 'LCollar': 13, 'RCollar': 14 +} + +full_smpl_idx = range(24) +key_smpl_idx = [0, 1, 4, 7, 2, 5, 8, 17, 19, 21, 16, 18, 20] + +AMASS_JOINT_MAP = { + 'MidHip': 0, + 'LHip': 1, 'LKnee': 4, 'LAnkle': 7, 'LFoot': 10, + 'RHip': 2, 'RKnee': 5, 'RAnkle': 8, 'RFoot': 11, + 'LShoulder': 16, 'LElbow': 18, 'LWrist': 20, + 'RShoulder': 17, 'RElbow': 19, 'RWrist': 21, + 'spine1': 3, 'spine2': 6, 'spine3': 9, 'Neck': 12, 'Head': 15, + 'LCollar': 13, 'RCollar': 14, +} +amass_idx = range(22) +amass_smpl_idx = range(22) + +SMPL_MODEL_DIR = "./deps/smpl_models" +GMM_MODEL_DIR = "./deps/smpl_models" +SMPL_MEAN_FILE = "./deps/smpl_models/neutral_smpl_mean_params.h5" +# for collision +Part_Seg_DIR = "./deps/smpl_models/smplx_parts_segm.pkl" diff --git a/mld/transforms/joints2rots/customloss.py b/mld/transforms/joints2rots/customloss.py new file mode 100644 index 0000000000000000000000000000000000000000..c7f3a922184c1efc5a3198cfe120c0e2ab7b6554 --- /dev/null +++ b/mld/transforms/joints2rots/customloss.py @@ -0,0 +1,100 @@ +import torch + +from mld.transforms.joints2rots import config + + +def gmof(x, sigma): + """ + Geman-McClure error function + """ + x_squared = x ** 2 + sigma_squared = sigma ** 2 + return (sigma_squared * x_squared) / (sigma_squared + x_squared) + + +def angle_prior(pose): + """ + Angle prior that penalizes unnatural bending of the knees and elbows + """ + # We subtract 3 because pose does not include the global rotation of the model + return torch.exp( + pose[:, [55 - 3, 58 - 3, 12 - 3, 15 - 3]] * torch.tensor([1., -1., -1, -1.], device=pose.device)) ** 2 + + +def body_fitting_loss_3d(body_pose, preserve_pose, + betas, model_joints, camera_translation, + j3d, pose_prior, + joints3d_conf, + sigma=100, pose_prior_weight=4.78 * 1.5, + shape_prior_weight=5.0, angle_prior_weight=15.2, + joint_loss_weight=500.0, + pose_preserve_weight=0.0, + use_collision=False, + model_vertices=None, model_faces=None, + search_tree=None, pen_distance=None, filter_faces=None, + collision_loss_weight=1000 + ): + """ + Loss function for body fitting + """ + batch_size = body_pose.shape[0] + + joint3d_error = gmof((model_joints + camera_translation) - j3d, sigma) + + joint3d_loss_part = (joints3d_conf ** 2) * joint3d_error.sum(dim=-1) + joint3d_loss = ((joint_loss_weight ** 2) * joint3d_loss_part).sum(dim=-1) + + # Pose prior loss + pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas) + # Angle prior for knees and elbows + angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1) + # Regularizer to prevent betas from taking large values + shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1) + + collision_loss = 0.0 + # Calculate the loss due to interpenetration + if use_collision: + triangles = torch.index_select( + model_vertices, 1, + model_faces).view(batch_size, -1, 3, 3) + + with torch.no_grad(): + collision_idxs = search_tree(triangles) + + # Remove unwanted collisions + if filter_faces is not None: + collision_idxs = filter_faces(collision_idxs) + + if collision_idxs.ge(0).sum().item() > 0: + collision_loss = torch.sum(collision_loss_weight * pen_distance(triangles, collision_idxs)) + + pose_preserve_loss = (pose_preserve_weight ** 2) * ((body_pose - preserve_pose) ** 2).sum(dim=-1) + + total_loss = joint3d_loss + pose_prior_loss + angle_prior_loss + shape_prior_loss + collision_loss + pose_preserve_loss + + return total_loss.sum() + + +def camera_fitting_loss_3d(model_joints, camera_t, camera_t_est, + j3d, joints_category="orig", depth_loss_weight=100.0): + """ + Loss function for camera optimization. + """ + model_joints = model_joints + camera_t + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + if joints_category == "orig": + select_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + elif joints_category == "AMASS": + select_joints_ind = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints] + else: + print("NO SUCH JOINTS CATEGORY!") + + j3d_error_loss = (j3d[:, select_joints_ind] - model_joints[:, gt_joints_ind]) ** 2 + + # Loss that penalizes deviation from depth estimate + depth_loss = (depth_loss_weight ** 2) * (camera_t - camera_t_est) ** 2 + + total_loss = j3d_error_loss + depth_loss + return total_loss.sum() diff --git a/mld/transforms/joints2rots/prior.py b/mld/transforms/joints2rots/prior.py new file mode 100644 index 0000000000000000000000000000000000000000..70e8d19558661c712f1c3de82340d994c837dd3d --- /dev/null +++ b/mld/transforms/joints2rots/prior.py @@ -0,0 +1,208 @@ +import os +import sys +import pickle + +import numpy as np + +import torch +import torch.nn as nn + +DEFAULT_DTYPE = torch.float32 + + +def create_prior(prior_type, **kwargs): + if prior_type == 'gmm': + prior = MaxMixturePrior(**kwargs) + elif prior_type == 'l2': + return L2Prior(**kwargs) + elif prior_type == 'angle': + return SMPLifyAnglePrior(**kwargs) + elif prior_type == 'none' or prior_type is None: + # Don't use any pose prior + def no_prior(*args, **kwargs): + return 0.0 + prior = no_prior + else: + raise ValueError('Prior {}'.format(prior_type) + ' is not implemented') + return prior + + +class SMPLifyAnglePrior(nn.Module): + def __init__(self, dtype=DEFAULT_DTYPE, **kwargs): + super(SMPLifyAnglePrior, self).__init__() + + # Indices for the rotation angle of + # 55: left elbow, 90deg bend at -np.pi/2 + # 58: right elbow, 90deg bend at np.pi/2 + # 12: left knee, 90deg bend at np.pi/2 + # 15: right knee, 90deg bend at np.pi/2 + angle_prior_idxs = np.array([55, 58, 12, 15], dtype=np.int64) + angle_prior_idxs = torch.tensor(angle_prior_idxs, dtype=torch.long) + self.register_buffer('angle_prior_idxs', angle_prior_idxs) + + angle_prior_signs = np.array([1, -1, -1, -1], + dtype=np.float32 if dtype == torch.float32 + else np.float64) + angle_prior_signs = torch.tensor(angle_prior_signs, + dtype=dtype) + self.register_buffer('angle_prior_signs', angle_prior_signs) + + def forward(self, pose, with_global_pose=False): + ''' Returns the angle prior loss for the given pose + + Args: + pose: (Bx[23 + 1] * 3) torch tensor with the axis-angle + representation of the rotations of the joints of the SMPL model. + Kwargs: + with_global_pose: Whether the pose vector also contains the global + orientation of the SMPL model. If not then the indices must be + corrected. + Returns: + A sze (B) tensor containing the angle prior loss for each element + in the batch. + ''' + angle_prior_idxs = self.angle_prior_idxs - (not with_global_pose) * 3 + return torch.exp(pose[:, angle_prior_idxs] * + self.angle_prior_signs).pow(2) + + +class L2Prior(nn.Module): + def __init__(self, dtype=DEFAULT_DTYPE, reduction='sum', **kwargs): + super(L2Prior, self).__init__() + + def forward(self, module_input, *args): + return torch.sum(module_input.pow(2)) + + +class MaxMixturePrior(nn.Module): + + def __init__(self, prior_folder='prior', + num_gaussians=6, dtype=DEFAULT_DTYPE, epsilon=1e-16, + use_merged=True, + **kwargs): + super(MaxMixturePrior, self).__init__() + + if dtype == DEFAULT_DTYPE: + np_dtype = np.float32 + elif dtype == torch.float64: + np_dtype = np.float64 + else: + print('Unknown float type {}, exiting!'.format(dtype)) + sys.exit(-1) + + self.num_gaussians = num_gaussians + self.epsilon = epsilon + self.use_merged = use_merged + gmm_fn = 'gmm_{:02d}.pkl'.format(num_gaussians) + + full_gmm_fn = os.path.join(prior_folder, gmm_fn) + if not os.path.exists(full_gmm_fn): + print('The path to the mixture prior "{}"'.format(full_gmm_fn) + + ' does not exist, exiting!') + sys.exit(-1) + + with open(full_gmm_fn, 'rb') as f: + gmm = pickle.load(f, encoding='latin1') + + if type(gmm) == dict: + means = gmm['means'].astype(np_dtype) + covs = gmm['covars'].astype(np_dtype) + weights = gmm['weights'].astype(np_dtype) + elif 'sklearn.mixture.gmm.GMM' in str(type(gmm)): + means = gmm.means_.astype(np_dtype) + covs = gmm.covars_.astype(np_dtype) + weights = gmm.weights_.astype(np_dtype) + else: + print('Unknown type for the prior: {}, exiting!'.format(type(gmm))) + sys.exit(-1) + + self.register_buffer('means', torch.tensor(means, dtype=dtype)) + + self.register_buffer('covs', torch.tensor(covs, dtype=dtype)) + + precisions = [np.linalg.inv(cov) for cov in covs] + precisions = np.stack(precisions).astype(np_dtype) + + self.register_buffer('precisions', + torch.tensor(precisions, dtype=dtype)) + + # The constant term: + sqrdets = np.array([(np.sqrt(np.linalg.det(c))) + for c in gmm['covars']]) + const = (2 * np.pi)**(69 / 2.) + + nll_weights = np.asarray(gmm['weights'] / (const * + (sqrdets / sqrdets.min()))) + nll_weights = torch.tensor(nll_weights, dtype=dtype).unsqueeze(dim=0) + self.register_buffer('nll_weights', nll_weights) + + weights = torch.tensor(gmm['weights'], dtype=dtype).unsqueeze(dim=0) + self.register_buffer('weights', weights) + + self.register_buffer('pi_term', + torch.log(torch.tensor(2 * np.pi, dtype=dtype))) + + cov_dets = [np.log(np.linalg.det(cov.astype(np_dtype)) + epsilon) + for cov in covs] + self.register_buffer('cov_dets', + torch.tensor(cov_dets, dtype=dtype)) + + # The dimensionality of the random variable + self.random_var_dim = self.means.shape[1] + + def get_mean(self): + ''' Returns the mean of the mixture ''' + mean_pose = torch.matmul(self.weights, self.means) + return mean_pose + + def merged_log_likelihood(self, pose, betas): + diff_from_mean = pose.unsqueeze(dim=1) - self.means + + prec_diff_prod = torch.einsum('mij,bmj->bmi', + [self.precisions, diff_from_mean]) + diff_prec_quadratic = (prec_diff_prod * diff_from_mean).sum(dim=-1) + + curr_loglikelihood = 0.5 * diff_prec_quadratic - \ + torch.log(self.nll_weights) + # curr_loglikelihood = 0.5 * (self.cov_dets.unsqueeze(dim=0) + + # self.random_var_dim * self.pi_term + + # diff_prec_quadratic + # ) - torch.log(self.weights) + + min_likelihood, _ = torch.min(curr_loglikelihood, dim=1) + return min_likelihood + + def log_likelihood(self, pose, betas, *args, **kwargs): + ''' Create graph operation for negative log-likelihood calculation + ''' + likelihoods = [] + + for idx in range(self.num_gaussians): + mean = self.means[idx] + prec = self.precisions[idx] + cov = self.covs[idx] + diff_from_mean = pose - mean + + curr_loglikelihood = torch.einsum('bj,ji->bi', + [diff_from_mean, prec]) + curr_loglikelihood = torch.einsum('bi,bi->b', + [curr_loglikelihood, + diff_from_mean]) + cov_term = torch.log(torch.det(cov) + self.epsilon) + curr_loglikelihood += 0.5 * (cov_term + + self.random_var_dim * + self.pi_term) + likelihoods.append(curr_loglikelihood) + + log_likelihoods = torch.stack(likelihoods, dim=1) + min_idx = torch.argmin(log_likelihoods, dim=1) + weight_component = self.nll_weights[:, min_idx] + weight_component = -torch.log(weight_component) + + return weight_component + log_likelihoods[:, min_idx] + + def forward(self, pose, betas): + if self.use_merged: + return self.merged_log_likelihood(pose, betas) + else: + return self.log_likelihood(pose, betas) diff --git a/mld/transforms/joints2rots/smplify.py b/mld/transforms/joints2rots/smplify.py new file mode 100644 index 0000000000000000000000000000000000000000..c5ed9dc87dcffb78257f88e997f5c4aabffcce20 --- /dev/null +++ b/mld/transforms/joints2rots/smplify.py @@ -0,0 +1,274 @@ +import os +import pickle + +from tqdm import tqdm + +import torch + +from mld.transforms.joints2rots import config +from mld.transforms.joints2rots.customloss import camera_fitting_loss_3d, body_fitting_loss_3d +from mld.transforms.joints2rots.prior import MaxMixturePrior + + +@torch.no_grad() +def guess_init_3d(model_joints, j3d, joints_category="orig"): + """ + Initialize the camera translation via triangle similarity, by using the torso joints . + """ + # get the indexed four + gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] + gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] + + if joints_category == "orig": + joints_ind_category = [config.JOINT_MAP[joint] for joint in gt_joints] + elif joints_category == "AMASS": + joints_ind_category = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints] + else: + print("NO SUCH JOINTS CATEGORY!") + + sum_init_t = (j3d[:, joints_ind_category] - model_joints[:, gt_joints_ind]).sum(dim=1) + init_t = sum_init_t / 4.0 + return init_t + + +# SMPLIfy 3D +class SMPLify3D(): + """Implementation of SMPLify, use 3D joints.""" + + def __init__(self, + smplxmodel, + step_size=1e-2, + batch_size=1, + num_iters=100, + use_collision=False, + use_lbfgs=True, + joints_category="orig", + device=torch.device('cuda:0'), + ): + + # Store options + self.batch_size = batch_size + self.device = device + self.step_size = step_size + + self.num_iters = num_iters + # --- choose optimizer + self.use_lbfgs = use_lbfgs + # GMM pose prior + self.pose_prior = MaxMixturePrior(prior_folder=config.GMM_MODEL_DIR, + num_gaussians=8, + dtype=torch.float32).to(device) + # collision part + self.use_collision = use_collision + if self.use_collision: + self.part_segm_fn = config.Part_Seg_DIR + + # reLoad SMPL-X model + self.smpl = smplxmodel + + self.model_faces = smplxmodel.faces_tensor.view(-1) + + # select joint joint_category + self.joints_category = joints_category + + if joints_category == "orig": + self.smpl_index = config.full_smpl_idx + self.corr_index = config.full_smpl_idx + elif joints_category == "AMASS": + self.smpl_index = config.amass_smpl_idx + self.corr_index = config.amass_idx + else: + self.smpl_index = None + self.corr_index = None + print("NO SUCH JOINTS CATEGORY!") + + # ---- get the man function here ------ + def __call__(self, init_pose, init_betas, init_cam_t, j3d, conf_3d=1.0, seq_ind=0): + """Perform body fitting. + Input: + init_pose: SMPL pose estimate + init_betas: SMPL betas estimate + init_cam_t: Camera translation estimate + j3d: joints 3d aka keypoints + conf_3d: confidence for 3d joints + seq_ind: index of the sequence + Returns: + vertices: Vertices of optimized shape + joints: 3D joints of optimized shape + pose: SMPL pose parameters of optimized shape + betas: SMPL beta parameters of optimized shape + camera_translation: Camera translation + """ + + # # # add the mesh inter-section to avoid + search_tree = None + pen_distance = None + filter_faces = None + + if self.use_collision: + from mesh_intersection.bvh_search_tree import BVH + import mesh_intersection.loss as collisions_loss + from mesh_intersection.filter_faces import FilterFaces + + search_tree = BVH(max_collisions=8) + + pen_distance = collisions_loss.DistanceFieldPenetrationLoss( + sigma=0.5, point2plane=False, vectorized=True, penalize_outside=True) + + if self.part_segm_fn: + # Read the part segmentation + part_segm_fn = os.path.expandvars(self.part_segm_fn) + with open(part_segm_fn, 'rb') as faces_parents_file: + face_segm_data = pickle.load(faces_parents_file, encoding='latin1') + faces_segm = face_segm_data['segm'] + faces_parents = face_segm_data['parents'] + # Create the module used to filter invalid collision pairs + filter_faces = FilterFaces( + faces_segm=faces_segm, faces_parents=faces_parents, + ign_part_pairs=None).to(device=self.device) + + # Split SMPL pose to body pose and global orientation + body_pose = init_pose[:, 3:].detach().clone() + global_orient = init_pose[:, :3].detach().clone() + betas = init_betas.detach().clone() + + # use guess 3d to get the initial + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + init_cam_t = guess_init_3d(model_joints, j3d, self.joints_category).unsqueeze(1).detach() + camera_translation = init_cam_t.clone() + + preserve_pose = init_pose[:, 3:].detach().clone() + # -------------Step 1: Optimize camera translation and body orientation-------- + # Optimize only camera translation and body orientation + body_pose.requires_grad = False + betas.requires_grad = False + global_orient.requires_grad = True + camera_translation.requires_grad = True + + camera_opt_params = [global_orient, camera_translation] + + if self.use_lbfgs: + camera_optimizer = torch.optim.LBFGS(camera_opt_params, max_iter=self.num_iters, + lr=self.step_size, line_search_fn='strong_wolfe') + for i in range(10): + def closure(): + camera_optimizer.zero_grad() + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + loss = camera_fitting_loss_3d(model_joints, camera_translation, + init_cam_t, j3d, self.joints_category) + loss.backward() + return loss + + camera_optimizer.step(closure) + else: + camera_optimizer = torch.optim.Adam(camera_opt_params, lr=self.step_size, betas=(0.9, 0.999)) + + for i in range(20): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + + loss = camera_fitting_loss_3d(model_joints[:, self.smpl_index], camera_translation, + init_cam_t, j3d[:, self.corr_index], self.joints_category) + camera_optimizer.zero_grad() + loss.backward() + camera_optimizer.step() + + # Fix camera translation after optimizing camera + # --------Step 2: Optimize body joints -------------------------- + # Optimize only the body pose and global orientation of the body + body_pose.requires_grad = True + global_orient.requires_grad = True + camera_translation.requires_grad = True + + # --- if we use the sequence, fix the shape + if seq_ind == 0: + betas.requires_grad = True + body_opt_params = [body_pose, betas, global_orient, camera_translation] + else: + betas.requires_grad = False + body_opt_params = [body_pose, global_orient, camera_translation] + + if self.use_lbfgs: + body_optimizer = torch.optim.LBFGS(body_opt_params, max_iter=self.num_iters, + lr=self.step_size, line_search_fn='strong_wolfe') + + for i in tqdm(range(self.num_iters), desc=f"LBFGS iter: "): + def closure(): + body_optimizer.zero_grad() + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], + camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + pose_preserve_weight=5.0, + use_collision=self.use_collision, + model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, + filter_faces=filter_faces) + loss.backward() + return loss + + body_optimizer.step(closure) + else: + body_optimizer = torch.optim.Adam(body_opt_params, lr=self.step_size, betas=(0.9, 0.999)) + + for i in range(self.num_iters): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], + camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + use_collision=self.use_collision, + model_vertices=model_vertices, model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, + filter_faces=filter_faces) + body_optimizer.zero_grad() + loss.backward() + body_optimizer.step() + + # Get final loss value + with torch.no_grad(): + smpl_output = self.smpl(global_orient=global_orient, + body_pose=body_pose, + betas=betas, return_full_pose=True) + model_joints = smpl_output.joints + model_vertices = smpl_output.vertices + + final_loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], + camera_translation, + j3d[:, self.corr_index], self.pose_prior, + joints3d_conf=conf_3d, + joint_loss_weight=600.0, + use_collision=self.use_collision, model_vertices=model_vertices, + model_faces=self.model_faces, + search_tree=search_tree, pen_distance=pen_distance, + filter_faces=filter_faces) + + vertices = smpl_output.vertices.detach() + joints = smpl_output.joints.detach() + pose = torch.cat([global_orient, body_pose], dim=-1).detach() + betas = betas.detach() + + return vertices, joints, pose, betas, camera_translation, final_loss diff --git a/mld/utils/__init__.py b/mld/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mld/utils/temos_utils.py b/mld/utils/temos_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..22b62750d1dd0dca0ca2b85af58109043b0559c1 --- /dev/null +++ b/mld/utils/temos_utils.py @@ -0,0 +1,18 @@ +import torch + + +def lengths_to_mask(lengths: list[int], + device: torch.device, + max_len: int = None) -> torch.Tensor: + lengths = torch.tensor(lengths, device=device) + max_len = max_len if max_len else max(lengths) + mask = torch.arange(max_len, device=device).expand( + len(lengths), max_len) < lengths.unsqueeze(1) + return mask + + +def remove_padding(tensors: torch.Tensor, lengths: list[int]) -> list: + return [ + tensor[:tensor_length] + for tensor, tensor_length in zip(tensors, lengths) + ] diff --git a/mld/utils/utils.py b/mld/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..81aef88275cb15b5bb320382e36690101919bd2e --- /dev/null +++ b/mld/utils/utils.py @@ -0,0 +1,66 @@ +import random + +import numpy as np + +from rich import get_console +from rich.table import Table + +import torch +import torch.nn as nn + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def print_table(title: str, metrics: dict) -> None: + table = Table(title=title) + + table.add_column("Metrics", style="cyan", no_wrap=True) + table.add_column("Value", style="magenta") + + for key, value in metrics.items(): + table.add_row(key, str(value)) + + console = get_console() + console.print(table, justify="center") + + +def move_batch_to_device(batch: dict, device: torch.device) -> dict: + for key in batch.keys(): + if isinstance(batch[key], torch.Tensor): + batch[key] = batch[key].to(device) + return batch + + +def count_parameters(module: nn.Module) -> float: + num_params = sum(p.numel() for p in module.parameters()) + return round(num_params / 1e6, 3) + + +def get_guidance_scale_embedding(w: torch.Tensor, embedding_dim: int = 512, + dtype: torch.dtype = torch.float32) -> torch.Tensor: + assert len(w.shape) == 1 + w = w * 1000.0 + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + +def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor: + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def sum_flat(tensor: torch.Tensor) -> torch.Tensor: + return tensor.sum(dim=list(range(1, len(tensor.shape)))) diff --git a/prepare/download_glove.sh b/prepare/download_glove.sh new file mode 100644 index 0000000000000000000000000000000000000000..7bb914f20ef6f1ac3e3979a9b7eab6f18c282f0f --- /dev/null +++ b/prepare/download_glove.sh @@ -0,0 +1,12 @@ +mkdir -p deps/ +cd deps/ + +echo -e "Downloading glove (in use by the evaluators)" +gdown --fuzzy https://drive.google.com/file/d/1cmXKUT31pqd7_XpJAiWEo1K81TMYHA5n/view?usp=sharing +rm -rf glove + +unzip glove.zip +echo -e "Cleaning\n" +rm glove.zip + +echo -e "Downloading done!" \ No newline at end of file diff --git a/prepare/download_pretrained_models.sh b/prepare/download_pretrained_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..cffe1709779f7d58add9b537eaa271447a65f182 --- /dev/null +++ b/prepare/download_pretrained_models.sh @@ -0,0 +1,14 @@ +echo -e "Downloading experiments_t2m!" +# experiments_t2m +gdown --fuzzy https://drive.google.com/file/d/1RpNumy8g2X2lI4H-kO9xkZVYRDDgyu6Y/view?usp=sharing +unzip experiments_t2m.zip + +echo -e "Downloading experiments_control!" +# experiments_control +gdown --fuzzy https://drive.google.com/file/d/1_xrSGWHo4pRz-AJj22rt8qmRX0SflJ6V/view?usp=sharing +unzip experiments_control.zip + +rm experiments_t2m.zip +rm experiments_control.zip + +echo -e "Downloading done!" diff --git a/prepare/download_smpl_models.sh b/prepare/download_smpl_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..33777c17720c0df4abfb10afc5926684fb8e157a --- /dev/null +++ b/prepare/download_smpl_models.sh @@ -0,0 +1,12 @@ +mkdir -p deps/ +cd deps/ + +echo -e "Downloading smpl models" +gdown --fuzzy https://drive.google.com/file/d/1J2pTxrar_q689Du5r3jES343fZUmCs_y/view?usp=sharing +rm -rf smpl_models + +unzip smpl_models.zip +echo -e "Cleaning\n" +rm smpl_models.zip + +echo -e "Downloading done!" \ No newline at end of file diff --git a/prepare/download_t2m_evaluators.sh b/prepare/download_t2m_evaluators.sh new file mode 100644 index 0000000000000000000000000000000000000000..4f38467b6d43686592b14ba5b538bf1788ee3595 --- /dev/null +++ b/prepare/download_t2m_evaluators.sh @@ -0,0 +1,13 @@ +mkdir -p deps/ +cd deps/ + +echo "The t2m evaluators will be stored in the './deps' folder" + +echo "Downloading" +gdown --fuzzy https://drive.google.com/file/d/16hyR4XlEyksVyNVjhIWK684Lrm_7_pvX/view?usp=sharing +echo "Extracting" +unzip t2m.zip +echo "Cleaning" +rm t2m.zip + +echo "Downloading done!" diff --git a/prepare/prepare_bert.sh b/prepare/prepare_bert.sh new file mode 100644 index 0000000000000000000000000000000000000000..a959290e55210d47e34a243460d4eacb54b18757 --- /dev/null +++ b/prepare/prepare_bert.sh @@ -0,0 +1,4 @@ +cd deps/ +git lfs install +git clone https://huggingface.co./distilbert-base-uncased +cd .. diff --git a/prepare/prepare_clip.sh b/prepare/prepare_clip.sh new file mode 100644 index 0000000000000000000000000000000000000000..529e91035e7245b41de91e5dfe19af6546236ea2 --- /dev/null +++ b/prepare/prepare_clip.sh @@ -0,0 +1,4 @@ +cd deps/ +git lfs install +git clone https://huggingface.co./openai/clip-vit-large-patch14 +cd .. diff --git a/prepare/prepare_t5.sh b/prepare/prepare_t5.sh new file mode 100644 index 0000000000000000000000000000000000000000..8a487c8048af37fbcfd73d482bad2602d6a760bb --- /dev/null +++ b/prepare/prepare_t5.sh @@ -0,0 +1,4 @@ +cd deps/ +git lfs install +git clone https://huggingface.co./sentence-transformers/sentence-t5-large +cd .. diff --git a/prepare/prepare_tiny_humanml3d.sh b/prepare/prepare_tiny_humanml3d.sh new file mode 100644 index 0000000000000000000000000000000000000000..a9691edc94dfeaf434e3848bd83a1ccf52fb5c71 --- /dev/null +++ b/prepare/prepare_tiny_humanml3d.sh @@ -0,0 +1,6 @@ +mkdir -p datasets/ +cd datasets/ + +gdown --fuzzy https://drive.google.com/file/d/1Mg_3RnWmRt0tk_lyLRRiOZg1W-Fu4wLL/view?usp=sharing +unzip humanml3d_tiny.zip +rm humanml3d_tiny.zip diff --git a/render.py b/render.py new file mode 100644 index 0000000000000000000000000000000000000000..d7b298bdad2dfc2d543f2c11c86e8c193050502e --- /dev/null +++ b/render.py @@ -0,0 +1,85 @@ +import os +import pickle +import random +import sys +import natsort +from argparse import ArgumentParser + +try: + import bpy + sys.path.append(os.path.dirname(bpy.data.filepath)) +except ImportError: + raise ImportError( + "Blender is not properly installed or not launch properly. See README.md to have instruction on how to install and use blender.") + +import mld.launch.blender # noqa +from mld.render.blender import render + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--pkl", type=str, default=None, help="pkl motion file") + parser.add_argument("--dir", type=str, default=None, help="pkl motion folder") + parser.add_argument("--mode", type=str, default="sequence", help="render target: video, sequence, frame") + parser.add_argument("--res", type=str, default="high") + parser.add_argument("--denoising", type=bool, default=True) + parser.add_argument("--oldrender", type=bool, default=True) + parser.add_argument("--accelerator", type=str, default='gpu', help='accelerator device') + parser.add_argument("--device", type=int, nargs='+', default=[0], help='gpu ids') + parser.add_argument("--faces_path", type=str, default='./deps/smpl_models/smplh/smplh.faces') + parser.add_argument("--always_on_floor", action="store_true", help='put all the body on the floor (not recommended)') + parser.add_argument("--gt", type=str, default=False, help='green for gt, otherwise orange') + parser.add_argument("--fps", type=int, default=20, help="the frame rate of the rendered video") + parser.add_argument("--num", type=int, default=8, help="the number of frames rendered in 'sequence' mode") + parser.add_argument("--exact_frame", type=float, default=0.5, help="the frame id selected under 'frame' mode ([0, 1])") + cfg = parser.parse_args() + return cfg + + +def render_cli() -> None: + cfg = parse_args() + + if cfg.pkl: + paths = [cfg.pkl] + elif cfg.dir: + paths = [] + file_list = natsort.natsorted(os.listdir(cfg.dir)) + begin_id = random.randrange(0, len(file_list)) + file_list = file_list[begin_id:] + file_list[:begin_id] + + for item in file_list: + if item.endswith("_mesh.pkl"): + paths.append(os.path.join(cfg.dir, item)) + else: + raise ValueError(f'{cfg.pkl} and {cfg.dir} are both None!') + + for path in paths: + try: + with open(path, 'rb') as f: + pkl = pickle.load(f) + data = pkl['vertices'] + trajectory = pkl['hint'] + + except FileNotFoundError: + print(f"{path} not found") + continue + + render( + data, + trajectory, + path, + exact_frame=cfg.exact_frame, + num=cfg.num, + mode=cfg.mode, + faces_path=cfg.faces_path, + always_on_floor=cfg.always_on_floor, + oldrender=cfg.oldrender, + res=cfg.res, + gt=cfg.gt, + accelerator=cfg.accelerator, + device=cfg.device, + fps=cfg.fps) + + +if __name__ == "__main__": + render_cli() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..c1fa323f2f5ef0492246c9576fc9e402e1fc3268 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +--extra-index-url https://download.pytorch.org/whl/cu116 +torch==1.13.1 +gdown +omegaconf +rich +torchmetrics==1.3.2 +scipy==1.11.2 +matplotlib==3.3.4 +transformers==4.35.2 +sentence-transformers==2.2.2 +diffusers==0.24.0 +tensorboard==2.15.1 +h5py==3.11.0 +smplx==0.1.28 +chumpy==0.70 +numpy==1.23.1 +natsort==8.4.0 \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..e4d903f301954d1f131c705b085c9c881511ddda --- /dev/null +++ b/test.py @@ -0,0 +1,125 @@ +import os +import sys +import json +import datetime +import logging +import os.path as osp + +import numpy as np +from tqdm.auto import tqdm +from omegaconf import OmegaConf + +import torch +from torch.utils.data import DataLoader + +from mld.config import parse_args +from mld.data.get_data import get_datasets +from mld.models.modeltype.mld import MLD +from mld.utils.utils import print_table, set_seed, move_batch_to_device + + +def get_metric_statistics(values: np.ndarray, replication_times: int) -> tuple: + mean = np.mean(values, axis=0) + std = np.std(values, axis=0) + conf_interval = 1.96 * std / np.sqrt(replication_times) + return mean, conf_interval + + +@torch.no_grad() +def test_one_epoch(model: MLD, dataloader: DataLoader, device: torch.device) -> dict: + for batch in tqdm(dataloader): + batch = move_batch_to_device(batch, device) + model.test_step(batch) + metrics = model.allsplit_epoch_end() + return metrics + + +def main(): + cfg = parse_args() + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + set_seed(cfg.TRAIN.SEED_VALUE) + + name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) + output_dir = osp.join(cfg.TEST_FOLDER, name_time_str) + os.makedirs(output_dir, exist_ok=False) + + steam_handler = logging.StreamHandler(sys.stdout) + file_handler = logging.FileHandler(osp.join(output_dir, 'output.log')) + logging.basicConfig(level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[steam_handler, file_handler]) + logger = logging.getLogger(__name__) + + OmegaConf.save(cfg, osp.join(output_dir, 'config.yaml')) + + state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"] + logger.info("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS)) + + lcm_key = 'denoiser.time_embedding.cond_proj.weight' + is_lcm = False + if lcm_key in state_dict: + is_lcm = True + time_cond_proj_dim = state_dict[lcm_key].shape[1] + cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim + logger.info(f'Is LCM: {is_lcm}') + + cn_key = "controlnet.controlnet_cond_embedding.0.weight" + is_controlnet = True if cn_key in state_dict else False + cfg.model.is_controlnet = is_controlnet + logger.info(f'Is Controlnet: {is_controlnet}') + + datasets = get_datasets(cfg, phase="test")[0] + test_dataloader = datasets.test_dataloader() + model = MLD(cfg, datasets) + model.to(device) + model.eval() + model.load_state_dict(state_dict) + + all_metrics = {} + replication_times = cfg.TEST.REPLICATION_TIMES + max_num_samples = cfg.TEST.get('MAX_NUM_SAMPLES', len(test_dataloader.dataset)) + name_list = test_dataloader.dataset.name_list + # calculate metrics + for i in range(replication_times): + chosen_list = np.random.choice(name_list, max_num_samples, replace=False) + test_dataloader.dataset.name_list = chosen_list + + metrics_type = ", ".join(cfg.METRIC.TYPE) + logger.info(f"Evaluating {metrics_type} - Replication {i}") + metrics = test_one_epoch(model, test_dataloader, device) + + if "TM2TMetrics" in metrics_type: + test_dataloader.dataset.name_list = name_list + # mm metrics + logger.info(f"Evaluating MultiModality - Replication {i}") + datasets.mm_mode(True) + mm_metrics = test_one_epoch(model, test_dataloader, device) + metrics.update(mm_metrics) + datasets.mm_mode(False) + + print_table(f"Metrics@Replication-{i}", metrics) + logger.info(metrics) + + for key, item in metrics.items(): + if key not in all_metrics: + all_metrics[key] = [item] + else: + all_metrics[key] += [item] + + all_metrics_new = dict() + for key, item in all_metrics.items(): + mean, conf_interval = get_metric_statistics(np.array(item), replication_times) + all_metrics_new[key + "/mean"] = mean + all_metrics_new[key + "/conf_interval"] = conf_interval + print_table(f"Mean Metrics", all_metrics_new) + all_metrics_new.update(all_metrics) + # save metrics to file + metric_file = osp.join(output_dir, f"metrics.json") + with open(metric_file, "w", encoding="utf-8") as f: + json.dump(all_metrics_new, f, indent=4) + logger.info(f"Testing done, the metrics are saved to {str(metric_file)}") + + +if __name__ == "__main__": + main() diff --git a/train_motion_control.py b/train_motion_control.py new file mode 100644 index 0000000000000000000000000000000000000000..be37ca1574e9c66d29f1d5eb9307ff5a507793f5 --- /dev/null +++ b/train_motion_control.py @@ -0,0 +1,204 @@ +import os +import sys +import logging +import datetime +import os.path as osp + +from tqdm.auto import tqdm +from omegaconf import OmegaConf + +import torch +import diffusers +import transformers +from torch.utils.tensorboard import SummaryWriter +from diffusers.optimization import get_scheduler + +from mld.config import parse_args +from mld.data.get_data import get_datasets +from mld.models.modeltype.mld import MLD +from mld.utils.utils import print_table, set_seed, move_batch_to_device + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def main(): + cfg = parse_args() + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + set_seed(cfg.TRAIN.SEED_VALUE) + + name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) + output_dir = osp.join(cfg.FOLDER, name_time_str) + os.makedirs(output_dir, exist_ok=False) + os.makedirs(f"{output_dir}/checkpoints", exist_ok=False) + + writer = SummaryWriter(output_dir) + + stream_handler = logging.StreamHandler(sys.stdout) + file_handler = logging.FileHandler(osp.join(output_dir, 'output.log')) + handlers = [file_handler, stream_handler] + logging.basicConfig(level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=handlers) + logger = logging.getLogger(__name__) + + OmegaConf.save(cfg, osp.join(output_dir, 'config.yaml')) + + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + + assert cfg.model.is_controlnet, "cfg.model.is_controlnet must be true for controlling!" + + datasets = get_datasets(cfg)[0] + train_dataloader = datasets.train_dataloader() + val_dataloader = datasets.val_dataloader() + + logger.info(f"Loading pretrained model: {cfg.TRAIN.PRETRAINED}") + state_dict = torch.load(cfg.TRAIN.PRETRAINED, map_location="cpu")["state_dict"] + lcm_key = 'denoiser.time_embedding.cond_proj.weight' + is_lcm = False + if lcm_key in state_dict: + is_lcm = True + time_cond_proj_dim = state_dict[lcm_key].shape[1] + cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim + logger.info(f'Is LCM: {is_lcm}') + + model = MLD(cfg, datasets) + logger.info(model.load_state_dict(state_dict, strict=False)) + logger.info(model.controlnet.load_state_dict(model.denoiser.state_dict(), strict=False)) + + model.vae.requires_grad_(False) + model.text_encoder.requires_grad_(False) + model.denoiser.requires_grad_(False) + model.vae.eval() + model.text_encoder.eval() + model.denoiser.eval() + model.to(device) + + controlnet_params = list(model.controlnet.parameters()) + traj_encoder_params = list(model.traj_encoder.parameters()) + params = controlnet_params + traj_encoder_params + params_to_optimize = [{'params': controlnet_params, 'lr': cfg.TRAIN.learning_rate}, + {'params': traj_encoder_params, 'lr': cfg.TRAIN.learning_rate_spatial}] + + logger.info("learning_rate: {}, learning_rate_spatial: {}". + format(cfg.TRAIN.learning_rate, cfg.TRAIN.learning_rate_spatial)) + + optimizer = torch.optim.AdamW( + params_to_optimize, + betas=(cfg.TRAIN.adam_beta1, cfg.TRAIN.adam_beta2), + weight_decay=cfg.TRAIN.adam_weight_decay, + eps=cfg.TRAIN.adam_epsilon) + + if cfg.TRAIN.max_train_steps == -1: + assert cfg.TRAIN.max_train_epochs != -1 + cfg.TRAIN.max_train_steps = cfg.TRAIN.max_train_epochs * len(train_dataloader) + + if cfg.TRAIN.checkpointing_steps == -1: + assert cfg.TRAIN.checkpointing_epochs != -1 + cfg.TRAIN.checkpointing_steps = cfg.TRAIN.checkpointing_epochs * len(train_dataloader) + + if cfg.TRAIN.validation_steps == -1: + assert cfg.TRAIN.validation_epochs != -1 + cfg.TRAIN.validation_steps = cfg.TRAIN.validation_epochs * len(train_dataloader) + + lr_scheduler = get_scheduler( + cfg.TRAIN.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.TRAIN.lr_warmup_steps, + num_training_steps=cfg.TRAIN.max_train_steps) + + # Train! + logger.info("***** Running training *****") + logging.info(f" Num examples = {len(train_dataloader.dataset)}") + logging.info(f" Num Epochs = {cfg.TRAIN.max_train_epochs}") + logging.info(f" Instantaneous batch size per device = {cfg.TRAIN.BATCH_SIZE}") + logging.info(f" Total optimization steps = {cfg.TRAIN.max_train_steps}") + + global_step = 0 + first_epoch = 0 + + progress_bar = tqdm(range(0, cfg.TRAIN.max_train_steps), desc="Steps") + + @torch.no_grad() + def validation(): + model.controlnet.eval() + model.traj_encoder.eval() + + for val_batch in tqdm(val_dataloader): + val_batch = move_batch_to_device(val_batch, device) + model.allsplit_step('test', val_batch) + metrics = model.allsplit_epoch_end() + min_val_km = metrics['Metrics/kps_mean_err(m)'] + min_val_tj = metrics['Metrics/traj_fail_50cm'] + print_table(f'Metrics@Step-{global_step}', metrics) + for k, v in metrics.items(): + writer.add_scalar(k, v, global_step=global_step) + + model.controlnet.train() + model.traj_encoder.train() + return min_val_km, min_val_tj + + min_km, min_tj = validation() + + for epoch in range(first_epoch, cfg.TRAIN.max_train_epochs): + for step, batch in enumerate(train_dataloader): + batch = move_batch_to_device(batch, device) + loss_dict = model.allsplit_step('train', batch) + + diff_loss = loss_dict['diff_loss'] + cond_loss = loss_dict['cond_loss'] + rot_loss = loss_dict['rot_loss'] + loss = diff_loss + cond_loss + rot_loss + + loss.backward() + torch.nn.utils.clip_grad_norm_(params, cfg.TRAIN.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + progress_bar.update(1) + global_step += 1 + + if global_step % cfg.TRAIN.checkpointing_steps == 0: + save_path = os.path.join(output_dir, 'checkpoints', f"checkpoint-{global_step}.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % cfg.TRAIN.validation_steps == 0: + cur_km, cur_tj = validation() + if cur_km < min_km: + min_km = cur_km + save_path = os.path.join(output_dir, 'checkpoints', f"checkpoint-{global_step}-km-{round(cur_km, 3)}.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path} with km:{round(cur_km, 3)}") + + if cur_tj < min_tj: + min_tj = cur_tj + save_path = os.path.join(output_dir, 'checkpoints', f"checkpoint-{global_step}-tj-{round(cur_tj, 3)}.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path} with tj:{round(cur_tj, 3)}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], + "diff_loss": diff_loss.detach().item(), 'cond_loss': cond_loss.detach().item(), 'rot_loss': rot_loss.detach().item()} + progress_bar.set_postfix(**logs) + for k, v in logs.items(): + writer.add_scalar(k, v, global_step=global_step) + + if global_step >= cfg.TRAIN.max_train_steps: + break + + save_path = os.path.join(output_dir, 'checkpoints', f"checkpoint-last.ckpt") + ckpt = dict(state_dict=model.state_dict()) + model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + + +if __name__ == "__main__": + main() diff --git a/train_motionlcm.py b/train_motionlcm.py new file mode 100644 index 0000000000000000000000000000000000000000..0a7edaf36fe4163942a551cc7a74165549d9ee05 --- /dev/null +++ b/train_motionlcm.py @@ -0,0 +1,427 @@ +import os +import sys +import logging +import datetime +import os.path as osp +from typing import Generator + +import numpy as np +from tqdm.auto import tqdm +from omegaconf import OmegaConf + +import torch +import diffusers +import transformers +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from diffusers.optimization import get_scheduler + +from mld.config import parse_args, instantiate_from_config +from mld.data.get_data import get_datasets +from mld.models.modeltype.mld import MLD +from mld.utils.utils import print_table, set_seed, move_batch_to_device + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def guidance_scale_embedding(w: torch.Tensor, embedding_dim: int = 512, + dtype: torch.dtype = torch.float32) -> torch.Tensor: + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + +def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def scalings_for_boundary_conditions(timestep: torch.Tensor, sigma_data: float = 0.5, + timestep_scaling: float = 10.0) -> tuple: + c_skip = sigma_data ** 2 / ((timestep * timestep_scaling) ** 2 + sigma_data ** 2) + c_out = (timestep * timestep_scaling) / ((timestep * timestep_scaling) ** 2 + sigma_data ** 2) ** 0.5 + return c_skip, c_out + + +def predicted_origin( + model_output: torch.Tensor, + timesteps: torch.Tensor, + sample: torch.Tensor, + prediction_type: str, + alphas: torch.Tensor, + sigmas: torch.Tensor +) -> torch.Tensor: + if prediction_type == "epsilon": + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "v_prediction": + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + pred_x_0 = alphas * sample - sigmas * model_output + else: + raise ValueError(f"Prediction type {prediction_type} currently not supported.") + + return pred_x_0 + + +def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor: + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +class DDIMSolver: + def __init__(self, alpha_cumprods: np.ndarray, timesteps: int = 1000, ddim_timesteps: int = 50) -> None: + # DDIM sampling parameters + step_ratio = timesteps // ddim_timesteps + self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 + self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] + self.ddim_alpha_cumprods_prev = np.asarray( + [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() + ) + # convert to torch tensors + self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() + self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) + self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) + + def to(self, device: torch.device) -> "DDIMSolver": + self.ddim_timesteps = self.ddim_timesteps.to(device) + self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) + self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) + return self + + def ddim_step(self, pred_x0: torch.Tensor, pred_noise: torch.Tensor, + timestep_index: torch.Tensor) -> torch.Tensor: + alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape) + dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise + x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt + return x_prev + + +@torch.no_grad() +def update_ema(target_params: Generator, source_params: Generator, rate: float = 0.99) -> None: + for tgt, src in zip(target_params, source_params): + tgt.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def main(): + cfg = parse_args() + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + set_seed(cfg.TRAIN.SEED_VALUE) + + name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) + output_dir = osp.join(cfg.FOLDER, name_time_str) + os.makedirs(output_dir, exist_ok=False) + os.makedirs(f"{output_dir}/checkpoints", exist_ok=False) + + writer = SummaryWriter(output_dir) + + stream_handler = logging.StreamHandler(sys.stdout) + file_handler = logging.FileHandler(osp.join(output_dir, 'output.log')) + handlers = [file_handler, stream_handler] + logging.basicConfig(level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=handlers) + logger = logging.getLogger(__name__) + + OmegaConf.save(cfg, osp.join(output_dir, 'config.yaml')) + + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + + logger.info(f'Training guidance scale range (w): [{cfg.TRAIN.w_min}, {cfg.TRAIN.w_max}]') + logger.info(f'EMA rate (mu): {cfg.TRAIN.ema_decay}') + logger.info(f'Skipping interval (k): {1000 / cfg.TRAIN.num_ddim_timesteps}') + logger.info(f'Loss type (huber or l2): {cfg.TRAIN.loss_type}') + + datasets = get_datasets(cfg)[0] + train_dataloader = datasets.train_dataloader() + val_dataloader = datasets.val_dataloader() + + logger.info(f"Loading pretrained model: {cfg.TRAIN.PRETRAINED}") + state_dict = torch.load(cfg.TRAIN.PRETRAINED, map_location="cpu")["state_dict"] + base_model = MLD(cfg, datasets) + base_model.load_state_dict(state_dict) + + noise_scheduler = base_model.noise_scheduler + alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) + sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + solver = DDIMSolver( + noise_scheduler.alphas_cumprod.numpy(), + timesteps=noise_scheduler.config.num_train_timesteps, + ddim_timesteps=cfg.TRAIN.num_ddim_timesteps, + ) + + base_model.to(device) + + vae = base_model.vae + text_encoder = base_model.text_encoder + teacher_unet = base_model.denoiser + + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + teacher_unet.requires_grad_(False) + + # Apply CFG here (Important!!!) + cfg.model.denoiser.params.time_cond_proj_dim = cfg.TRAIN.unet_time_cond_proj_dim + unet = instantiate_from_config(cfg.model.denoiser) + unet.load_state_dict(teacher_unet.state_dict(), strict=False) + target_unet = instantiate_from_config(cfg.model.denoiser) + target_unet.load_state_dict(teacher_unet.state_dict(), strict=False) + + # Only evaluate the online network + base_model.denoiser = unet + + unet = unet.to(device) + target_unet = target_unet.to(device) + target_unet.requires_grad_(False) + + # Also move the alpha and sigma noise schedules to device + alpha_schedule = alpha_schedule.to(device) + sigma_schedule = sigma_schedule.to(device) + solver = solver.to(device) + + optimizer = torch.optim.AdamW( + unet.parameters(), + lr=cfg.TRAIN.learning_rate, + betas=(cfg.TRAIN.adam_beta1, cfg.TRAIN.adam_beta2), + weight_decay=cfg.TRAIN.adam_weight_decay, + eps=cfg.TRAIN.adam_epsilon) + + if cfg.TRAIN.max_train_steps == -1: + assert cfg.TRAIN.max_train_epochs != -1 + cfg.TRAIN.max_train_steps = cfg.TRAIN.max_train_epochs * len(train_dataloader) + + if cfg.TRAIN.checkpointing_steps == -1: + assert cfg.TRAIN.checkpointing_epochs != -1 + cfg.TRAIN.checkpointing_steps = cfg.TRAIN.checkpointing_epochs * len(train_dataloader) + + if cfg.TRAIN.validation_steps == -1: + assert cfg.TRAIN.validation_epochs != -1 + cfg.TRAIN.validation_steps = cfg.TRAIN.validation_epochs * len(train_dataloader) + + lr_scheduler = get_scheduler( + cfg.TRAIN.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.TRAIN.lr_warmup_steps, + num_training_steps=cfg.TRAIN.max_train_steps) + + uncond_prompt_embeds = text_encoder([""] * cfg.TRAIN.BATCH_SIZE) + + # Train! + logger.info("***** Running training *****") + logging.info(f" Num examples = {len(train_dataloader.dataset)}") + logging.info(f" Num Epochs = {cfg.TRAIN.max_train_epochs}") + logging.info(f" Instantaneous batch size per device = {cfg.TRAIN.BATCH_SIZE}") + logging.info(f" Total optimization steps = {cfg.TRAIN.max_train_steps}") + + global_step = 0 + first_epoch = 0 + + progress_bar = tqdm(range(0, cfg.TRAIN.max_train_steps), desc="Steps") + + @torch.no_grad() + def validation(): + base_model.eval() + for val_batch in tqdm(val_dataloader): + val_batch = move_batch_to_device(val_batch, device) + base_model.allsplit_step('test', val_batch) + metrics = base_model.allsplit_epoch_end() + max_val_rp1 = metrics['Metrics/R_precision_top_1'] + min_val_fid = metrics['Metrics/FID'] + print_table(f'Metrics@Step-{global_step}', metrics) + for k, v in metrics.items(): + writer.add_scalar(k, v, global_step=global_step) + base_model.train() + return max_val_rp1, min_val_fid + + max_rp1, min_fid = validation() + + for epoch in range(first_epoch, cfg.TRAIN.max_train_epochs): + for step, batch in enumerate(train_dataloader): + batch = move_batch_to_device(batch, device) + feats_ref = batch["motion"] + lengths = batch["length"] + text = batch['text'] + + # Encode motions to latents + with torch.no_grad(): + latents, _ = vae.encode(feats_ref, lengths) + latents = latents.permute(1, 0, 2) + + prompt_embeds = text_encoder(text) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias. + topk = noise_scheduler.config.num_train_timesteps // cfg.TRAIN.num_ddim_timesteps + index = torch.randint(0, cfg.TRAIN.num_ddim_timesteps, (bsz,), device=latents.device).long() + start_timesteps = solver.ddim_timesteps[index] + timesteps = start_timesteps - topk + timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) + + # Get boundary scalings for start_timesteps and (end) timesteps. + c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps) + c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] + c_skip, c_out = scalings_for_boundary_conditions(timesteps) + c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) + + # Sample a random guidance scale w from U[w_min, w_max] and embed it + w = (cfg.TRAIN.w_max - cfg.TRAIN.w_min) * torch.rand((bsz,)) + cfg.TRAIN.w_min + w_embedding = guidance_scale_embedding(w, embedding_dim=cfg.TRAIN.unet_time_cond_proj_dim) + w = append_dims(w, latents.ndim) + # Move to U-Net device and dtype + w = w.to(device=latents.device, dtype=latents.dtype) + w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype) + + # Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k} + noise_pred = unet( + noisy_model_input, + start_timesteps, + timestep_cond=w_embedding, + encoder_hidden_states=prompt_embeds) + + pred_x_0 = predicted_origin( + noise_pred, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule) + + model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 + + # Use the ODE solver to predict the k-th step in the augmented PF-ODE trajectory after + # noisy_latents with both the conditioning embedding c and unconditional embedding 0 + # Get teacher model prediction on noisy_latents and conditional embedding + with torch.no_grad(): + cond_teacher_output = teacher_unet( + noisy_model_input, + start_timesteps, + encoder_hidden_states=prompt_embeds) + cond_pred_x0 = predicted_origin( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule) + + # Get teacher model prediction on noisy_latents and unconditional embedding + uncond_teacher_output = teacher_unet( + noisy_model_input, + start_timesteps, + encoder_hidden_states=uncond_prompt_embeds[:bsz]) + uncond_pred_x0 = predicted_origin( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule) + + # Perform "CFG" to get z_prev estimate (using the LCM paper's CFG formulation) + pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) + pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output) + x_prev = solver.ddim_step(pred_x0, pred_noise, index) + + # Get target LCM prediction on z_prev, w, c, t_n + with torch.no_grad(): + target_noise_pred = target_unet( + x_prev.float(), + timesteps, + timestep_cond=w_embedding, + encoder_hidden_states=prompt_embeds) + pred_x_0 = predicted_origin( + target_noise_pred, + timesteps, + x_prev, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule) + target = c_skip * x_prev + c_out * pred_x_0 + + # Calculate loss + if cfg.TRAIN.loss_type == "l2": + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + elif cfg.TRAIN.loss_type == "huber": + loss = torch.mean( + torch.sqrt( + (model_pred.float() - target.float()) ** 2 + cfg.TRAIN.huber_c ** 2) - cfg.TRAIN.huber_c + ) + + # Back propagate on the online student model (`unet`) + loss.backward() + torch.nn.utils.clip_grad_norm_(unet.parameters(), cfg.TRAIN.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Make EMA update to target student model parameters + update_ema(target_unet.parameters(), unet.parameters(), cfg.TRAIN.ema_decay) + progress_bar.update(1) + global_step += 1 + + if global_step % cfg.TRAIN.checkpointing_steps == 0: + save_path = os.path.join(output_dir, 'checkpoints', f"checkpoint-{global_step}.ckpt") + ckpt = dict(state_dict=base_model.state_dict()) + base_model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % cfg.TRAIN.validation_steps == 0: + cur_rp1, cur_fid = validation() + if cur_rp1 > max_rp1: + max_rp1 = cur_rp1 + save_path = os.path.join(output_dir, 'checkpoints', + f"checkpoint-{global_step}-rp1-{round(cur_rp1, 3)}.ckpt") + ckpt = dict(state_dict=base_model.state_dict()) + base_model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path} with rp1:{round(cur_rp1, 3)}") + + if cur_fid < min_fid: + min_fid = cur_fid + save_path = os.path.join(output_dir, 'checkpoints', + f"checkpoint-{global_step}-fid-{round(cur_fid, 3)}.ckpt") + ckpt = dict(state_dict=base_model.state_dict()) + base_model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + logger.info(f"Saved state to {save_path} with fid:{round(cur_fid, 3)}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + writer.add_scalar('loss', logs['loss'], global_step=global_step) + writer.add_scalar('lr', logs['lr'], global_step=global_step) + + if global_step >= cfg.TRAIN.max_train_steps: + break + + save_path = os.path.join(output_dir, 'checkpoints', f"checkpoint-last.ckpt") + ckpt = dict(state_dict=base_model.state_dict()) + base_model.on_save_checkpoint(ckpt) + torch.save(ckpt, save_path) + + +if __name__ == "__main__": + main()