File size: 3,062 Bytes
4409449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from typing import List
import torch
from torch import Tensor
from torchmetrics import Metric
from .utils import *
# motion reconstruction metric
class PredMetrics(Metric):
def __init__(self,
cfg,
njoints: int = 22,
jointstype: str = "mmm",
force_in_meter: bool = True,
align_root: bool = True,
dist_sync_on_step=True,
task: str = "pred",
**kwargs):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.name = 'Motion Prdiction'
self.cfg = cfg
self.jointstype = jointstype
self.align_root = align_root
self.task = task
self.force_in_meter = force_in_meter
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("count_seq",
default=torch.tensor(0),
dist_reduce_fx="sum")
self.add_state("APD",
default=torch.tensor([0.0]),
dist_reduce_fx="sum")
self.add_state("ADE",
default=torch.tensor([0.0]),
dist_reduce_fx="sum")
self.add_state("FDE",
default=torch.tensor([0.0]),
dist_reduce_fx="sum")
self.MR_metrics = ["APD", "ADE", "FDE"]
# All metric
self.metrics = self.MR_metrics
def compute(self, sanity_flag):
count = self.count
count_seq = self.count_seq
mr_metrics = {}
mr_metrics["APD"] = self.APD / count_seq
mr_metrics["ADE"] = self.ADE / count_seq
mr_metrics["FDE"] = self.FDE / count_seq
# Reset
self.reset()
return mr_metrics
def update(self, joints_rst: Tensor, joints_ref: Tensor,
lengths: List[int]):
assert joints_rst.shape == joints_ref.shape
assert joints_rst.dim() == 4
# (bs, seq, njoint=22, 3)
self.count += sum(lengths)
self.count_seq += len(lengths)
rst = torch.flatten(joints_rst, start_dim=2)
ref = torch.flatten(joints_ref, start_dim=2)
for i, l in enumerate(lengths):
if self.task == "pred":
pred_start = int(l*self.cfg.ABLATION.predict_ratio)
diff = rst[i,pred_start:] - ref[i,pred_start:]
elif self.task == "inbetween":
inbetween_start = int(l*self.cfg.ABLATION.inbetween_ratio)
inbetween_end = l - int(l*self.cfg.ABLATION.inbetween_ratio)
diff = rst[i,inbetween_start:inbetween_end] - ref[i,inbetween_start:inbetween_end]
else:
print(f"Task {self.task} not implemented.")
diff = rst - ref
dist = torch.linalg.norm(diff, dim=-1)[None]
ade = dist.mean(dim=1)
fde = dist[:,-1]
self.ADE = self.ADE + ade
self.FDE = self.FDE + fde
|