|
|
|
import os |
|
import logging |
|
from collections import OrderedDict |
|
|
|
import torch |
|
from torch import nn |
|
from einops import rearrange |
|
from timm.models.layers import DropPath |
|
from timm.models.registry import register_model |
|
|
|
import torch.utils.checkpoint as checkpoint |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True): |
|
""" |
|
Add/Remove extra temporal_embeddings as needed. |
|
https://arxiv.org/abs/2104.00650 shows adding zero paddings works. |
|
|
|
temp_embed_old: (1, num_frames_old, 1, d) |
|
temp_embed_new: (1, num_frames_new, 1, d) |
|
add_zero: bool, if True, add zero, else, interpolate trained embeddings. |
|
""" |
|
|
|
num_frms_new = temp_embed_new.shape[1] |
|
num_frms_old = temp_embed_old.shape[1] |
|
logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}") |
|
if num_frms_new > num_frms_old: |
|
if add_zero: |
|
temp_embed_new[ |
|
:, :num_frms_old |
|
] = temp_embed_old |
|
else: |
|
temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new) |
|
elif num_frms_new < num_frms_old: |
|
temp_embed_new = temp_embed_old[:, :num_frms_new] |
|
else: |
|
temp_embed_new = temp_embed_old |
|
return temp_embed_new |
|
|
|
|
|
class QuickGELU(nn.Module): |
|
def forward(self, x): |
|
return x * torch.sigmoid(1.702 * x) |
|
|
|
|
|
class ResidualAttentionBlock(nn.Module): |
|
def __init__(self, d_model, n_head, drop_path=0., attn_mask=None, dropout=0.): |
|
super().__init__() |
|
|
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
self.attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout) |
|
self.ln_1 = nn.LayerNorm(d_model) |
|
self.mlp = nn.Sequential(OrderedDict([ |
|
("c_fc", nn.Linear(d_model, d_model * 4)), |
|
("gelu", QuickGELU()), |
|
("drop1", nn.Dropout(dropout)), |
|
("c_proj", nn.Linear(d_model * 4, d_model)), |
|
("drop2", nn.Dropout(dropout)), |
|
])) |
|
self.ln_2 = nn.LayerNorm(d_model) |
|
self.attn_mask = attn_mask |
|
|
|
def attention(self, x): |
|
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
|
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] |
|
|
|
def forward(self, x): |
|
x = x + self.drop_path1(self.attention(self.ln_1(x))) |
|
x = x + self.drop_path2(self.mlp(self.ln_2(x))) |
|
return x |
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__(self, width, layers, heads, drop_path=0., checkpoint_num=0, dropout=0.): |
|
super().__init__() |
|
dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] |
|
self.resblocks = nn.ModuleList() |
|
for idx in range(layers): |
|
self.resblocks.append(ResidualAttentionBlock(width, heads, drop_path=dpr[idx], dropout=dropout)) |
|
self.checkpoint_num = checkpoint_num |
|
|
|
def forward(self, x): |
|
for idx, blk in enumerate(self.resblocks): |
|
if idx < self.checkpoint_num: |
|
x = checkpoint.checkpoint(blk, x) |
|
else: |
|
x = blk(x) |
|
return x |
|
|
|
|
|
class VisionTransformer(nn.Module): |
|
def __init__( |
|
self, input_resolution, patch_size, width, layers, heads, output_dim=None, |
|
kernel_size=1, num_frames=8, drop_path=0, checkpoint_num=0, dropout=0., |
|
temp_embed=True, |
|
): |
|
super().__init__() |
|
self.output_dim = output_dim |
|
self.conv1 = nn.Conv3d( |
|
3, width, |
|
(kernel_size, patch_size, patch_size), |
|
(kernel_size, patch_size, patch_size), |
|
(0, 0, 0), bias=False |
|
) |
|
|
|
scale = width ** -0.5 |
|
self.class_embedding = nn.Parameter(scale * torch.randn(width)) |
|
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) |
|
self.ln_pre = nn.LayerNorm(width) |
|
if temp_embed: |
|
self.temporal_positional_embedding = nn.Parameter(torch.zeros(1, num_frames, width)) |
|
|
|
self.transformer = Transformer( |
|
width, layers, heads, drop_path=drop_path, checkpoint_num=checkpoint_num, |
|
dropout=dropout) |
|
|
|
self.ln_post = nn.LayerNorm(width) |
|
if output_dim is not None: |
|
self.proj = nn.Parameter(torch.empty(width, output_dim)) |
|
else: |
|
self.proj = None |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def get_num_layers(self): |
|
return len(self.transformer.resblocks) |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
return {'positional_embedding', 'class_embedding', 'temporal_positional_embedding'} |
|
|
|
def mask_tokens(self, inputs, masking_prob=0.0): |
|
B, L, _ = inputs.shape |
|
|
|
|
|
Lm = int(masking_prob * L) |
|
masked_indices = torch.zeros(B, L) |
|
indices = torch.argsort(torch.rand_like(masked_indices), dim=-1)[:, :Lm] |
|
batch_indices = ( |
|
torch.arange(masked_indices.shape[0]).unsqueeze(-1).expand_as(indices) |
|
) |
|
masked_indices[batch_indices, indices] = 1 |
|
|
|
masked_indices = masked_indices.bool() |
|
|
|
return inputs[~masked_indices].reshape(B, -1, inputs.shape[-1]) |
|
|
|
def forward(self, x, masking_prob=0.0): |
|
x = self.conv1(x) |
|
B, C, T, H, W = x.shape |
|
x = x.permute(0, 2, 3, 4, 1).reshape(B * T, H * W, C) |
|
|
|
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) |
|
x = x + self.positional_embedding.to(x.dtype) |
|
|
|
|
|
cls_tokens = x[:B, :1, :] |
|
x = x[:, 1:] |
|
x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T) |
|
if hasattr(self, 'temporal_positional_embedding'): |
|
if x.size(1) == 1: |
|
|
|
x = x + self.temporal_positional_embedding.mean(1) |
|
else: |
|
x = x + self.temporal_positional_embedding |
|
x = rearrange(x, '(b n) t m -> b (n t) m', b=B, t=T) |
|
|
|
if masking_prob > 0.0: |
|
x = self.mask_tokens(x, masking_prob) |
|
|
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
x = self.ln_pre(x) |
|
|
|
x = x.permute(1, 0, 2) |
|
x = self.transformer(x) |
|
|
|
x = self.ln_post(x) |
|
|
|
if self.proj is not None: |
|
x = self.dropout(x[0]) @ self.proj |
|
else: |
|
x = x.permute(1, 0, 2) |
|
|
|
return x |
|
|
|
|
|
def inflate_weight(weight_2d, time_dim, center=True): |
|
logger.info(f'Init center: {center}') |
|
if center: |
|
weight_3d = torch.zeros(*weight_2d.shape) |
|
weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) |
|
middle_idx = time_dim // 2 |
|
weight_3d[:, :, middle_idx, :, :] = weight_2d |
|
else: |
|
weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) |
|
weight_3d = weight_3d / time_dim |
|
return weight_3d |
|
|
|
|
|
def load_state_dict(model, state_dict, input_resolution=224, patch_size=16, center=True): |
|
state_dict_3d = model.state_dict() |
|
for k in state_dict.keys(): |
|
if k in state_dict_3d.keys() and state_dict[k].shape != state_dict_3d[k].shape: |
|
if len(state_dict_3d[k].shape) <= 2: |
|
logger.info(f'Ignore: {k}') |
|
continue |
|
logger.info(f'Inflate: {k}, {state_dict[k].shape} => {state_dict_3d[k].shape}') |
|
time_dim = state_dict_3d[k].shape[2] |
|
state_dict[k] = inflate_weight(state_dict[k], time_dim, center=center) |
|
|
|
pos_embed_checkpoint = state_dict['positional_embedding'] |
|
embedding_size = pos_embed_checkpoint.shape[-1] |
|
num_patches = (input_resolution // patch_size) ** 2 |
|
orig_size = int((pos_embed_checkpoint.shape[-2] - 1) ** 0.5) |
|
new_size = int(num_patches ** 0.5) |
|
if orig_size != new_size: |
|
logger.info(f'Pos_emb from {orig_size} to {new_size}') |
|
extra_tokens = pos_embed_checkpoint[:1] |
|
pos_tokens = pos_embed_checkpoint[1:] |
|
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) |
|
pos_tokens = torch.nn.functional.interpolate( |
|
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) |
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(0, 2) |
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0) |
|
state_dict['positional_embedding'] = new_pos_embed |
|
|
|
message = model.load_state_dict(state_dict, strict=False) |
|
logger.info(f"Load pretrained weights: {message}") |
|
|
|
|
|
@register_model |
|
def clip_joint_b16( |
|
pretrained=False, input_resolution=224, kernel_size=1, |
|
center=True, num_frames=8, drop_path=0., checkpoint_num=0, |
|
dropout=0., |
|
): |
|
model = VisionTransformer( |
|
input_resolution=input_resolution, patch_size=16, |
|
width=768, layers=12, heads=12, output_dim=512, |
|
kernel_size=kernel_size, num_frames=num_frames, |
|
drop_path=drop_path, checkpoint_num=checkpoint_num, |
|
dropout=dropout, |
|
) |
|
if pretrained: |
|
if isinstance(pretrained, str): |
|
model_name = pretrained |
|
else: |
|
model_name = "ViT-B/16" |
|
|
|
logger.info('load pretrained weights') |
|
state_dict = torch.load(_MODELS[model_name], map_location='cpu') |
|
load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=16, center=center) |
|
return model.eval() |
|
|
|
|
|
@register_model |
|
def clip_joint_l14( |
|
pretrained=False, input_resolution=224, kernel_size=1, |
|
center=True, num_frames=8, drop_path=0., checkpoint_num=0, |
|
dropout=0., |
|
): |
|
model = VisionTransformer( |
|
input_resolution=input_resolution, patch_size=14, |
|
width=1024, layers=24, heads=16, output_dim=768, |
|
kernel_size=kernel_size, num_frames=num_frames, |
|
drop_path=drop_path, checkpoint_num=checkpoint_num, |
|
dropout=dropout, |
|
) |
|
|
|
if pretrained: |
|
if isinstance(pretrained, str): |
|
model_name = pretrained |
|
else: |
|
model_name = "ViT-L/14" |
|
logger.info('load pretrained weights') |
|
state_dict = torch.load(_MODELS[model_name], map_location='cpu') |
|
load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center) |
|
return model.eval() |
|
|
|
|
|
@register_model |
|
def clip_joint_l14_336( |
|
pretrained=True, input_resolution=336, kernel_size=1, |
|
center=True, num_frames=8, drop_path=0. |
|
): |
|
raise NotImplementedError |
|
model = VisionTransformer( |
|
input_resolution=input_resolution, patch_size=14, |
|
width=1024, layers=24, heads=16, output_dim=768, |
|
kernel_size=kernel_size, num_frames=num_frames, |
|
drop_path=drop_path, |
|
) |
|
if pretrained: |
|
logger.info('load pretrained weights') |
|
state_dict = torch.load(_MODELS["ViT-L/14_336"], map_location='cpu') |
|
load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center) |
|
return model.eval() |
|
|
|
|
|
def interpolate_pos_embed_vit(state_dict, new_model): |
|
key = "vision_encoder.temporal_positional_embedding" |
|
if key in state_dict: |
|
vision_temp_embed_new = new_model.state_dict()[key] |
|
vision_temp_embed_new = vision_temp_embed_new.unsqueeze(2) |
|
vision_temp_embed_old = state_dict[key] |
|
vision_temp_embed_old = vision_temp_embed_old.unsqueeze(2) |
|
|
|
state_dict[key] = load_temp_embed_with_mismatch( |
|
vision_temp_embed_old, vision_temp_embed_new, add_zero=False |
|
).squeeze(2) |
|
|
|
key = "text_encoder.positional_embedding" |
|
if key in state_dict: |
|
text_temp_embed_new = new_model.state_dict()[key] |
|
text_temp_embed_new = text_temp_embed_new.unsqueeze(0).unsqueeze(2) |
|
text_temp_embed_old = state_dict[key] |
|
text_temp_embed_old = text_temp_embed_old.unsqueeze(0).unsqueeze(2) |
|
|
|
state_dict[key] = load_temp_embed_with_mismatch( |
|
text_temp_embed_old, text_temp_embed_new, add_zero=False |
|
).squeeze(2).squeeze(0) |
|
return state_dict |