|
import torch |
|
from torch import nn |
|
from einops import rearrange |
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, hidden_dim): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.LayerNorm(dim), |
|
nn.Linear(dim, hidden_dim), |
|
nn.GELU(), |
|
nn.Linear(hidden_dim, dim), |
|
) |
|
|
|
def forward(self, x): |
|
x = self.net(x) |
|
return x |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, dim, heads=8, dim_head=64, mask=None): |
|
super().__init__() |
|
inner_dim = dim_head * heads |
|
self.heads = heads |
|
self.scale = dim_head ** -0.5 |
|
self.norm = nn.LayerNorm(dim) |
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) |
|
self.to_out = nn.Linear(inner_dim, dim, bias=False) if inner_dim != dim else nn.Identity() |
|
self.softmax = nn.Softmax(dim=-1) |
|
if mask is not None: |
|
assert len(mask.shape) == 2 |
|
mask = mask[None, None, :, :] |
|
self.register_buffer("mask", mask) |
|
self.use_mask = True |
|
else: |
|
self.use_mask = False |
|
|
|
def forward(self, x): |
|
x = self.norm(x) |
|
qkv = self.to_qkv(x).chunk(3, dim=-1) |
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) |
|
attn = q @ k.transpose(-1, -2) |
|
if self.use_mask: |
|
attn = torch.where(self.mask, attn, 1e-8) |
|
out = self.softmax(attn) @ v |
|
out = rearrange(out, 'b h n d -> b n (h d)') |
|
out = self.to_out(out) |
|
return out |
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__(self, d_model, nhead, dim_feedforward, dim_head, num_layers, mask=None): |
|
super().__init__() |
|
self.layers = nn.ModuleList([]) |
|
for _ in range(num_layers): |
|
self.layers.append(nn.ModuleList([ |
|
Attention(d_model, heads=nhead, dim_head=dim_head, mask=mask), |
|
FeedForward(d_model, dim_feedforward) |
|
])) |
|
|
|
def forward(self, x): |
|
for attn, ff in self.layers: |
|
x = attn(x) + x |
|
x = ff(x) + x |
|
return x |
|
|
|
|
|
class TransformerModel(nn.Module): |
|
config = {} |
|
|
|
def __init__(self, positional_embedding): |
|
super().__init__() |
|
self.transformer_forward = Transformer( |
|
d_model=self.config["d_model"], |
|
nhead=self.config["nhead"], |
|
dim_feedforward=self.config["dim_feedforward"], |
|
dim_head=self.config["dim_head"], |
|
num_layers=self.config["num_layers"], |
|
mask=self.config.get("mask"), |
|
) |
|
pe = positional_embedding[None, :, :] |
|
if self.config.get("trainable_pe"): |
|
self.pe = nn.Parameter(pe) |
|
else: |
|
self.register_buffer("pe", pe) |
|
|
|
def forward(self, output_shape, condition=None): |
|
assert len(condition.shape) == 3 |
|
x = self.transformer_forward(self.pe.repeat(output_shape[0], 1, 1) + condition) |
|
return x |
|
|