sachin's picture
newer model
66fd0f6 verified
class Model(PreTrainedModel):
config_class = VLMConfig
def __init__(self, config: VLMConfig, image_model, language_model, num_projections: int, tokenizer, prepend_text: str, image_tokens:int):
super().__init__(config)
self.image_model = image_model
self.language_model = language_model
self.projector = nn.Sequential(
*projection_layers(image_model.num_features, language_model.config.hidden_size, num_projections)
)
self.tokenizer = tokenizer
self.eos_token = tokenizer.eos_token
self.prepend_text = prepend_text
self.image_tokens = image_tokens
input_ids = tokenizer(prepend_text, return_tensors="pt").input_ids
eos_token_index = (input_ids[0] == tokenizer.eos_token_id).nonzero(as_tuple=True)[0].item()
text_embeddings = self.language_model.get_input_embeddings()(input_ids).detach()
self.prepend_embeddings = text_embeddings[:, :eos_token_index]
self.postpend_embeddings = text_embeddings[:, eos_token_index:]
self.attention_mask = torch.ones(1, text_embeddings.shape[1] + image_tokens)
self.labels = torch.full((1, self.attention_mask.shape[1]), LABEL_MASK)
def project_image_features(self, images: torch.Tensor):
image_features = self.image_model.forward_features(images)
image_features = einops.rearrange(image_features, "bs dim w h -> bs (w h) dim")
encoder_outputs = self.projector(image_features)
return encoder_outputs
def forward(self, images: torch.Tensor, tokenized_captions: dict[str, torch.Tensor]):
image_outputs = self.project_image_features(images)
caption_embeddings = self.language_model.get_input_embeddings()(tokenized_captions.input_ids).detach()
device = images.device
embeddings = torch.cat(
[
self.prepend_embeddings.to(device).expand(len(images), -1, -1),
image_outputs,
self.postpend_embeddings.to(device).expand(len(images), -1, -1),
caption_embeddings,
],
dim=1,
)
attention_mask = torch.cat(
[
self.attention_mask.to(device).expand(len(images), -1),
tokenized_captions.attention_mask
],
dim=1
)
labels = torch.cat(
[
self.labels.to(device).expand(len(images), -1),
tokenized_captions.input_ids.clone()
],
dim=1,
)
labels[attention_mask == 0] = LABEL_MASK
return self.language_model(
inputs_embeds=embeddings,
attention_mask=attention_mask,
labels=labels,
)
def generate(self, images: torch.Tensor, generator_kwargs: dict[str, Union[int, float]]):
image_outputs = self.project_image_features(images)
device = images.device
embeddings = torch.cat(
[
self.prepend_embeddings.to(device).expand(len(images), -1, -1),
image_outputs,
self.postpend_embeddings.to(device).expand(len(images), -1, -1),
],
dim=1,
)
attention_mask = self.attention_mask.to(device).expand(len(images), -1)
return self.language_model.generate(
inputs_embeds=embeddings,
attention_mask=attention_mask,
eos_token_id=self.tokenizer.eos_token_id,
**generator_kwargs
)