""" Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F FNS = { "sqrt": torch.sqrt, "log": torch.log, "log1": lambda x: torch.log(x + 1), "linear": lambda x: x, "square": torch.square, "disp": lambda x: 1 / x, } FNS_INV = { "sqrt": torch.square, "log": torch.exp, "log1": lambda x: torch.exp(x) - 1, "linear": lambda x: x, "square": torch.sqrt, "disp": lambda x: 1 / x, } def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): if mask is None: return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) mask = mask.float() mask_sum = torch.sum(mask, dim=dim, keepdim=True) mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( mask_sum, min=1.0 ) mask_var = torch.sum( mask * (data - mask_mean) ** 2, dim=dim, keepdim=True ) / torch.clamp(mask_sum, min=1.0) return mask_mean.squeeze(dim), mask_var.squeeze(dim) def masked_mean(data: torch.Tensor, mask: torch.Tensor | None, dim: List[int]): if mask is None: return data.mean(dim=dim, keepdim=True) mask = mask.float() mask_sum = torch.sum(mask, dim=dim, keepdim=True) mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( mask_sum, min=1.0 ) return mask_mean def masked_mae(data: torch.Tensor, mask: torch.Tensor, dim: Tuple[int, ...]): if mask is None: return data.abs().mean(dim=dim, keepdim=True) mask = mask.float() mask_sum = torch.sum(mask, dim=dim, keepdim=True) mask_mean = torch.sum(data.abs() * mask, dim=dim, keepdim=True) / torch.clamp( mask_sum, min=1.0 ) return mask_mean def masked_mse(data: torch.Tensor, mask: torch.Tensor, dim: Tuple[int, ...]): if mask is None: return (data**2).mean(dim=dim, keepdim=True) mask = mask.float() mask_sum = torch.sum(mask, dim=dim, keepdim=True) mask_mean = torch.sum((data**2) * mask, dim=dim, keepdim=True) / torch.clamp( mask_sum, min=1.0 ) return mask_mean def masked_median(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): ndim = data.ndim data = data.flatten(ndim - len(dim)) mask = mask.flatten(ndim - len(dim)) mask_median = torch.median(data[mask], dim=-1).values return mask_median def masked_median_mad(data: torch.Tensor, mask: torch.Tensor): data = data.flatten() mask = mask.flatten() mask_median = torch.median(data[mask]) n_samples = torch.clamp(torch.sum(mask.float()), min=1.0) mask_mad = torch.sum((data[mask] - mask_median).abs()) / n_samples return mask_median, mask_mad def masked_weighted_mean_var( data: torch.Tensor, mask: torch.Tensor, weights: torch.Tensor, dim: Tuple[int, ...] ): if mask is None: return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) mask = mask.float() mask_mean = torch.sum(data * mask * weights, dim=dim, keepdim=True) / torch.sum( mask * weights, dim=dim, keepdim=True ).clamp(min=1.0) # V1**2 - V2, V1: sum w_i, V2: sum w_i**2 denom = torch.sum(weights * mask, dim=dim, keepdim=True).square() - torch.sum( (mask * weights).square(), dim=dim, keepdim=True ) # correction is V1 / (V1**2 - V2), if w_i=1 => N/(N**2 - N) => 1/(N-1) (unbiased estimator of variance, cvd) correction_factor = torch.sum(mask * weights, dim=dim, keepdim=True) / denom.clamp( min=1.0 ) mask_var = correction_factor * torch.sum( weights * mask * (data - mask_mean) ** 2, dim=dim, keepdim=True ) return mask_mean, mask_var def masked_mean_var_q(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): if mask is None: return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) mask = mask.float() mask_sum = torch.sum(mask, dim=dim, keepdim=True) mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( mask_sum, min=1.0 ) mask_var = torch.sum( mask * (data - mask_mean) ** 2, dim=dim, keepdim=True ) / torch.clamp(mask_sum, min=1.0) return mask_mean, mask_var class SILog(nn.Module): def __init__( self, weight: float, scale_pred_weight: float = 0.15, output_fn: str = "sqrt", input_fn: str = "log", legacy: bool = False, abs_rel: bool = False, norm: bool = False, eps: float = 1e-5, ): super().__init__() assert output_fn in FNS self.name: str = self.__class__.__name__ self.weight: float = weight self.scale_pred_weight: float = scale_pred_weight self.dims = (-4, -3, -2, -1) if legacy else (-2, -1) self.output_fn = FNS[output_fn] self.input_fn = FNS[input_fn] self.abs_rel = abs_rel self.norm = norm self.eps: float = eps @torch.cuda.amp.autocast(enabled=False) def forward( self, input: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None, interpolate: bool = True, scale_inv: torch.Tensor | None = None, ss_inv: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: if interpolate: input = F.interpolate( input, target.shape[-2:], mode="bilinear", align_corners=False ) if mask is not None: mask = mask.to(torch.bool) if ss_inv is not None: ss_inv = ~ss_inv if input.shape[1] > 1: input_ = torch.cat( [input[:, :-1], self.input_fn(input[:, -1:].clamp(min=self.eps))], dim=1 ) target_ = torch.cat( [target[:, :-1], self.input_fn(target[:, -1:].clamp(min=self.eps))], dim=1, ) error = torch.norm(input_ - target_, dim=1, keepdim=True) else: input_ = self.input_fn(input.clamp(min=self.eps)) target_ = self.input_fn(target.clamp(min=self.eps)) error = input_ - target_ mean_error, var_error = masked_mean_var(data=error, mask=mask, dim=self.dims) # prevoiusly was inverted!! if self.abs_rel: scale_error = (input - target).abs()[:, -1:] / target[:, -1:].clip( min=self.eps ) scale_error = masked_mean(data=scale_error, mask=mask, dim=self.dims) else: scale_error = mean_error**2 if var_error.ndim > 1: var_error = var_error.sum(dim=1) scale_error = scale_error.sum(dim=1) # if scale inv -> mask scale error, if scale/shift, mask the full loss if scale_inv is not None: scale_error = (1 - scale_inv.int()) * scale_error scale_error = self.scale_pred_weight * scale_error loss = var_error + scale_error out_loss = self.output_fn(loss.clamp(min=self.eps)) out_loss = masked_mean(data=out_loss, mask=ss_inv, dim=(0,)) return out_loss.mean() @classmethod def build(cls, config: Dict[str, Any]): obj = cls( weight=config["weight"], legacy=config["legacy"], output_fn=config["output_fn"], input_fn=config["input_fn"], norm=config.get("norm", False), scale_pred_weight=config.get("gamma", 0.15), abs_rel=config.get("abs_rel", False), ) return obj class MSE(nn.Module): def __init__( self, weight: float = 1.0, input_fn: str = "linear", output_fn: str = "linear", ): super().__init__() self.name: str = self.__class__.__name__ self.output_fn = FNS[output_fn] self.input_fn = FNS[input_fn] self.weight: float = weight self.eps = 1e-6 @torch.cuda.amp.autocast(enabled=False) def forward( self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None, batch_mask: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: input = input[..., : target.shape[-1]] # B N C or B H W C error = self.input_fn(input + self.eps) - self.input_fn(target + self.eps) abs_error = torch.square(error).sum(dim=-1) mean_error = masked_mean(data=abs_error, mask=mask, dim=(-1,)).mean(dim=-1) batched_error = masked_mean( self.output_fn(mean_error.clamp(self.eps)), batch_mask, dim=(0,) ) return batched_error.mean(), mean_error.detach() @classmethod def build(cls, config: Dict[str, Any]): obj = cls( weight=config["weight"], output_fn=config["output_fn"], input_fn=config["input_fn"], ) return obj class SelfCons(nn.Module): def __init__( self, weight: float, scale_pred_weight: float = 0.15, output_fn: str = "sqrt", input_fn: str = "log", abs_rel: bool = False, norm: bool = False, eps: float = 1e-5, ): super().__init__() assert output_fn in FNS self.name: str = self.__class__.__name__ self.weight: float = weight self.scale_pred_weight: float = scale_pred_weight self.dims = (-2, -1) self.output_fn = FNS[output_fn] self.input_fn = FNS[input_fn] self.abs_rel = abs_rel self.norm = norm self.eps: float = eps @torch.cuda.amp.autocast(enabled=False) def forward( self, input: torch.Tensor, mask: torch.Tensor, metas: List[Dict[str, torch.Tensor]], ) -> torch.Tensor: chunks = input.shape[0] // 2 device = input.device mask = F.interpolate(mask.float(), size=input.shape[-2:], mode="nearest") rescales = input.shape[-2] / torch.tensor( [x["resized_shape"][0] for x in metas], device=device ) cams = torch.cat([x["K_target"] for x in metas], dim=0).to(device) flips = torch.tensor([x["flip"] for x in metas], device=device) iters = zip( input.chunk(chunks), mask.chunk(chunks), cams.chunk(chunks), rescales.chunk(chunks), flips.chunk(chunks), ) inputs0, inputs1, masks = [], [], [] for i, (pair_input, pair_mask, pair_cam, pair_rescale, pair_flip) in enumerate( iters ): mask0, mask1 = pair_mask input0, input1 = pair_input cam0, cam1 = pair_cam rescale0, rescale1 = pair_rescale flip0, flip1 = pair_flip fx_0 = cam0[0, 0] * rescale0 fx_1 = cam1[0, 0] * rescale1 cx_0 = (cam0[0, 2] - 0.5) * rescale0 + 0.5 cx_1 = (cam1[0, 2] - 0.5) * rescale1 + 0.5 cy_0 = (cam0[1, 2] - 0.5) * rescale0 + 0.5 cy_1 = (cam1[1, 2] - 0.5) * rescale1 + 0.5 # flip image if flip0 ^ flip1: input0 = torch.flip(input0, dims=(2,)) mask0 = torch.flip(mask0, dims=(2,)) cx_0 = input0.shape[-1] - cx_0 # calc zoom zoom_x = float(fx_1 / fx_0) # apply zoom input0 = F.interpolate( input0.unsqueeze(0), scale_factor=zoom_x, mode="bilinear", align_corners=True, ).squeeze(0) mask0 = F.interpolate( mask0.unsqueeze(0), scale_factor=zoom_x, mode="nearest" ).squeeze(0) # calc translation change_left = int(cx_1 - (cx_0 - 0.5) * zoom_x - 0.5) change_top = int(cy_1 - (cy_0 - 0.5) * zoom_x - 0.5) change_right = input1.shape[-1] - change_left - input0.shape[-1] change_bottom = input1.shape[-2] - change_top - input0.shape[-2] # apply translation pad_left = max(0, change_left) pad_right = max(0, change_right) pad_top = max(0, change_top) pad_bottom = max(0, change_bottom) crop_left = max(0, -change_left) crop_right = max(0, -change_right) crop_top = max(0, -change_top) crop_bottom = max(0, -change_bottom) input0 = F.pad( input0, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=0, ) mask0 = F.pad( mask0, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=0, ) input0 = input0[ :, crop_top : input0.shape[-2] - crop_bottom, crop_left : input0.shape[-1] - crop_right, ] mask0 = mask0[ :, crop_top : mask0.shape[-2] - crop_bottom, crop_left : mask0.shape[-1] - crop_right, ] mask = torch.logical_and(mask0, mask1) inputs0.append(input0) inputs1.append(input1) masks.append(mask) inputs0 = torch.stack(inputs0, dim=0) inputs1 = torch.stack(inputs1, dim=0) masks = torch.stack(masks, dim=0) loss1 = self.loss(inputs0, inputs1.detach(), masks) loss2 = self.loss(inputs1, inputs0.detach(), masks) return torch.cat([loss1, loss2], dim=0).mean() def loss( self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: loss = masked_mean( (input - target).square().mean(dim=1), mask=mask, dim=(-2, -1) ) return self.output_fn(loss + self.eps) @classmethod def build(cls, config: Dict[str, Any]): obj = cls( weight=config["weight"], output_fn=config["output_fn"], input_fn=config["input_fn"], ) return obj