Spaces:
Sleeping
Sleeping
File size: 4,535 Bytes
4409449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from typing import List
import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.functional import pairwise_euclidean_distance
from .utils import *
import os
from mGPT.config import instantiate_from_config
class MMMetrics(Metric):
full_state_update = True
def __init__(self, cfg, dataname='humanml3d', mm_num_times=10, dist_sync_on_step=True, **kwargs):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.name = "MultiModality scores"
self.cfg = cfg
self.dataname = dataname
self.mm_num_times = mm_num_times
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("count_seq",
default=torch.tensor(0),
dist_reduce_fx="sum")
self.metrics = ["MultiModality"]
self.add_state("MultiModality",
default=torch.tensor(0.),
dist_reduce_fx="sum")
# chached batches
self.add_state("mm_motion_embeddings", default=[], dist_reduce_fx=None)
# T2M Evaluator
self._get_t2m_evaluator(cfg)
def _get_t2m_evaluator(self, cfg):
"""
load T2M text encoder and motion encoder for evaluating
"""
# init module
self.t2m_textencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_textencoder)
self.t2m_moveencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_moveencoder)
self.t2m_motionencoder = instantiate_from_config(cfg.METRIC.TM2T.t2m_motionencoder)
# load pretrianed
if self.dataname == "kit":
dataname = "kit"
else:
dataname = "t2m"
t2m_checkpoint = torch.load(os.path.join(
cfg.METRIC.TM2T.t2m_path, dataname,
"text_mot_match/model/finest.tar"),
map_location="cpu")
self.t2m_textencoder.load_state_dict(t2m_checkpoint["text_encoder"])
self.t2m_moveencoder.load_state_dict(
t2m_checkpoint["movement_encoder"])
self.t2m_motionencoder.load_state_dict(
t2m_checkpoint["motion_encoder"])
# freeze params
self.t2m_textencoder.eval()
self.t2m_moveencoder.eval()
self.t2m_motionencoder.eval()
for p in self.t2m_textencoder.parameters():
p.requires_grad = False
for p in self.t2m_moveencoder.parameters():
p.requires_grad = False
for p in self.t2m_motionencoder.parameters():
p.requires_grad = False
def compute(self, sanity_flag):
count = self.count.item()
count_seq = self.count_seq.item()
# init metrics
metrics = {metric: getattr(self, metric) for metric in self.metrics}
# if in sanity check stage then jump
if sanity_flag:
return metrics
# cat all embeddings
all_mm_motions = torch.cat(self.mm_motion_embeddings,
axis=0).cpu().numpy()
metrics['MultiModality'] = calculate_multimodality_np(
all_mm_motions, self.mm_num_times)
# Reset
self.reset()
return {**metrics}
def update(
self,
feats_rst: Tensor,
lengths_rst: List[int],
):
self.count += sum(lengths_rst)
self.count_seq += len(lengths_rst)
align_idx = np.argsort(lengths_rst)[::-1].copy()
feats_rst = feats_rst[align_idx]
lengths_rst = np.array(lengths_rst)[align_idx]
recmotion_embeddings = self.get_motion_embeddings(
feats_rst, lengths_rst)
cache = [0] * len(lengths_rst)
for i in range(len(lengths_rst)):
cache[align_idx[i]] = recmotion_embeddings[i:i + 1]
mm_motion_embeddings = torch.cat(cache, axis=0).unsqueeze(0)
# self.mm_motion_embeddings.extend(cache)
# print(mm_motion_embeddings.shape)
# # store all mm motion embeddings
self.mm_motion_embeddings.append(mm_motion_embeddings)
def get_motion_embeddings(self, feats: Tensor, lengths: List[int]):
m_lens = torch.tensor(lengths)
m_lens = torch.div(m_lens,
self.cfg.DATASET.HUMANML3D.UNIT_LEN,
rounding_mode="floor")
mov = self.t2m_moveencoder(feats[..., :-4]).detach()
emb = self.t2m_motionencoder(mov, m_lens)
# [bs, nlatent*ndim] <= [bs, nlatent, ndim]
return torch.flatten(emb, start_dim=1).detach()
|