import torch import torch.nn as nn from .fourier_features import FourierFeatures class RegionModel(nn.Module): def __init__(self): super().__init__() self.position_features = FourierFeatures(2, 256) self.position_encoder = nn.Linear(256, 2048) self.size_features = FourierFeatures(2, 256) self.size_encoder = nn.Linear(256, 2048) self.position_decoder = nn.Linear(2048, 2) self.size_decoder = nn.Linear(2048, 2) self.confidence_decoder = nn.Linear(2048, 1) def encode_position(self, position): return self.position_encoder(self.position_features(position)) def encode_size(self, size): return self.size_encoder(self.size_features(size)) def decode_position(self, x): return self.position_decoder(x) def decode_size(self, x): return self.size_decoder(x) def decode_confidence(self, x): return self.confidence_decoder(x) def encode(self, position, size): return torch.stack( [self.encode_position(position), self.encode_size(size)], dim=0 ) def decode(self, position_logits, size_logits): return ( self.decode_position(position_logits), self.decode_size(size_logits), self.decode_confidence(size_logits), )