Spaces:
Sleeping
Sleeping
from typing import List | |
import torch | |
from torch import Tensor | |
from torchmetrics import Metric | |
from torchmetrics.functional import pairwise_euclidean_distance | |
from .utils import * | |
import os | |
from mGPT.config import instantiate_from_config | |
class MMMetrics(Metric): | |
full_state_update = True | |
def __init__(self, cfg, dataname='humanml3d', mm_num_times=10, dist_sync_on_step=True, **kwargs): | |
super().__init__(dist_sync_on_step=dist_sync_on_step) | |
self.name = "MultiModality scores" | |
self.cfg = cfg | |
self.dataname = dataname | |
self.mm_num_times = mm_num_times | |
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") | |
self.add_state("count_seq", | |
default=torch.tensor(0), | |
dist_reduce_fx="sum") | |
self.metrics = ["MultiModality"] | |
self.add_state("MultiModality", | |
default=torch.tensor(0.), | |
dist_reduce_fx="sum") | |
# chached batches | |
self.add_state("mm_motion_embeddings", default=[], dist_reduce_fx=None) | |
# T2M Evaluator | |
self._get_t2m_evaluator(cfg) | |
def _get_t2m_evaluator(self, cfg): | |
""" | |
load T2M text encoder and motion encoder for evaluating | |
""" | |
# init module | |
self.t2m_textencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_textencoder) | |
self.t2m_moveencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_moveencoder) | |
self.t2m_motionencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_motionencoder) | |
# load pretrianed | |
if self.dataname == "kit": | |
dataname = "kit" | |
else: | |
dataname = "t2m" | |
t2m_checkpoint = torch.load(os.path.join( | |
cfg.METRIC.TM2T.t2m_path, dataname, | |
"text_mot_match/model/finest.tar"), | |
map_location="cpu") | |
self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"]) | |
self.t2m_moveencoder.load_state_dict( | |
t2m_checkpoint["movement_encoder"]) | |
self.t2m_motionencoder.load_state_dict( | |
t2m_checkpoint["motion_encoder"]) | |
# freeze params | |
self.t2m_textencoder.eval() | |
self.t2m_moveencoder.eval() | |
self.t2m_motionencoder.eval() | |
for p in self.t2m_textencoder.parameters(): | |
p.requires_grad = False | |
for p in self.t2m_moveencoder.parameters(): | |
p.requires_grad = False | |
for p in self.t2m_motionencoder.parameters(): | |
p.requires_grad = False | |
def compute(self, sanity_flag): | |
count = self.count.item() | |
count_seq = self.count_seq.item() | |
# init metrics | |
metrics = {metric: getattr(self, metric) for metric in self.metrics} | |
# if in sanity check stage then jump | |
if sanity_flag: | |
return metrics | |
# cat all embeddings | |
all_mm_motions = torch.cat(self.mm_motion_embeddings, | |
axis=0).cpu().numpy() | |
metrics['MultiModality'] = calculate_multimodality_np( | |
all_mm_motions, self.mm_num_times) | |
# Reset | |
self.reset() | |
return {**metrics} | |
def update( | |
self, | |
feats_rst: Tensor, | |
lengths_rst: List[int], | |
): | |
self.count += sum(lengths_rst) | |
self.count_seq += len(lengths_rst) | |
align_idx = np.argsort(lengths_rst)[::-1].copy() | |
feats_rst = feats_rst[align_idx] | |
lengths_rst = np.array(lengths_rst)[align_idx] | |
recmotion_embeddings = self.get_motion_embeddings( | |
feats_rst, lengths_rst) | |
cache = [0] * len(lengths_rst) | |
for i in range(len(lengths_rst)): | |
cache[align_idx[i]] = recmotion_embeddings[i:i + 1] | |
mm_motion_embeddings = torch.cat(cache, axis=0).unsqueeze(0) | |
# self.mm_motion_embeddings.extend(cache) | |
# print(mm_motion_embeddings.shape) | |
# # store all mm motion embeddings | |
self.mm_motion_embeddings.append(mm_motion_embeddings) | |
def get_motion_embeddings(self, feats: Tensor, lengths: List[int]): | |
m_lens = torch.tensor(lengths) | |
m_lens = torch.div(m_lens, | |
self.cfg.DATASET.HUMANML3D.UNIT_LEN, | |
rounding_mode="floor") | |
mov = self.t2m_moveencoder(feats[..., :-4]).detach() | |
emb = self.t2m_motionencoder(mov, m_lens) | |
# [bs, nlatent*ndim] <= [bs, nlatent, ndim] | |
return torch.flatten(emb, start_dim=1).detach() | |