|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
from pathlib import Path |
|
import os |
|
|
|
class Rots2Rfeats(nn.Module): |
|
def __init__(self, path: Optional[str] = None, |
|
normalization: bool = True, |
|
eps: float = 1e-12, |
|
**kwargs) -> None: |
|
if normalization and path is None: |
|
raise TypeError("You should provide a path if normalization is on.") |
|
|
|
super().__init__() |
|
self.normalization = normalization |
|
self.eps = eps |
|
if normalization: |
|
|
|
rel_p = path.split('/') |
|
|
|
if rel_p[-1] == 'separate_pairs': |
|
rel_p.remove('separate_pairs') |
|
|
|
|
|
rel_p = '/'.join(rel_p) |
|
|
|
path = rel_p |
|
mean_path = Path(path) / "rfeats_mean.pt" |
|
std_path = Path(path) / "rfeats_std.pt" |
|
|
|
self.register_buffer('mean', torch.load(mean_path)) |
|
self.register_buffer('std', torch.load(std_path)) |
|
|
|
def normalize(self, features: Tensor) -> Tensor: |
|
if self.normalization: |
|
features = (features - self.mean)/(self.std + self.eps) |
|
return features |
|
|
|
def unnormalize(self, features: Tensor) -> Tensor: |
|
if self.normalization: |
|
features = features * self.std + self.mean |
|
return features |
|
|