Doven
update code.
f7009b3
import torch
from torch import nn
from einops import rearrange
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
x = self.net(x)
return x
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, mask=None):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False) if inner_dim != dim else nn.Identity()
self.softmax = nn.Softmax(dim=-1)
if mask is not None:
assert len(mask.shape) == 2
mask = mask[None, None, :, :]
self.register_buffer("mask", mask)
self.use_mask = True
else: # not use mask
self.use_mask = False
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
attn = q @ k.transpose(-1, -2)
if self.use_mask:
attn = torch.where(self.mask, attn, 1e-8)
out = self.softmax(attn) @ v
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class Transformer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward, dim_head, num_layers, mask=None):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(num_layers):
self.layers.append(nn.ModuleList([
Attention(d_model, heads=nhead, dim_head=dim_head, mask=mask),
FeedForward(d_model, dim_feedforward)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class TransformerModel(nn.Module):
config = {}
def __init__(self, positional_embedding):
super().__init__()
self.transformer_forward = Transformer(
d_model=self.config["d_model"],
nhead=self.config["nhead"],
dim_feedforward=self.config["dim_feedforward"],
dim_head=self.config["dim_head"],
num_layers=self.config["num_layers"],
mask=self.config.get("mask"),
)
pe = positional_embedding[None, :, :]
if self.config.get("trainable_pe"):
self.pe = nn.Parameter(pe)
else: # fixed positional embedding
self.register_buffer("pe", pe)
def forward(self, output_shape, condition=None):
assert len(condition.shape) == 3
x = self.transformer_forward(self.pe.repeat(output_shape[0], 1, 1) + condition)
return x