BuboGPT / imagebind /models /multimodal_formers.py
ikuinen99's picture
update
e4bd7f9
raw
history blame
4.24 kB
import logging
import os
import torch
from torch import nn, Tensor
from bubogpt.common.dist_utils import download_cached_file
from bubogpt.common.utils import is_url
from bubogpt.models.Qformer import BertConfig, BertLMHeadModel
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class BaseQFormer(nn.Module):
def __init__(self, freeze_qformer=False):
super().__init__()
self.freeze_qformer = freeze_qformer
self.Qformer = None
def check_and_freeze(self):
assert self.Qformer is not None
if self.freeze_qformer:
for name, param in self.Qformer.named_parameters():
param.requires_grad = False
self.Qformer = self.Qformer.eval()
self.Qformer.train = disabled_train
self.query_tokens.requires_grad = False
logging.info("Freeze This QFormer")
def load_from_pretrained(self, url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
msg = self.load_state_dict(state_dict, strict=False)
logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)
return msg
class SequenceGenericQFormer(BaseQFormer):
def __init__(self,
num_query_token: int,
encoder_width: int = 768,
freeze_qformer: bool = False,
q_former_model: str = "",
cross_attention_freq: int = 2
):
super().__init__(freeze_qformer)
self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, encoder_width, cross_attention_freq)
if q_former_model != "":
self.load_Qformer(q_former_model)
self.check_and_freeze()
def set_Qformer(self):
self.Qformer.cls = None
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
def load_Qformer(self, q_former_model):
self.Qformer.cls = None
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.load_from_pretrained(url_or_filename=q_former_model)
@classmethod
def init_Qformer(cls, num_query_token, encoder_width, cross_attention_freq=2):
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
encoder_config.encoder_width = encoder_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
return Qformer, query_tokens
def forward(self, input_embeds: Tensor) -> Tensor:
input_atts = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
query_tokens = self.query_tokens.expand(input_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=input_embeds,
encoder_attention_mask=input_atts,
return_dict=True,
)
return query_output.last_hidden_state