TiTok / modeling /blocks.py
yucornetto's picture
Upload 20 files
dada74e verified
"""Building blocks for TiTok.
Copyright (2024) Bytedance Ltd. and/or its affiliates
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Reference:
https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py
"""
import torch
import torch.nn as nn
from collections import OrderedDict
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model,
n_head,
mlp_ratio = 4.0,
act_layer = nn.GELU,
norm_layer = nn.LayerNorm
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head)
self.mlp_ratio = mlp_ratio
# optionally we can disable the FFN
if mlp_ratio > 0:
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
def attention(
self,
x: torch.Tensor
):
return self.attn(x, x, x, need_weights=False)[0]
def forward(
self,
x: torch.Tensor,
):
attn_output = self.attention(x=self.ln_1(x))
x = x + attn_output
if self.mlp_ratio > 0:
x = x + self.mlp(self.ln_2(x))
return x
def _expand_token(token, batch_size: int):
return token.unsqueeze(0).expand(batch_size, -1, -1)
class TiTokEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.image_size = config.dataset.preprocessing.crop_size
self.patch_size = config.model.vq_model.vit_enc_patch_size
self.grid_size = self.image_size // self.patch_size
self.model_size = config.model.vq_model.vit_enc_model_size
self.num_latent_tokens = config.model.vq_model.num_latent_tokens
self.token_size = config.model.vq_model.token_size
self.width = {
"small": 512,
"base": 768,
"large": 1024,
}[self.model_size]
self.num_layers = {
"small": 8,
"base": 12,
"large": 24,
}[self.model_size]
self.num_heads = {
"small": 8,
"base": 12,
"large": 16,
}[self.model_size]
self.patch_embed = nn.Conv2d(
in_channels=3, out_channels=self.width,
kernel_size=self.patch_size, stride=self.patch_size, bias=True)
scale = self.width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
self.positional_embedding = nn.Parameter(
scale * torch.randn(self.grid_size ** 2 + 1, self.width))
self.latent_token_positional_embedding = nn.Parameter(
scale * torch.randn(self.num_latent_tokens, self.width))
self.ln_pre = nn.LayerNorm(self.width)
self.transformer = nn.ModuleList()
for i in range(self.num_layers):
self.transformer.append(ResidualAttentionBlock(
self.width, self.num_heads, mlp_ratio=4.0
))
self.ln_post = nn.LayerNorm(self.width)
self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True)
def forward(self, pixel_values, latent_tokens):
batch_size = pixel_values.shape[0]
x = pixel_values
x = self.patch_embed(x)
x = x.reshape(x.shape[0], x.shape[1], -1)
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
x = x + self.positional_embedding.to(x.dtype) # shape = [*, grid ** 2 + 1, width]
latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype)
latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)
x = torch.cat([x, latent_tokens], dim=1)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
for i in range(self.num_layers):
x = self.transformer[i](x)
x = x.permute(1, 0, 2) # LND -> NLD
latent_tokens = x[:, 1+self.grid_size**2:]
latent_tokens = self.ln_post(latent_tokens)
# fake 2D shape
latent_tokens = latent_tokens.reshape(batch_size, self.width, self.num_latent_tokens, 1)
latent_tokens = self.conv_out(latent_tokens)
latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, self.num_latent_tokens)
return latent_tokens
class TiTokDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.image_size = config.dataset.preprocessing.crop_size
self.patch_size = config.model.vq_model.vit_dec_patch_size
self.grid_size = self.image_size // self.patch_size
self.model_size = config.model.vq_model.vit_dec_model_size
self.num_latent_tokens = config.model.vq_model.num_latent_tokens
self.token_size = config.model.vq_model.token_size
self.width = {
"small": 512,
"base": 768,
"large": 1024,
}[self.model_size]
self.num_layers = {
"small": 8,
"base": 12,
"large": 24,
}[self.model_size]
self.num_heads = {
"small": 8,
"base": 12,
"large": 16,
}[self.model_size]
self.decoder_embed = nn.Linear(
self.token_size, self.width, bias=True)
scale = self.width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
self.positional_embedding = nn.Parameter(
scale * torch.randn(self.grid_size ** 2 + 1, self.width))
# add mask token and query pos embed
self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
self.latent_token_positional_embedding = nn.Parameter(
scale * torch.randn(self.num_latent_tokens, self.width))
self.ln_pre = nn.LayerNorm(self.width)
self.transformer = nn.ModuleList()
for i in range(self.num_layers):
self.transformer.append(ResidualAttentionBlock(
self.width, self.num_heads, mlp_ratio=4.0
))
self.ln_post = nn.LayerNorm(self.width)
self.ffn = nn.Sequential(
nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True),
nn.Tanh(),
nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True),
)
self.conv_out = nn.Identity()
def forward(self, z_quantized):
N, C, H, W = z_quantized.shape
assert H == 1 and W == self.num_latent_tokens, f"{H}, {W}, {self.num_latent_tokens}"
x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
x = self.decoder_embed(x)
batchsize, seq_len, _ = x.shape
mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(x.dtype)
mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
mask_tokens], dim=1)
mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype)
x = x + self.latent_token_positional_embedding[:seq_len]
x = torch.cat([mask_tokens, x], dim=1)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
for i in range(self.num_layers):
x = self.transformer[i](x)
x = x.permute(1, 0, 2) # LND -> NLD
x = x[:, 1:1+self.grid_size**2] # remove cls embed
x = self.ln_post(x)
# N L D -> N D H W
x = x.permute(0, 2, 1).reshape(batchsize, self.width, self.grid_size, self.grid_size)
x = self.ffn(x.contiguous())
x = self.conv_out(x)
return x