|
from typing import List |
|
|
|
import torch |
|
from torch import Tensor |
|
from torchmetrics import Metric |
|
|
|
from .utils import * |
|
|
|
|
|
|
|
class MRMetrics(Metric): |
|
|
|
def __init__(self, |
|
njoints, |
|
jointstype: str = "mmm", |
|
force_in_meter: bool = True, |
|
align_root: bool = True, |
|
dist_sync_on_step=True, |
|
**kwargs): |
|
super().__init__(dist_sync_on_step=dist_sync_on_step) |
|
|
|
self.name = 'Motion Reconstructions' |
|
self.jointstype = jointstype |
|
self.align_root = align_root |
|
self.force_in_meter = force_in_meter |
|
|
|
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") |
|
self.add_state("count_seq", |
|
default=torch.tensor(0), |
|
dist_reduce_fx="sum") |
|
|
|
self.add_state("MPJPE", |
|
default=torch.tensor([0.0]), |
|
dist_reduce_fx="sum") |
|
self.add_state("PAMPJPE", |
|
default=torch.tensor([0.0]), |
|
dist_reduce_fx="sum") |
|
self.add_state("ACCEL", |
|
default=torch.tensor([0.0]), |
|
dist_reduce_fx="sum") |
|
|
|
|
|
|
|
self.MR_metrics = ["MPJPE", "PAMPJPE", "ACCEL"] |
|
|
|
|
|
self.metrics = self.MR_metrics |
|
|
|
def compute(self, sanity_flag): |
|
if self.force_in_meter: |
|
|
|
|
|
|
|
|
|
|
|
factor = 1000.0 |
|
else: |
|
factor = 1.0 |
|
|
|
count = self.count |
|
count_seq = self.count_seq |
|
mr_metrics = {} |
|
mr_metrics["MPJPE"] = self.MPJPE / count * factor |
|
mr_metrics["PAMPJPE"] = self.PAMPJPE / count * factor |
|
|
|
|
|
mr_metrics["ACCEL"] = self.ACCEL / (count - 2 * count_seq) * factor |
|
|
|
|
|
self.reset() |
|
|
|
return mr_metrics |
|
|
|
def update(self, joints_rst: Tensor, joints_ref: Tensor, |
|
lengths: List[int]): |
|
assert joints_rst.shape == joints_ref.shape |
|
assert joints_rst.dim() == 4 |
|
|
|
|
|
self.count += sum(lengths) |
|
self.count_seq += len(lengths) |
|
|
|
|
|
rst = joints_rst.detach().cpu() |
|
ref = joints_ref.detach().cpu() |
|
|
|
|
|
if self.align_root and self.jointstype in ['mmm', 'humanml3d']: |
|
align_inds = [0] |
|
else: |
|
align_inds = None |
|
|
|
for i in range(len(lengths)): |
|
self.MPJPE += torch.sum( |
|
calc_mpjpe(rst[i], ref[i], align_inds=align_inds)) |
|
self.PAMPJPE += torch.sum(calc_pampjpe(rst[i], ref[i])) |
|
self.ACCEL += torch.sum(calc_accel(rst[i], ref[i])) |
|
|