|
import numpy as np |
|
import os |
|
import random |
|
import torch |
|
import time |
|
from mGPT.config import instantiate_from_config |
|
from os.path import join as pjoin |
|
from mGPT.losses.mgpt import GPTLosses |
|
from mGPT.models.base import BaseModel |
|
from .base import BaseModel |
|
import json |
|
import mGPT.render.matplot.plot_3d_global as plot_3d |
|
|
|
|
|
class MotionGPT(BaseModel): |
|
""" |
|
Stage 1 Motion Tokenizer |
|
Stage 2 Motion-language pretrian |
|
Stage 3 Motion-language instruction tuning |
|
""" |
|
|
|
def __init__(self, |
|
cfg, |
|
datamodule, |
|
lm, |
|
motion_vae, |
|
codebook_size=512, |
|
stage='vae', |
|
debug=True, |
|
condition='text', |
|
task='t2m', |
|
metrics_dict=['TM2TMetrics'], |
|
**kwargs): |
|
|
|
self.save_hyperparameters(ignore='datamodule', logger=False) |
|
self.datamodule = datamodule |
|
super().__init__() |
|
|
|
|
|
if motion_vae != None: |
|
self.vae = instantiate_from_config(motion_vae) |
|
|
|
|
|
self.lm = instantiate_from_config(lm) |
|
|
|
|
|
if 'lm' in self.hparams.stage: |
|
self.vae.training = False |
|
for p in self.vae.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
self._losses = torch.nn.ModuleDict({ |
|
split: GPTLosses(cfg, self.hparams.stage, self.datamodule.njoints) |
|
for split in ["losses_train", "losses_test", "losses_val"] |
|
}) |
|
|
|
|
|
self.feats2joints = datamodule.feats2joints |
|
|
|
|
|
self.codePred = [] |
|
self.codeFrequency = torch.zeros((self.hparams.codebook_size, )) |
|
|
|
def forward(self, batch, task="t2m"): |
|
texts = batch["text"] |
|
lengths_ref = batch["length"] |
|
|
|
|
|
|
|
outputs, output_texts = self.lm.generate_direct(texts, do_sample=True) |
|
|
|
|
|
feats_rst_lst = [] |
|
lengths = [] |
|
max_len = 0 |
|
|
|
for i in range(len(texts)): |
|
if task == "pred": |
|
motion = self.vae.decode( |
|
torch.cat((batch["motion"][i], outputs[i]))) |
|
elif task in ["t2m", "m2t", "inbetween"]: |
|
motion = self.vae.decode(outputs[i]) |
|
|
|
lengths.append(motion.shape[1]) |
|
else: |
|
raise NotImplementedError |
|
|
|
if motion.shape[1] > max_len: |
|
max_len = motion.shape[1] |
|
|
|
if task in ["t2m", "m2t", "pred"]: |
|
feats_rst_lst.append(motion) |
|
|
|
elif task == "inbetween": |
|
motion = torch.cat( |
|
(batch["motion_heading"][i][None], |
|
motion[:, lengths_ref[i] // 4:lengths_ref[i] // 4 * 3, |
|
...], batch["motion_tailing"][i][None]), |
|
dim=1) |
|
feats_rst_lst.append(motion) |
|
|
|
feats_rst = torch.zeros( |
|
(len(feats_rst_lst), max_len, motion.shape[-1])).to(self.device) |
|
|
|
|
|
for i in range(len(feats_rst_lst)): |
|
feats_rst[i, :feats_rst_lst[i].shape[1], ...] = feats_rst_lst[i] |
|
|
|
|
|
joints_rst = self.feats2joints(feats_rst) |
|
|
|
|
|
outputs = { |
|
"texts": output_texts, |
|
"feats": feats_rst, |
|
"joints": joints_rst, |
|
"length": lengths |
|
} |
|
|
|
return outputs |
|
|
|
def train_lm_forward(self, batch): |
|
tokens_ref = batch["motion"] |
|
texts = batch["text"] |
|
lengths = batch["length"] |
|
tasks = batch["tasks"] |
|
all_captions = batch['all_captions'] |
|
if self.hparams.condition == 'caption': |
|
texts = [random.choice(all_captions[i]) for i in range(len(texts))] |
|
|
|
|
|
outputs = self.lm(texts, tokens_ref, lengths, tasks) |
|
|
|
return {'outputs': outputs} |
|
|
|
@torch.no_grad() |
|
def val_t2m_forward(self, batch): |
|
feats_ref = batch["motion"] |
|
texts = batch["text"] |
|
lengths = batch["length"] |
|
tasks = None |
|
if self.trainer.datamodule.is_mm: |
|
texts = texts * self.hparams.cfg.METRIC.MM_NUM_REPEATS |
|
feats_ref = feats_ref.repeat_interleave( |
|
self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0) |
|
lengths = lengths * self.hparams.cfg.METRIC.MM_NUM_REPEATS |
|
instructions = pjoin(self.datamodule.hparams.data_root, |
|
'template_instructions.json') |
|
instructions = json.load(open(instructions, 'r')) |
|
tasks = [instructions["Text-to-Motion"]["caption"]] * len(texts) |
|
|
|
if self.hparams.condition == 'caption': |
|
tasks = [{ |
|
'input': ['<Caption_Placeholder>'], |
|
'output': [''] |
|
}] * len(texts) |
|
|
|
if self.hparams.cfg.DATASET.TASK_PATH: |
|
instructions = pjoin(self.hparams.cfg.DATASET.TASK_PATH) |
|
instructions = json.load(open(instructions, 'r')) |
|
tasks = [instructions["Text-to-Motion"]["t2m"]] * len(texts) |
|
|
|
min_len = lengths.copy() |
|
|
|
outputs = self.lm.generate_conditional(texts, |
|
lengths=lengths, |
|
stage='test', |
|
tasks=tasks) |
|
|
|
|
|
feats_rst = torch.zeros_like(feats_ref) |
|
|
|
for i in range(len(texts)): |
|
outputs[i] = torch.clamp(outputs[i], |
|
0, |
|
self.hparams.codebook_size - 1, |
|
out=None) |
|
|
|
if len(outputs[i]) > 1: |
|
motion = self.vae.decode(outputs[i]) |
|
else: |
|
motion = torch.zeros_like(feats_ref[i:i + 1, ...]) |
|
|
|
min_len[i] = min(motion.shape[1], lengths[i]) |
|
|
|
|
|
feats_rst[i:i + 1, :min_len[i], ...] = motion[:, :lengths[i]] |
|
|
|
|
|
joints_ref = self.feats2joints(feats_ref) |
|
joints_rst = self.feats2joints(feats_rst) |
|
|
|
|
|
feats_ref = self.datamodule.renorm4t2m(feats_ref) |
|
feats_rst = self.datamodule.renorm4t2m(feats_rst) |
|
|
|
|
|
rs_set = { |
|
"m_ref": feats_ref, |
|
"m_rst": feats_rst, |
|
"joints_ref": joints_ref, |
|
"joints_rst": joints_rst, |
|
"length": min_len |
|
|
|
} |
|
|
|
return rs_set |
|
|
|
@torch.no_grad() |
|
def val_m2t_forward(self, batch): |
|
self.hparams.metrics_dict = [] |
|
|
|
feats_ref = batch["motion"] |
|
texts = batch["text"] |
|
lengths = batch["length"] |
|
all_captions = batch['all_captions'] |
|
|
|
|
|
motion_tokens = [] |
|
lengths_tokens = [] |
|
for i in range(len(feats_ref)): |
|
motion_token, _ = self.vae.encode(feats_ref[i:i + 1]) |
|
motion_tokens.append(motion_token[0]) |
|
lengths_tokens.append(motion_token.shape[1]) |
|
|
|
|
|
outputs = self.lm.generate_conditional(motion_tokens=motion_tokens, |
|
lengths=lengths_tokens, |
|
task="m2t", |
|
stage='test') |
|
|
|
|
|
rs_set = { |
|
"m_ref": feats_ref, |
|
"t_ref": all_captions, |
|
|
|
"t_pred": outputs, |
|
"length": lengths |
|
} |
|
|
|
return rs_set |
|
|
|
@torch.no_grad() |
|
def val_m2m_forward(self, batch, task="pred"): |
|
feats_ref = batch["motion"] |
|
lengths = batch["length"] |
|
|
|
|
|
motion_tokens = [] |
|
lengths_tokens = [] |
|
for i in range(len(feats_ref)): |
|
motion_token, _ = self.vae.encode(feats_ref[i:i + 1]) |
|
motion_tokens.append(motion_token[0]) |
|
|
|
|
|
outputs = self.lm.generate_conditional(motion_tokens=motion_tokens, |
|
lengths=lengths, |
|
task=task, |
|
stage='test') |
|
|
|
|
|
feats_rst = torch.zeros_like(feats_ref) |
|
min_len = lengths.copy() |
|
|
|
for i in range(len(lengths)): |
|
outputs[i] = torch.clamp(outputs[i], |
|
0, |
|
self.hparams.codebook_size - 1, |
|
out=None) |
|
|
|
if len(outputs[i]) > 1: |
|
motion = self.vae.decode(outputs[i]) |
|
else: |
|
motion = torch.zeros_like(feats_ref[i:i + 1, ...]) |
|
|
|
min_len[i] = min(motion.shape[1], lengths[i]) |
|
|
|
|
|
feats_rst[i:i + 1, :min_len[i], ...] = motion[:, :lengths[i]] |
|
|
|
|
|
joints_ref = self.feats2joints(feats_ref) |
|
joints_rst = self.feats2joints(feats_rst) |
|
|
|
|
|
feats_ref = self.datamodule.renorm4t2m(feats_ref) |
|
feats_rst = self.datamodule.renorm4t2m(feats_rst) |
|
|
|
|
|
rs_set = { |
|
"m_ref": feats_ref, |
|
"m_rst": feats_rst, |
|
"joints_ref": joints_ref, |
|
"joints_rst": joints_rst, |
|
"length": min_len |
|
|
|
} |
|
|
|
return rs_set |
|
|
|
def train_vae_forward(self, batch): |
|
|
|
feats_ref = batch["motion"] |
|
joints_ref = self.feats2joints(feats_ref) |
|
|
|
feats_rst, loss_commit, perplexity = self.vae(feats_ref) |
|
joints_rst = self.feats2joints(feats_rst) |
|
|
|
rs_set = { |
|
"m_ref": feats_ref, |
|
"joints_ref": joints_ref, |
|
"m_rst": feats_rst, |
|
"joints_rst": joints_rst, |
|
"loss_commit": loss_commit, |
|
"perplexity": perplexity, |
|
} |
|
return rs_set |
|
|
|
@torch.no_grad() |
|
def val_vae_forward(self, batch, split="train"): |
|
|
|
feats_ref = batch["motion"] |
|
lengths = batch["length"] |
|
|
|
|
|
if self.trainer.datamodule.is_mm: |
|
feats_ref = feats_ref.repeat_interleave( |
|
self.hparams.cfg.METRIC.MM_NUM_REPEATS, dim=0) |
|
lengths = lengths * self.hparams.cfg.METRIC.MM_NUM_REPEATS |
|
|
|
|
|
feats_rst = torch.zeros_like(feats_ref) |
|
|
|
for i in range(len(feats_ref)): |
|
if lengths[i] == 0: |
|
continue |
|
feats_pred, _, _ = self.vae(feats_ref[i:i + 1, :lengths[i]]) |
|
feats_rst[i:i + 1, :feats_pred.shape[1], :] = feats_pred |
|
|
|
code_pred, _ = self.vae.encode(feats_ref[i:i + 1, :lengths[i]]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
joints_ref = self.feats2joints(feats_ref) |
|
joints_rst = self.feats2joints(feats_rst) |
|
|
|
|
|
feats_ref = self.datamodule.renorm4t2m(feats_ref) |
|
feats_rst = self.datamodule.renorm4t2m(feats_rst) |
|
|
|
|
|
rs_set = { |
|
"m_ref": feats_ref, |
|
"joints_ref": joints_ref, |
|
"m_rst": feats_rst, |
|
"joints_rst": joints_rst, |
|
"length": lengths, |
|
} |
|
|
|
return rs_set |
|
|
|
|
|
def allsplit_step(self, split: str, batch, batch_idx): |
|
|
|
loss = None |
|
|
|
if self.hparams.stage == "vae" and split in ["train", "val"]: |
|
rs_set = self.train_vae_forward(batch) |
|
loss = self._losses['losses_' + split].update(rs_set) |
|
elif self.hparams.stage in ["lm_instruct", "lm_pretrain" |
|
] and split in ["train"]: |
|
rs_set = self.train_lm_forward(batch) |
|
loss = self._losses['losses_' + split].update(rs_set) |
|
elif self.hparams.stage == 'lm_rl' and split in ['train']: |
|
rs_set = self.train_rl_forward(batch) |
|
loss = None |
|
|
|
|
|
if split in ["val", "test"]: |
|
if self.hparams.stage == "vae": |
|
rs_set = self.val_vae_forward(batch, split) |
|
elif self.hparams.stage in ["lm_instruct", "lm_pretrain", "lm_rl"]: |
|
if self.hparams.task == "t2m": |
|
rs_set = self.val_t2m_forward(batch) |
|
elif self.hparams.task == "m2t": |
|
rs_set = self.val_m2t_forward(batch) |
|
elif self.hparams.task in ["m2m", "pred", "inbetween"]: |
|
rs_set = self.val_m2m_forward(batch, self.hparams.task) |
|
|
|
if self.hparams.task not in ["m2t"]: |
|
|
|
if self.trainer.datamodule.is_mm: |
|
metrics_dicts = ['MMMetrics'] |
|
else: |
|
metrics_dicts = self.hparams.metrics_dict |
|
|
|
if self.hparams.task not in ['pred', 'inbetween']: |
|
metrics_dicts.remove('PredMetrics') |
|
|
|
for metric in metrics_dicts: |
|
lengths = batch['length'] |
|
if metric == "TemosMetric": |
|
getattr(self.metrics, |
|
metric).update(rs_set["joints_rst"], |
|
rs_set["joints_ref"], lengths) |
|
elif metric == "TM2TMetrics": |
|
if self.hparams.stage in [ |
|
"lm_instruct", "lm_pretrain", "lm_rl" |
|
]: |
|
word_embs = batch['word_embs'] |
|
pos_ohot = batch['pos_ohot'] |
|
text_lengths = batch['text_len'] |
|
if self.trainer.datamodule.is_mm: |
|
word_embs = word_embs.repeat_interleave( |
|
self.hparams.cfg.METRIC.MM_NUM_REPEATS, |
|
dim=0) |
|
pos_ohot = pos_ohot.repeat_interleave( |
|
self.hparams.cfg.METRIC.MM_NUM_REPEATS, |
|
dim=0) |
|
text_lengths = text_lengths.repeat_interleave( |
|
self.hparams.cfg.METRIC.MM_NUM_REPEATS, |
|
dim=0) |
|
else: |
|
word_embs = None |
|
pos_ohot = None |
|
text_lengths = None |
|
|
|
getattr(self.metrics, metric).update( |
|
feats_ref=rs_set["m_ref"], |
|
feats_rst=rs_set["m_rst"], |
|
lengths_ref=lengths, |
|
lengths_rst=rs_set['length'], |
|
word_embs=word_embs, |
|
pos_ohot=pos_ohot, |
|
text_lengths=text_lengths, |
|
) |
|
elif metric == "UncondMetrics": |
|
getattr(self.metrics, metric).update( |
|
recmotion_embeddings=rs_set["lat_rm"], |
|
gtmotion_embeddings=rs_set["lat_m"], |
|
lengths=lengths, |
|
) |
|
elif metric == "MRMetrics": |
|
getattr(self.metrics, |
|
metric).update(rs_set["joints_rst"], |
|
rs_set["joints_ref"], lengths) |
|
elif metric == "PredMetrics": |
|
getattr(self.metrics, |
|
metric).update(rs_set["joints_rst"], |
|
rs_set["joints_ref"], lengths) |
|
elif metric == "MMMetrics": |
|
|
|
getattr(self.metrics, |
|
metric).update(rs_set["m_rst"], |
|
rs_set['length']) |
|
else: |
|
raise TypeError(f"Not support this metric {metric}") |
|
|
|
elif self.hparams.task == "m2t" and self.hparams.stage in [ |
|
"lm_instruct", "lm_pretrain", "lm_rl" |
|
]: |
|
self.hparams.metrics_dict = metrics_dicts = ['M2TMetrics'] |
|
for metric in metrics_dicts: |
|
if metric == "M2TMetrics": |
|
getattr(self.metrics, metric).update( |
|
feats_ref=rs_set["m_ref"], |
|
pred_texts=rs_set["t_pred"], |
|
gt_texts=batch["all_captions"], |
|
lengths=rs_set['length'], |
|
word_embs=batch["word_embs"], |
|
pos_ohot=batch["pos_ohot"], |
|
text_lengths=batch["text_len"], |
|
) |
|
|
|
|
|
if split in ["test"]: |
|
if self.hparams.task == "t2m": |
|
return rs_set["joints_rst"], rs_set["length"], rs_set[ |
|
"joints_ref"] |
|
|
|
elif self.hparams.task == "m2t": |
|
return rs_set["t_pred"], batch["length"] |
|
|
|
|
|
return loss |
|
|