# -*- coding: utf-8 -*- # Copyright 2024 Wen-Chin Huang # MIT License (https://opensource.org/licenses/MIT) # LDNet modules # taken from: https://github.com/unilight/LDNet/blob/main/models/modules.py (written by myself) import torch from torch import nn STRIDE = 3 class Projection(nn.Module): def __init__( self, in_dim, hidden_dim, activation, output_type, _output_dim, output_step=1.0, range_clipping=False, ): super(Projection, self).__init__() self.output_type = output_type self.range_clipping = range_clipping if output_type == "scalar": output_dim = 1 if range_clipping: self.proj = nn.Tanh() elif output_type == "categorical": output_dim = _output_dim self.output_step = output_step else: raise NotImplementedError("wrong output_type: {}".format(output_type)) self.net = nn.Sequential( nn.Linear(in_dim, hidden_dim), activation(), nn.Dropout(0.3), nn.Linear(hidden_dim, output_dim), ) def forward(self, x, inference=False): output = self.net(x) # scalar / categorical if self.output_type == "scalar": # range clipping if self.range_clipping: return self.proj(output) * 2.0 + 3 else: return output else: if inference: return torch.argmax(output, dim=-1) * self.output_step + 1 else: return output