File size: 1,341 Bytes
f7009b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from torch import nn




class GateMLP(nn.Module):
    def __init__(self, d_model, expand):
        super().__init__()
        self.proj_1 = nn.Linear(d_model, d_model * expand, bias=False)
        self.proj_2 = nn.Linear(d_model, d_model * expand, bias=False)
        self.proj_3 = nn.Linear(d_model * expand, d_model, bias=True)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        x, x1 = self.proj_1(x), self.proj_2(x)
        x = x * torch.sigmoid(x1)
        x = self.proj_3(x)
        x = self.layer_norm(x)
        return x


class GMLPModel(nn.Module):
    config = {}

    def __init__(self, positional_embedding):
        super().__init__()
        gmlp_config = {
            "d_model": self.config["d_model"],
            "expand": self.config["expand"],
        }
        self.gmlp_forward = nn.Sequential(*[GateMLP(**gmlp_config) for _ in range(self.config["num_layers"])])
        pe = positional_embedding[None, :, :]
        if self.config.get("trainable_pe"):
            self.pe = nn.Parameter(pe)
        else:  # fixed positional embedding
            self.register_buffer("pe", pe)

    def forward(self, output_shape, condition=None):
        assert len(condition.shape) == 3
        x = self.gmlp_forward(self.pe.repeat(output_shape[0], 1, 1) + condition)
        return x