Spaces:
Runtime error
Runtime error
import logging | |
import os | |
import torch | |
from torch import nn, Tensor | |
from bubogpt.common.dist_utils import download_cached_file | |
from bubogpt.common.utils import is_url | |
from bubogpt.models.Qformer import BertConfig, BertLMHeadModel | |
def disabled_train(self, mode=True): | |
"""Overwrite model.train with this function to make sure train/eval mode | |
does not change anymore.""" | |
return self | |
class BaseQFormer(nn.Module): | |
def __init__(self, freeze_qformer=False): | |
super().__init__() | |
self.freeze_qformer = freeze_qformer | |
self.Qformer = None | |
def check_and_freeze(self): | |
assert self.Qformer is not None | |
if self.freeze_qformer: | |
for name, param in self.Qformer.named_parameters(): | |
param.requires_grad = False | |
self.Qformer = self.Qformer.eval() | |
self.Qformer.train = disabled_train | |
self.query_tokens.requires_grad = False | |
logging.info("Freeze This QFormer") | |
def load_from_pretrained(self, url_or_filename): | |
if is_url(url_or_filename): | |
cached_file = download_cached_file( | |
url_or_filename, check_hash=False, progress=True | |
) | |
checkpoint = torch.load(cached_file, map_location="cpu") | |
elif os.path.isfile(url_or_filename): | |
checkpoint = torch.load(url_or_filename, map_location="cpu") | |
else: | |
raise RuntimeError("checkpoint url or path is invalid") | |
state_dict = checkpoint["model"] | |
msg = self.load_state_dict(state_dict, strict=False) | |
logging.info("Missing keys {}".format(msg.missing_keys)) | |
logging.info("load checkpoint from %s" % url_or_filename) | |
return msg | |
class SequenceGenericQFormer(BaseQFormer): | |
def __init__(self, | |
num_query_token: int, | |
encoder_width: int = 768, | |
freeze_qformer: bool = False, | |
q_former_model: str = "", | |
cross_attention_freq: int = 2 | |
): | |
super().__init__(freeze_qformer) | |
self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, encoder_width, cross_attention_freq) | |
if q_former_model != "": | |
self.load_Qformer(q_former_model) | |
self.check_and_freeze() | |
def set_Qformer(self): | |
self.Qformer.cls = None | |
self.Qformer.bert.embeddings.word_embeddings = None | |
self.Qformer.bert.embeddings.position_embeddings = None | |
for layer in self.Qformer.bert.encoder.layer: | |
layer.output = None | |
layer.intermediate = None | |
def load_Qformer(self, q_former_model): | |
self.Qformer.cls = None | |
self.Qformer.bert.embeddings.word_embeddings = None | |
self.Qformer.bert.embeddings.position_embeddings = None | |
for layer in self.Qformer.bert.encoder.layer: | |
layer.output = None | |
layer.intermediate = None | |
self.load_from_pretrained(url_or_filename=q_former_model) | |
def init_Qformer(cls, num_query_token, encoder_width, cross_attention_freq=2): | |
encoder_config = BertConfig.from_pretrained("bert-base-uncased") | |
encoder_config.encoder_width = encoder_width | |
# insert cross-attention layer every other block | |
encoder_config.add_cross_attention = True | |
encoder_config.cross_attention_freq = cross_attention_freq | |
encoder_config.query_length = num_query_token | |
Qformer = BertLMHeadModel(config=encoder_config) | |
query_tokens = nn.Parameter( | |
torch.zeros(1, num_query_token, encoder_config.hidden_size) | |
) | |
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) | |
return Qformer, query_tokens | |
def forward(self, input_embeds: Tensor) -> Tensor: | |
input_atts = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device) | |
query_tokens = self.query_tokens.expand(input_embeds.shape[0], -1, -1) | |
query_output = self.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=input_embeds, | |
encoder_attention_mask=input_atts, | |
return_dict=True, | |
) | |
return query_output.last_hidden_state | |