import warnings import os import torch from peft import LoraConfig from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PretrainedConfig, AutoConfig, GenerationConfig from jinja2.exceptions import TemplateError def add_memory_tokens_to_inputs(input_ids: torch.Tensor, attention_mask: torch.Tensor, n_mem_tokens: int, tokenizer): """ Concatenate the input ids with n_mem_tokens mem_tokens and update the corresponding attention mask """ assert len(tokenizer.mem_tokens) == n_mem_tokens, f"{len(tokenizer.mem_tokens)} VS {n_mem_tokens}" mem_tokens = torch.stack([tokenizer.mem_token_ids_pt] * input_ids.size(0), 0) assert len(mem_tokens.size()) == 2 assert len(mem_tokens) == input_ids.size(0) assert len(mem_tokens[0]) == n_mem_tokens #mem_tokens = torch.full((input_ids.size(0), n_mem_tokens), tokenizer.mem_token_id, dtype=torch.long) input_ids = torch.cat([input_ids, mem_tokens], dim=1) attention_mask = torch.cat([attention_mask, torch.ones(input_ids.size(0), n_mem_tokens)], dim=1) return input_ids, attention_mask class PISCOConfig(PretrainedConfig): model_type = "PISCO" def __init__(self, decoder_model_name: str = "meta-llama/Llama-2-7b-chat-hf", compr_rate: int = 16, **kwargs): super().__init__(**kwargs) self.decoder_model_name = decoder_model_name # model name of decoder self.compr_rate = compr_rate # compression rate self.lora_r = 16 self.sep = True class PISCO(PreTrainedModel): config_class = PISCOConfig def __init__(self, cfg): super().__init__(cfg) self.decoder_model_name = cfg.decoder_model_name self.sep = cfg.sep self.compr_rate = cfg.compr_rate self.create_tokenizer(cfg) # Base model config but we modify vocab size since we added tokens (mainly the mem tokens) decoder_config = AutoConfig.from_pretrained(cfg.decoder_model_name) decoder_config.vocab_size = len(self.tokenizer) # Initializing placeholder model: self.decoder = AutoModelForCausalLM.from_config(decoder_config, attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16) peft_config = self.get_peft_config(cfg) self.adapter_keys = [] self.decoder.add_adapter(peft_config, 'decoder_adapter') self.decoder.set_adapter('decoder_adapter') self.adapter_keys.append('decoder_adapter') self.decoder.add_adapter(peft_config, 'encoder_adapter') self.adapter_keys.append('encoder_adapter') self.generation_config = GenerationConfig(do_sample=False, top_p=None) def create_tokenizer(self, cfg): self.tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name, use_fast=True, padding_side='left') n_mem_tokens = 128 // cfg.compr_rate mem_tokens = ['' for i in range(n_mem_tokens)] self.tokenizer.add_special_tokens({'additional_special_tokens': mem_tokens + ['', '', '']}) self.tokenizer.mem_tokens = mem_tokens self.tokenizer.mem_token_ids = [self.tokenizer.convert_tokens_to_ids(elt) for elt in self.tokenizer.mem_tokens] self.tokenizer.mem_token_ids_pt = torch.LongTensor(self.tokenizer.mem_token_ids) # required later on for operations on tensors self.tokenizer.ae_token = '' # token for autoencoding on decoder side self.tokenizer.ae_token_id = self.tokenizer.convert_tokens_to_ids('') self.tokenizer.enc_token = '' # token for autoencoding on compressor side self.tokenizer.sep_token = '' # sep token between document self.tokenizer.sep_token_id = self.tokenizer.convert_tokens_to_ids('') # if pad token exists then use pad token, othrwise bos token if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token_id = self.tokenizer.bos_token_id def set_all_adapters(self): if len(self.adapter_keys) > 0: self.decoder.set_adapter(self.adapter_keys) def get_peft_config(self, cfg: PISCOConfig) -> LoraConfig: """ Builds the peft config """ return LoraConfig(task_type="CAUSAL_LM", r=cfg.lora_r, lora_alpha=2* cfg.lora_r, target_modules='all-linear', lora_dropout=0.1) def compress(self, enc_input_ids, enc_attention_mask): return self.compr_decoder(enc_input_ids, enc_attention_mask) def replace_emb(self, compressed_embs, dec_input_ids): """ Create an input embedding vector combining the compressed_embs and the dec_input_ids """ indices = range(0, compressed_embs.size(0) + 1, self.generation_top_k) input_embeds = self.decoder.get_input_embeddings()(dec_input_ids) num_embs = compressed_embs.size(1) if self.sep: slot_len = num_embs + 1 else: slot_len = num_embs # get first mem_token indices first_mem_token_indices = torch.argmax((dec_input_ids == self.tokenizer.mem_token_ids[0]).int(), dim=1) batch_size = input_embeds.size(0) # for each example in batch, replace them with compressed embeddings for i in range(batch_size): for j in range(indices[i], indices[i + 1]): start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len assert input_embeds[i, start_idx:start_idx + num_embs, :].size() == compressed_embs[j].size(), \ f"{input_embeds[i, start_idx:start_idx + num_embs, :].size()} VS {compressed_embs[j].size()}" input_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j] return input_embeds def compr_decoder(self, input_ids, attention_mask): """ Compression using the decoder """ assert input_ids.size() == attention_mask.size(), f"{input_ids.size()} vs {attention_mask.size()}" # Switch adapter if we are training two different ones: if 'encoder_adapter' in self.adapter_keys: self.decoder.set_adapter('encoder_adapter') emb = self.decoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1] mask = torch.isin(input_ids, self.tokenizer.mem_token_ids_pt.to(input_ids.device)) return emb[mask].reshape(emb.size(0), -1, emb.size(-1)) def prepare_encoder_inputs_to_decoder(self, texts, max_length): inp_enc = [self.tokenizer.enc_token + self.tokenizer.bos_token + text + self.tokenizer.eos_token for text in texts] inp_enc = self.tokenizer(inp_enc, return_tensors='pt', padding="longest", max_length=max_length+3, truncation=True, add_special_tokens=False) num_mem_tokens = 128 // self.compr_rate # hardcode size assert num_mem_tokens == len(self.tokenizer.mem_tokens) inp_enc['input_ids'], inp_enc['attention_mask'] = add_memory_tokens_to_inputs(inp_enc['input_ids'], inp_enc['attention_mask'], num_mem_tokens, tokenizer=self.tokenizer) return inp_enc def prepare_encoder_inputs(self, texts, max_length): return self.prepare_encoder_inputs_to_decoder(texts, max_length) def forward(self, enc_input_ids: torch.LongTensor = None, enc_attention_mask: torch.LongTensor = None, dec_input_ids: torch.LongTensor = None, dec_attention_mask: torch.LongTensor = None, labels: torch.LongTensor = None): """ enc_input_ids: stores the contexts, should be flattened from all queries before input, can be of shape: - (batch_size*generation_top_k, enc_token_length) - (batch_size, generation_top_k, enc_token_length) enc_attention_mask: attention mask of enc_input_ids, same shape as enc_input_ids dec_input_ids: stores the prompts (including mem tokens), dimention (batch_size, dec_token_length) dec_attention_mask: attention mask of dec_input_ids """ assert enc_input_ids.size() == enc_attention_mask.size(), f"{enc_input_ids.size()} vs {enc_attention_mask.size()}" if len(enc_input_ids.size()) == 3: # likely from bergen: we just flatten all of this to perform encoding in one batch batch_size, top_k, seq_length = enc_input_ids.size() enc_input_ids = enc_input_ids.view(batch_size * top_k, seq_length) enc_attention_mask = enc_attention_mask.view(batch_size * top_k, seq_length) # Here, we should have top_k times more elements in enc_input_ids than in dec_input_ids assert enc_input_ids.size(0) == dec_input_ids.size(0) * self.generation_top_k, \ f"{enc_input_ids.size(0)} VS {dec_input_ids.size(0)} with generation_top_k={self.generation_top_k}" # Perform compression with gradient tracking compressed_embs = self.compress(enc_input_ids, enc_attention_mask) inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids) # decoding if 'decoder_adapter' in self.adapter_keys: self.decoder.set_adapter('decoder_adapter') decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels) # At end of forward, we need to activate all adapters so that they are both trained... self.set_all_adapters() return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits} def generate_from_text(self, questions: list[str], documents: list[list[str]], max_new_tokens: int = 128) -> list[str]: """ Generates answers from documents (via compression then decoding) questions: list of string documents: list of list of strings (they should all be of equal length: the nb of doc for each question) """ self.generation_top_k = len(documents[0]) assert len(documents) == len(questions) assert all([len(context) == len(documents[0]) for context in documents]) flat_documents = sum(documents, []) model_input = {} # Creating encoder inputs: input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128) device = self.decoder.device model_input['enc_input_ids'], model_input['enc_attention_mask'] = input_encoder['input_ids'].to(device), input_encoder['attention_mask'].to(device) # Creating decoder inputs instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions] inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048) model_input['dec_input_ids'], model_input['dec_attention_mask'] = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device) # Generation return self.generate(model_input, max_new_tokens=max_new_tokens) def generate_from_compressed_documents_and_questions(self, questions: list[str], compressed_documents: torch.Tensor, max_new_tokens: int = 128) -> list[str]: """ Generates answers from compressed documents questions: list of string compressed_documents: torch tensor, its first dimension should be a multiple of len(questions) """ self.generation_top_k = compressed_documents.size(0) // len(questions) assert compressed_documents.size(0) % self.generation_top_k == 0, f"{compressed_documents.size(0)} {self.generation_top_k}" # Creating decoder inputs instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions] inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048) device = self.decoder.device dec_input_ids, dec_attention_mask = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device) # Creating input decoder embeddings from prompt + compressed documents inputs_embeds = self.replace_emb(compressed_documents, dec_input_ids) # Activating decoder generator: if 'decoder_adapter' in self.adapter_keys: self.decoder.set_adapter('decoder_adapter') output_ids = self.decoder.generate( inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, generation_config=self.generation_config, max_new_tokens=max_new_tokens ) # de-tokenizing return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) def compress_documents(self, documents: list[str]) -> torch.Tensor: """ Compress a list of documents """ input_encoder = self.prepare_encoder_inputs(documents, max_length=128) enc_input_ids = input_encoder['input_ids'].to(self.decoder.device) attention_mask = input_encoder['attention_mask'].to(self.decoder.device) return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask) def generate(self, model_input, max_new_tokens=128): """ Generation pipeline including compression + decoding from compressed """ enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask'] assert enc_input_ids.size() == enc_attention_mask.size() if len(enc_input_ids.size()) == 3: # likely from bergen: we just flatten all of this to perform encoding in one batch batch_size, top_k, seq_length = enc_input_ids.size() enc_input_ids = enc_input_ids.view(batch_size * top_k, seq_length) enc_attention_mask = enc_attention_mask.view(batch_size * top_k, seq_length) # Here, we should have top_k times more elements in enc_input_ids than in dec_input_ids assert enc_input_ids.size(0) == dec_input_ids.size(0) * self.generation_top_k, \ f"{enc_input_ids.size(0)} VS {dec_input_ids.size(0)} with generation_top_k={self.generation_top_k}" compressed_embs = self.compress(enc_input_ids, enc_attention_mask) inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids) if 'decoder_adapter' in self.adapter_keys: self.decoder.set_adapter('decoder_adapter') output_ids = self.decoder.generate( inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, generation_config=self.generation_config, max_new_tokens=max_new_tokens ) return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) def blend_prompt_and_memory_tokens(self, query: str): """ Takes care of blending the prompt with the memory tokens: Also returns, if a label is provided, the position of the first token index of the label (for loss comp later on) """ mem_tokens_str = ''.join(self.tokenizer.mem_tokens) + self.tokenizer.sep_token # proper names for "eval" call, don't remove these lines docs = mem_tokens_str * self.generation_top_k question = query prompt_system = 'You are a helpful assistant. Your task is to extract relevant information from provided documents and to answer to questions as briefly as possible.' prompt_user = f"Background:\n{docs}\n\nQuestion:{question}" # Prepare the messages with system and user roles messages = [ {"role": "system", "content": prompt_system}, {"role": "user", "content": prompt_user.replace(':\ ', ': ')} ] # Attempt to apply the system role and catch if it's not supported try: prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) except TemplateError as e: # Catch the error related to system role and handle it (e.g. gemma) if "System role not supported" in str(e): # Remove system role and proceed with only the user role messages = [{"role": "user", "content": messages[0]['content'] + '\n' + messages[1]['content']}] # Apply template again without system role prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) else: # Re-raise the exception if it's unrelated to system role raise e return prompt