Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
from typing import List, Optional | |
from torch import Tensor, nn | |
import copy | |
from cliport.models.resnet import IdentityBlock, ConvBlock | |
from cliport.models.core.unet import Up | |
from cliport.models.core import fusion | |
from cliport.models.core.fusion import FusionConvLat | |
from cliport.models.backbone_full import Backbone | |
from cliport.models.misc import NestedTensor | |
from cliport.models.position_encoding import build_position_encoding | |
from transformers import RobertaModel, RobertaTokenizerFast | |
class FeatureResizer(nn.Module): | |
""" | |
This class takes as input a set of embeddings of dimension C1 and outputs a set of | |
embedding of dimension C2, after a linear transformation, dropout and normalization (LN). | |
""" | |
def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True): | |
super().__init__() | |
self.do_ln = do_ln | |
# Object feature encoding | |
self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True) | |
self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, encoder_features): | |
x = self.fc(encoder_features) | |
if self.do_ln: | |
x = self.layer_norm(x) | |
output = self.dropout(x) | |
return output | |
class MDETRLingUNetLat_fuse(nn.Module): | |
""" CLIP RN50 with U-Net skip connections and lateral connections """ | |
def __init__(self, input_shape, output_dim, cfg, device, preprocess): | |
super(MDETRLingUNetLat_fuse, self).__init__() | |
self.input_shape = input_shape | |
self.output_dim = output_dim | |
self.input_dim = 2048 # penultimate layer channel-size of mdetr | |
self.cfg = cfg | |
self.device = device | |
self.batchnorm = self.cfg['train']['batchnorm'] | |
self.lang_fusion_type = self.cfg['train']['lang_fusion_type'] | |
self.bilinear = True | |
self.up_factor = 2 if self.bilinear else 1 | |
self.preprocess = preprocess | |
self.backbone = Backbone('resnet101', True, True, False) | |
self.position_embedding = build_position_encoding() | |
self.input_proj = nn.Conv2d(2048, 256, kernel_size=1) | |
self.tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base') | |
self.text_encoder = RobertaModel.from_pretrained('roberta-base') | |
self.resizer = FeatureResizer( | |
input_feat_size=768, | |
output_feat_size=256, | |
dropout=0.1, | |
) | |
encoder_layer = TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=2048, dropout=0.1, activation='relu', normalize_before=False) | |
self.encoder = TransformerEncoder(encoder_layer, 6, None) | |
mdter_checkpoint = torch.load('/home/yzc/shared/project/GPT-CLIPort/ckpts/mdetr_pretrained_resnet101_checkpoint.pth', map_location="cpu")['model'] | |
checkpoint_new = {} | |
for param in mdter_checkpoint: | |
if 'transformer.text_encoder' in param or 'transformer.encoder.' in param or 'input_proj' in param or 'resizer' in param: | |
param_new = param.replace('transformer.','') | |
checkpoint_new[param_new] = mdter_checkpoint[param] | |
elif 'backbone.0.body' in param: | |
param_new = param.replace('backbone.0.body', 'backbone.body') | |
checkpoint_new[param_new] = mdter_checkpoint[param] | |
self.load_state_dict(checkpoint_new, True) | |
self._build_decoder() | |
def _build_decoder(self): | |
# language | |
self.up_fuse1 = nn.UpsamplingBilinear2d(scale_factor=2) | |
self.up_fuse2 = nn.UpsamplingBilinear2d(scale_factor=4) | |
self.up_fuse3 = nn.UpsamplingBilinear2d(scale_factor=8) | |
self.lang_fuser1 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 2) | |
self.lang_fuser2 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 4) | |
self.lang_fuser3 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 8) | |
self.proj_input_dim = 768 | |
self.lang_proj1 = nn.Linear(self.proj_input_dim, 1024) | |
self.lang_proj2 = nn.Linear(self.proj_input_dim, 512) | |
self.lang_proj3 = nn.Linear(self.proj_input_dim, 256) | |
# vision | |
self.conv1 = nn.Sequential( | |
nn.Conv2d(self.input_dim+256, 1024, kernel_size=3, stride=1, padding=1, bias=False), | |
nn.ReLU(True) | |
) | |
self.up1 = Up(2048+256, 1024 // self.up_factor, self.bilinear) | |
self.lat_fusion1 = FusionConvLat(input_dim=1024+512, output_dim=512) | |
self.up2 = Up(1024+256, 512 // self.up_factor, self.bilinear) | |
self.lat_fusion2 = FusionConvLat(input_dim=512+256, output_dim=256) | |
self.up3 = Up(512+256, 256 // self.up_factor, self.bilinear) | |
self.lat_fusion3 = FusionConvLat(input_dim=256+128, output_dim=128) | |
self.layer1 = nn.Sequential( | |
ConvBlock(128, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm), | |
IdentityBlock(64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm), | |
nn.UpsamplingBilinear2d(scale_factor=2), | |
) | |
self.lat_fusion4 = FusionConvLat(input_dim=128+64, output_dim=64) | |
self.layer2 = nn.Sequential( | |
ConvBlock(64, [32, 32, 32], kernel_size=3, stride=1, batchnorm=self.batchnorm), | |
IdentityBlock(32, [32, 32, 32], kernel_size=3, stride=1, batchnorm=self.batchnorm), | |
nn.UpsamplingBilinear2d(scale_factor=2), | |
) | |
self.lat_fusion5 = FusionConvLat(input_dim=64+32, output_dim=32) | |
self.layer3 = nn.Sequential( | |
ConvBlock(32, [16, 16, 16], kernel_size=3, stride=1, batchnorm=self.batchnorm), | |
IdentityBlock(16, [16, 16, 16], kernel_size=3, stride=1, batchnorm=self.batchnorm), | |
nn.UpsamplingBilinear2d(scale_factor=2), | |
) | |
self.lat_fusion6 = FusionConvLat(input_dim=32+16, output_dim=16) | |
self.conv2 = nn.Sequential( | |
nn.Conv2d(16, self.output_dim, kernel_size=1) | |
) | |
def encode_image(self, img): | |
img = NestedTensor.from_tensor_list(img) | |
with torch.no_grad(): | |
xs = self.backbone(img) | |
out = [] | |
pos = [] | |
for name, x in xs.items(): | |
out.append(x) | |
# position encoding | |
pos.append(self.position_embedding(x).to(x.tensors.dtype)) | |
return out, pos | |
def encode_text(self, x): | |
with torch.no_grad(): | |
tokenized = self.tokenizer.batch_encode_plus(x, padding="longest", return_tensors="pt").to(self.device) | |
encoded_text = self.text_encoder(**tokenized) | |
# Transpose memory because pytorch's attention expects sequence first | |
text_memory = encoded_text.last_hidden_state.transpose(0, 1) | |
text_memory_mean = torch.mean(text_memory, 0) | |
# Invert attention mask that we get from huggingface because its the opposite in pytorch transformer | |
text_attention_mask = tokenized.attention_mask.ne(1).bool() | |
# Resize the encoder hidden states to be of the same d_model as the decoder | |
text_memory_resized = self.resizer(text_memory) | |
return text_memory_resized, text_attention_mask, text_memory_mean | |
def forward(self, x, lat, l): | |
x = self.preprocess(x, dist='mdetr') | |
in_type = x.dtype | |
in_shape = x.shape | |
x = x[:,:3] # select RGB | |
x = x.permute(0, 1, 3, 2) | |
with torch.no_grad(): | |
features, pos = self.encode_image(x) | |
x1, mask = features[-1].decompose() | |
x2, _ = features[-2].decompose() | |
x3, _ = features[-3].decompose() | |
x4, _ = features[-4].decompose() | |
#print(x1.shape, x2.shape, x3.shape, x4.shape) | |
src = self.input_proj(x1) | |
pos_embed = pos[-1] | |
bs, c, h, w = src.shape | |
src = src.flatten(2).permute(2, 0, 1) | |
device = self.device | |
pos_embed = pos_embed.flatten(2).permute(2, 0, 1) | |
mask = mask.flatten(1) | |
text_memory_resized, text_attention_mask, l_input = self.encode_text(l) | |
# l_input = l_input.view(1, -1) | |
# text_memory_resized = text_memory_resized.repeat(1, src.shape[1], 1) | |
# text_attention_mask = text_attention_mask.repeat(src.shape[1], 1) | |
#print(src.shape, text_memory_resized.shape, mask.shape, text_attention_mask.shape) | |
if src.shape[1] == int(36*8): | |
text_memory_resized = text_memory_resized.repeat_interleave(36, dim=1) | |
l_input = l_input.repeat_interleave(36, dim=0) | |
text_attention_mask = text_attention_mask.repeat_interleave(36, dim=0) | |
src = torch.cat([src, text_memory_resized], dim=0) | |
# For mask, sequence dimension is second | |
mask = torch.cat([mask, text_attention_mask], dim=1) | |
# Pad the pos_embed with 0 so that the addition will be a no-op for the text tokens | |
pos_embed = torch.cat([pos_embed, torch.zeros_like(text_memory_resized)], dim=0) | |
img_memory, img_memory_all = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) | |
dim = img_memory.shape[-1] | |
fuse1 = img_memory_all[-1][:h*w].permute(1,2,0).reshape(bs, dim, h, w) | |
fuse2 = self.up_fuse1(img_memory_all[-2][:h*w].permute(1,2,0).reshape(bs, dim, h, w)) | |
fuse3 = self.up_fuse2(img_memory_all[-3][:h*w].permute(1,2,0).reshape(bs, dim, h, w)) | |
fuse4 = self.up_fuse3(img_memory_all[-4][:h*w].permute(1,2,0).reshape(bs, dim, h, w)) | |
assert x1.shape[1] == self.input_dim | |
x1 = torch.cat((x1, fuse1), 1) | |
x2 = torch.cat((x2, fuse2), 1) | |
x3 = torch.cat((x3, fuse3), 1) | |
x4 = torch.cat((x4, fuse4), 1) | |
x = self.conv1(x1) | |
x = self.lang_fuser1(x, l_input, x2_mask=None, x2_proj=self.lang_proj1) | |
x = self.up1(x, x2) | |
x = self.lat_fusion1(x, lat[-6].permute(0, 1, 3, 2)) | |
x = self.lang_fuser2(x, l_input, x2_mask=None, x2_proj=self.lang_proj2) | |
x = self.up2(x, x3) | |
x = self.lat_fusion2(x, lat[-5].permute(0, 1, 3, 2)) | |
x = self.lang_fuser3(x, l_input, x2_mask=None, x2_proj=self.lang_proj3) | |
x = self.up3(x, x4) | |
x = self.lat_fusion3(x, lat[-4].permute(0, 1, 3, 2)) | |
x = self.layer1(x) | |
x = self.lat_fusion4(x, lat[-3].permute(0, 1, 3, 2)) | |
x = self.layer2(x) | |
x = self.lat_fusion5(x, lat[-2].permute(0, 1, 3, 2)) | |
x = self.layer3(x) | |
x = self.lat_fusion6(x, lat[-1].permute(0, 1, 3, 2)) | |
x = self.conv2(x) | |
x = F.interpolate(x, size=(in_shape[-1], in_shape[-2]), mode='bilinear') | |
x = x.permute(0, 1, 3, 2) | |
return x | |
class TransformerEncoder(nn.Module): | |
def __init__(self, encoder_layer, num_layers, norm=None): | |
super().__init__() | |
self.layers = _get_clones(encoder_layer, num_layers) | |
self.num_layers = num_layers | |
self.norm = norm | |
def forward( | |
self, | |
src, | |
mask: Optional[Tensor] = None, | |
src_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
): | |
output = src | |
output_all = [] | |
for layer in self.layers: | |
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos) | |
output_all.append(output) | |
if self.norm is not None: | |
output = self.norm(output) | |
return output, output_all | |
class TransformerEncoderLayer(nn.Module): | |
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False): | |
super().__init__() | |
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | |
# Implementation of Feedforward model | |
self.linear1 = nn.Linear(d_model, dim_feedforward) | |
self.dropout = nn.Dropout(dropout) | |
self.linear2 = nn.Linear(dim_feedforward, d_model) | |
self.norm1 = nn.LayerNorm(d_model) | |
self.norm2 = nn.LayerNorm(d_model) | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
self.activation = _get_activation_fn(activation) | |
self.normalize_before = normalize_before | |
print(self.normalize_before) | |
def with_pos_embed(self, tensor, pos: Optional[Tensor]): | |
return tensor if pos is None else tensor + pos | |
def forward_post( | |
self, | |
src, | |
src_mask: Optional[Tensor] = None, | |
src_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
): | |
q = k = self.with_pos_embed(src, pos) | |
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] | |
src = src + self.dropout1(src2) | |
src = self.norm1(src) | |
src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) | |
src = src + self.dropout2(src2) | |
src = self.norm2(src) | |
return src | |
def forward_pre( | |
self, | |
src, | |
src_mask: Optional[Tensor] = None, | |
src_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
): | |
src2 = self.norm1(src) | |
q = k = self.with_pos_embed(src2, pos) | |
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] | |
src = src + self.dropout1(src2) | |
src2 = self.norm2(src) | |
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) | |
src = src + self.dropout2(src2) | |
return src | |
def forward( | |
self, | |
src, | |
src_mask: Optional[Tensor] = None, | |
src_key_padding_mask: Optional[Tensor] = None, | |
pos: Optional[Tensor] = None, | |
): | |
if self.normalize_before: | |
return self.forward_pre(src, src_mask, src_key_padding_mask, pos) | |
return self.forward_post(src, src_mask, src_key_padding_mask, pos) | |
def _get_clones(module, N): | |
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
def _get_activation_fn(activation): | |
"""Return an activation function given a string""" | |
if activation == "relu": | |
return F.relu | |
if activation == "gelu": | |
return F.gelu | |
if activation == "glu": | |
return F.glu | |
raise RuntimeError(f"activation should be relu/gelu, not {activation}.") | |