|
|
|
class Model(PreTrainedModel): |
|
config_class = VLMConfig |
|
|
|
def __init__(self, config: VLMConfig, image_model, language_model, num_projections: int, tokenizer, prepend_text: str, image_tokens:int): |
|
super().__init__(config) |
|
self.image_model = image_model |
|
self.language_model = language_model |
|
self.projector = nn.Sequential( |
|
*projection_layers(image_model.num_features, language_model.config.hidden_size, num_projections) |
|
) |
|
|
|
self.tokenizer = tokenizer |
|
self.eos_token = tokenizer.eos_token |
|
self.prepend_text = prepend_text |
|
|
|
self.image_tokens = image_tokens |
|
|
|
input_ids = tokenizer(prepend_text, return_tensors="pt").input_ids |
|
eos_token_index = (input_ids[0] == tokenizer.eos_token_id).nonzero(as_tuple=True)[0].item() |
|
text_embeddings = self.language_model.get_input_embeddings()(input_ids).detach() |
|
self.prepend_embeddings = text_embeddings[:, :eos_token_index] |
|
self.postpend_embeddings = text_embeddings[:, eos_token_index:] |
|
self.attention_mask = torch.ones(1, text_embeddings.shape[1] + image_tokens) |
|
self.labels = torch.full((1, self.attention_mask.shape[1]), LABEL_MASK) |
|
|
|
def project_image_features(self, images: torch.Tensor): |
|
image_features = self.image_model.forward_features(images) |
|
image_features = einops.rearrange(image_features, "bs dim w h -> bs (w h) dim") |
|
encoder_outputs = self.projector(image_features) |
|
return encoder_outputs |
|
|
|
def forward(self, images: torch.Tensor, tokenized_captions: dict[str, torch.Tensor]): |
|
image_outputs = self.project_image_features(images) |
|
caption_embeddings = self.language_model.get_input_embeddings()(tokenized_captions.input_ids).detach() |
|
device = images.device |
|
embeddings = torch.cat( |
|
[ |
|
self.prepend_embeddings.to(device).expand(len(images), -1, -1), |
|
image_outputs, |
|
self.postpend_embeddings.to(device).expand(len(images), -1, -1), |
|
caption_embeddings, |
|
], |
|
dim=1, |
|
) |
|
attention_mask = torch.cat( |
|
[ |
|
self.attention_mask.to(device).expand(len(images), -1), |
|
tokenized_captions.attention_mask |
|
], |
|
dim=1 |
|
) |
|
labels = torch.cat( |
|
[ |
|
self.labels.to(device).expand(len(images), -1), |
|
tokenized_captions.input_ids.clone() |
|
], |
|
dim=1, |
|
) |
|
labels[attention_mask == 0] = LABEL_MASK |
|
|
|
return self.language_model( |
|
inputs_embeds=embeddings, |
|
attention_mask=attention_mask, |
|
labels=labels, |
|
) |
|
|
|
def generate(self, images: torch.Tensor, generator_kwargs: dict[str, Union[int, float]]): |
|
image_outputs = self.project_image_features(images) |
|
device = images.device |
|
embeddings = torch.cat( |
|
[ |
|
self.prepend_embeddings.to(device).expand(len(images), -1, -1), |
|
image_outputs, |
|
self.postpend_embeddings.to(device).expand(len(images), -1, -1), |
|
], |
|
dim=1, |
|
) |
|
attention_mask = self.attention_mask.to(device).expand(len(images), -1) |
|
return self.language_model.generate( |
|
inputs_embeds=embeddings, |
|
attention_mask=attention_mask, |
|
eos_token_id=self.tokenizer.eos_token_id, |
|
**generator_kwargs |
|
) |
|
|