import torch from .vision_encoder import VisionEncoder from .configuration_moondream import MoondreamConfig from transformers import PreTrainedModel import re from .modeling_phi import PhiForCausalLM from .configuration_moondream import PhiConfig class Moondream(PreTrainedModel): config_class = MoondreamConfig _supports_flash_attn_2 = True def __init__(self, config): super().__init__(config) self.vision_encoder = VisionEncoder() if type(config.phi_config) == dict: phi_config = PhiConfig( **config.phi_config, attn_implementation=config._attn_implementation ) else: phi_config = config.phi_config self.text_model = PhiForCausalLM(phi_config) @property def device(self): return self.text_model.device def encode_image(self, image): return self.vision_encoder(image) def input_embeds(self, prompt, image_embeds, tokenizer): def _tokenize(txt): return tokenizer( txt, return_tensors="pt", add_special_tokens=False ).input_ids.to(self.device) text_emb = self.text_model.get_input_embeddings() # Add BOS token embeds = [] embeds.append( text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device))) ) if "" not in prompt: embeds.append(text_emb(_tokenize(prompt))) else: assert prompt.count("") == 1 before, after = prompt.split("") embeds.append(text_emb(_tokenize(f"{before}"))) embeds.append(image_embeds.to(self.device)) embeds.append(text_emb(_tokenize(f"{after}"))) return torch.cat(embeds, dim=1) def generate( self, image_embeds, prompt, tokenizer, eos_text="", max_new_tokens=128, **kwargs, ): eos_tokens = tokenizer(eos_text, add_special_tokens=False)[0].ids generate_config = { "eos_token_id": eos_tokens, "bos_token_id": tokenizer.bos_token_id, "pad_token_id": tokenizer.eos_token_id, "max_new_tokens": max_new_tokens, **kwargs, } with torch.no_grad(): inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer) output_ids = self.text_model.generate( inputs_embeds=inputs_embeds, **generate_config ) return tokenizer.batch_decode(output_ids, skip_special_tokens=True) def answer_question( self, image_embeds, question, tokenizer, chat_history="", result_queue=None, **kwargs, ): prompt = f"\n\n{chat_history}Question: {question}\n\nAnswer: " answer = self.generate( image_embeds, prompt, eos_text="", tokenizer=tokenizer, max_new_tokens=512, **kwargs, )[0] cleaned_answer = re.sub("<$|", add_special_tokens=False)[0].ids image_embeds = self.encode_image(images) templated_prompts = [ f"\n\nQuestion: {prompt}\n\nAnswer: " for prompt in prompts ] prompt_embs = [ self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0] for prompt, image_embed in zip(templated_prompts, image_embeds) ] bos_emb = prompt_embs[0][0] max_len = max([p.shape[0] for p in prompt_embs]) inputs_embeds = torch.cat( [ torch.cat([bos_emb.repeat(max_len - p.shape[0], 1), p]).unsqueeze(0) for p in prompt_embs ], dim=0, ) attention_mask = torch.cat( [ torch.cat( [ torch.zeros( 1, max_len - p.shape[0], device=self.device, dtype=torch.long, ), torch.ones(1, p.shape[0], device=self.device, dtype=torch.long), ], dim=1, ) for p in prompt_embs ], dim=0, ) generate_config = { "eos_token_id": eos_tokens, "bos_token_id": tokenizer.bos_token_id, "pad_token_id": tokenizer.eos_token_id, "max_new_tokens": 512, **kwargs, } with torch.no_grad(): output_ids = self.text_model.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generate_config, ) return [ re.sub("<$|