File size: 1,633 Bytes
052c3ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# -*- coding: utf-8 -*-

# Copyright 2024 Wen-Chin Huang
#  MIT License (https://opensource.org/licenses/MIT)

# LDNet modules
# taken from: https://github.com/unilight/LDNet/blob/main/models/modules.py (written by myself)

import torch
from torch import nn

STRIDE = 3

class Projection(nn.Module):
    def __init__(
        self,
        in_dim,
        hidden_dim,
        activation,
        output_type,
        _output_dim,
        output_step=1.0,
        range_clipping=False,
    ):
        super(Projection, self).__init__()
        self.output_type = output_type
        self.range_clipping = range_clipping
        if output_type == "scalar":
            output_dim = 1
            if range_clipping:
                self.proj = nn.Tanh()
        elif output_type == "categorical":
            output_dim = _output_dim
            self.output_step = output_step
        else:
            raise NotImplementedError("wrong output_type: {}".format(output_type))

        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            activation(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x, inference=False):
        output = self.net(x)

        # scalar / categorical
        if self.output_type == "scalar":
            # range clipping
            if self.range_clipping:
                return self.proj(output) * 2.0 + 3
            else:
                return output
        else:
            if inference:
                return torch.argmax(output, dim=-1) * self.output_step + 1
            else:
                return output