"""Building blocks for TiTok. Copyright (2024) Bytedance Ltd. and/or its affiliates Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Reference: https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py https://github.com/baofff/U-ViT/blob/main/libs/timm.py """ import torch import torch.nn as nn from collections import OrderedDict import einops from einops.layers.torch import Rearrange class ResidualAttentionBlock(nn.Module): def __init__( self, d_model, n_head, mlp_ratio = 4.0, act_layer = nn.GELU, norm_layer = nn.LayerNorm ): super().__init__() self.ln_1 = norm_layer(d_model) self.attn = nn.MultiheadAttention(d_model, n_head) self.mlp_ratio = mlp_ratio # optionally we can disable the FFN if mlp_ratio > 0: self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model)) ])) def attention( self, x: torch.Tensor ): return self.attn(x, x, x, need_weights=False)[0] def forward( self, x: torch.Tensor, ): attn_output = self.attention(x=self.ln_1(x)) x = x + attn_output if self.mlp_ratio > 0: x = x + self.mlp(self.ln_2(x)) return x if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): ATTENTION_MODE = 'flash' else: try: import xformers import xformers.ops ATTENTION_MODE = 'xformers' except: ATTENTION_MODE = 'math' print(f'attention mode is {ATTENTION_MODE}') class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, L, C = x.shape qkv = self.qkv(x) if ATTENTION_MODE == 'flash': qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float() q, k, v = qkv[0], qkv[1], qkv[2] # B H L D x = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = einops.rearrange(x, 'B H L D -> B L (H D)') elif ATTENTION_MODE == 'xformers': qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads) q, k, v = qkv[0], qkv[1], qkv[2] # B L H D x = xformers.ops.memory_efficient_attention(q, k, v) x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads) elif ATTENTION_MODE == 'math': qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) q, k, v = qkv[0], qkv[1], qkv[2] # B H L D attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, L, C) else: raise NotImplemented x = self.proj(x) x = self.proj_drop(x) return x def drop_path(x, drop_prob: float = 0., training: bool = False): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class UViTBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.skip_linear = nn.Linear(2 * dim, dim) if skip else None self.use_checkpoint = use_checkpoint def forward(self, x, skip=None): if self.use_checkpoint: return torch.utils.checkpoint.checkpoint(self._forward, x, skip) else: return self._forward(x, skip) def _forward(self, x, skip=None): if self.skip_linear is not None: x = self.skip_linear(torch.cat([x, skip], dim=-1)) x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x def _expand_token(token, batch_size: int): return token.unsqueeze(0).expand(batch_size, -1, -1) class TiTokEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.image_size = config.dataset.preprocessing.crop_size self.patch_size = config.model.vq_model.vit_enc_patch_size self.grid_size = self.image_size // self.patch_size self.model_size = config.model.vq_model.vit_enc_model_size self.num_latent_tokens = config.model.vq_model.num_latent_tokens self.token_size = config.model.vq_model.token_size if config.model.vq_model.get("quantize_mode", "vq") == "vae": self.token_size = self.token_size * 2 # needs to split into mean and std self.is_legacy = config.model.vq_model.get("is_legacy", True) self.width = { "small": 512, "base": 768, "large": 1024, }[self.model_size] self.num_layers = { "small": 8, "base": 12, "large": 24, }[self.model_size] self.num_heads = { "small": 8, "base": 12, "large": 16, }[self.model_size] self.patch_embed = nn.Conv2d( in_channels=3, out_channels=self.width, kernel_size=self.patch_size, stride=self.patch_size, bias=True) scale = self.width ** -0.5 self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width)) self.positional_embedding = nn.Parameter( scale * torch.randn(self.grid_size ** 2 + 1, self.width)) self.latent_token_positional_embedding = nn.Parameter( scale * torch.randn(self.num_latent_tokens, self.width)) self.ln_pre = nn.LayerNorm(self.width) self.transformer = nn.ModuleList() for i in range(self.num_layers): self.transformer.append(ResidualAttentionBlock( self.width, self.num_heads, mlp_ratio=4.0 )) self.ln_post = nn.LayerNorm(self.width) self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True) def forward(self, pixel_values, latent_tokens): batch_size = pixel_values.shape[0] x = pixel_values x = self.patch_embed(x) x = x.reshape(x.shape[0], x.shape[1], -1) x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] # class embeddings and positional embeddings x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) x = x + self.positional_embedding.to(x.dtype) # shape = [*, grid ** 2 + 1, width] latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype) latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype) x = torch.cat([x, latent_tokens], dim=1) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND for i in range(self.num_layers): x = self.transformer[i](x) x = x.permute(1, 0, 2) # LND -> NLD latent_tokens = x[:, 1+self.grid_size**2:] latent_tokens = self.ln_post(latent_tokens) # fake 2D shape if self.is_legacy: latent_tokens = latent_tokens.reshape(batch_size, self.width, self.num_latent_tokens, 1) else: # Fix legacy problem. latent_tokens = latent_tokens.reshape(batch_size, self.num_latent_tokens, self.width, 1).permute(0, 2, 1, 3) latent_tokens = self.conv_out(latent_tokens) latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, self.num_latent_tokens) return latent_tokens class TiTokDecoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.image_size = config.dataset.preprocessing.crop_size self.patch_size = config.model.vq_model.vit_dec_patch_size self.grid_size = self.image_size // self.patch_size self.model_size = config.model.vq_model.vit_dec_model_size self.num_latent_tokens = config.model.vq_model.num_latent_tokens self.token_size = config.model.vq_model.token_size self.is_legacy = config.model.vq_model.get("is_legacy", True) self.width = { "small": 512, "base": 768, "large": 1024, }[self.model_size] self.num_layers = { "small": 8, "base": 12, "large": 24, }[self.model_size] self.num_heads = { "small": 8, "base": 12, "large": 16, }[self.model_size] self.decoder_embed = nn.Linear( self.token_size, self.width, bias=True) scale = self.width ** -0.5 self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width)) self.positional_embedding = nn.Parameter( scale * torch.randn(self.grid_size ** 2 + 1, self.width)) # add mask token and query pos embed self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width)) self.latent_token_positional_embedding = nn.Parameter( scale * torch.randn(self.num_latent_tokens, self.width)) self.ln_pre = nn.LayerNorm(self.width) self.transformer = nn.ModuleList() for i in range(self.num_layers): self.transformer.append(ResidualAttentionBlock( self.width, self.num_heads, mlp_ratio=4.0 )) self.ln_post = nn.LayerNorm(self.width) if self.is_legacy: self.ffn = nn.Sequential( nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True), nn.Tanh(), nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True), ) self.conv_out = nn.Identity() else: # Directly predicting RGB pixels self.ffn = nn.Sequential( nn.Conv2d(self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True), Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)', p1 = self.patch_size, p2 = self.patch_size),) self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True) def forward(self, z_quantized): N, C, H, W = z_quantized.shape assert H == 1 and W == self.num_latent_tokens, f"{H}, {W}, {self.num_latent_tokens}" x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD x = self.decoder_embed(x) batchsize, seq_len, _ = x.shape mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(x.dtype) mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype), mask_tokens], dim=1) mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype) x = x + self.latent_token_positional_embedding[:seq_len] x = torch.cat([mask_tokens, x], dim=1) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND for i in range(self.num_layers): x = self.transformer[i](x) x = x.permute(1, 0, 2) # LND -> NLD x = x[:, 1:1+self.grid_size**2] # remove cls embed x = self.ln_post(x) # N L D -> N D H W x = x.permute(0, 2, 1).reshape(batchsize, self.width, self.grid_size, self.grid_size) x = self.ffn(x.contiguous()) x = self.conv_out(x) return x