seminar-demo / module.py
ilhamap's picture
Upload 4 files
cee5099 verified
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange, repeat
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads, dim_head, dropout):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
# x:[b,n,dim]
b, n, _, h = *x.shape, self.heads
# get qkv tuple:([b,n,head_num*head_dim],[...],[...])
qkv = self.to_qkv(x).chunk(3, dim = -1)
# split q,k,v from [b,n,head_num*head_dim] -> [b,head_num,n,head_dim]
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
# transpose(k) * q / sqrt(head_dim) -> [b,head_num,n,n]
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
mask_value = -torch.finfo(dots.dtype).max
# mask value: -inf
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, mask_value)
del mask
# softmax normalization -> attention matrix
attn = dots.softmax(dim=-1)
# value * attention matrix -> output
out = torch.einsum('bhij,bhjd->bhid', attn, v)
# cat all output -> [b, n, head_num*head_dim]
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class CrossAttention(nn.Module):
def __init__(self, dim, heads, dim_head, dropout):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.to_k = nn.Linear(dim, inner_dim , bias=False)
self.to_v = nn.Linear(dim, inner_dim , bias = False)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x_qkv):
b, n, _, h = *x_qkv.shape, self.heads
k = self.to_k(x_qkv)
k = rearrange(k, 'b n (h d) -> b h n d', h = h)
v = self.to_v(x_qkv)
v = rearrange(v, 'b n (h d) -> b h n d', h = h)
q = self.to_q(x_qkv[:, 0].unsqueeze(1))
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_head, dropout, num_channel):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_head, dropout = dropout)))
]))
self.skipcat = nn.ModuleList([])
for _ in range(depth-2):
self.skipcat.append(nn.Conv2d(num_channel+1, num_channel+1, [1, 2], 1, 0))
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
return x
class SSTransformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_head, b_dim, b_depth, b_heads, b_dim_head, b_mlp_head, num_patches, dropout):
super().__init__()
self.layers = nn.ModuleList([])
self.k_layers = nn.ModuleList([])
self.channels_to_embedding = nn.Linear(num_patches, b_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, b_dim))
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_head, dropout = dropout)))
]))
for _ in range(b_depth):
self.k_layers.append(nn.ModuleList([
Residual(PreNorm(b_dim, Attention(dim=b_dim, heads=b_heads, dim_head=b_dim_head, dropout = dropout))),
Residual(PreNorm(b_dim, FeedForward(b_dim, b_mlp_head, dropout = dropout)))
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
x = rearrange(x, 'b n d -> b d n')
x = self.channels_to_embedding(x)
b, d, n = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
for attn, ff in self.k_layers:
x = attn(x, mask = mask)
x = ff(x)
return x
class SSTransformer_pyramid(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_head, b_dim, b_depth, b_heads, b_dim_head, b_mlp_head, num_patches, dropout):
super().__init__()
self.layers = nn.ModuleList([])
self.k_layers = nn.ModuleList([])
self.channels_to_embedding = nn.Linear(num_patches, b_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, b_dim))
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_head, dropout = dropout)))
]))
for _ in range(b_depth):
self.k_layers.append(nn.ModuleList([
Residual(PreNorm(b_dim, Attention(dim=b_dim, heads=b_heads, dim_head=b_dim_head, dropout = dropout))),
Residual(PreNorm(b_dim, FeedForward(b_dim, b_mlp_head, dropout = dropout)))
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
out_feature = x
x = rearrange(x, 'b n d -> b d n')
x = self.channels_to_embedding(x)
b, d, n = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
for attn, ff in self.k_layers:
x = attn(x, mask = mask)
x = ff(x)
return x, out_feature
class ViT(nn.Module):
def __init__(self, image_size, near_band, num_patches, num_classes, dim, depth, heads, mlp_dim, pool='cls', channel_dim=1, dim_head = 16, dropout=0., emb_dropout=0., mode='ViT'):
super().__init__()
patch_dim = image_size ** 2 * near_band
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(channel_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, num_patches, mode)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, x, mask = None):
# patchs[batch, patch_num, patch_size*patch_size*c] [batch,200,145*145]
# x = rearrange(x, 'b c h w -> b c (h w)')
## embedding every patch vector to embedding size: [batch, patch_num, embedding_size]
x = self.patch_to_embedding(x) #[b,n,dim]
b, n, _ = x.shape
# add position embedding
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) #[b,1,dim]
x = torch.cat((cls_tokens, x), dim = 1) #[b,n+1,dim]
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
# transformer: x[b,n + 1,dim] -> x[b,n + 1,dim]
x = self.transformer(x, mask)
# classification: using cls_token output
x = self.to_latent(x[:,0])
# MLP classification layer
return self.mlp_head(x)
class SSFormer_v4(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_head, b_dim, b_depth, b_heads, b_dim_head, b_mlp_head, num_patches, dropout, mode):
super().__init__()
self.layers = nn.ModuleList([])
self.k_layers = nn.ModuleList([])
self.channels_to_embedding = nn.Linear(num_patches, b_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, b_dim))
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_head, dropout = dropout)))
]))
for _ in range(b_depth):
self.k_layers.append(nn.ModuleList([
Residual(PreNorm(b_dim, Attention(dim=b_dim, heads=b_heads, dim_head=b_dim_head, dropout = dropout))),
Residual(PreNorm(b_dim, FeedForward(b_dim, b_mlp_head, dropout = dropout)))
]))
self.mode = mode
def forward(self, x, c, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
x = rearrange(x, 'b n d -> b d n')
x = self.channels_to_embedding(x)
b, d, n = x.shape
cls_tokens = repeat(c, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
for attn, ff in self.k_layers:
x = attn(x, mask = mask)
x = ff(x)
return x