|
import warnings |
|
import os |
|
import torch |
|
from peft import LoraConfig |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PretrainedConfig, AutoConfig, GenerationConfig |
|
from jinja2.exceptions import TemplateError |
|
|
|
|
|
def add_memory_tokens_to_inputs(input_ids: torch.Tensor, attention_mask: torch.Tensor, n_mem_tokens: int, tokenizer): |
|
""" |
|
Concatenate the input ids with n_mem_tokens mem_tokens and update the corresponding attention mask |
|
""" |
|
assert len(tokenizer.mem_tokens) == n_mem_tokens, f"{len(tokenizer.mem_tokens)} VS {n_mem_tokens}" |
|
mem_tokens = torch.stack([tokenizer.mem_token_ids_pt] * input_ids.size(0), 0) |
|
assert len(mem_tokens.size()) == 2 |
|
assert len(mem_tokens) == input_ids.size(0) |
|
assert len(mem_tokens[0]) == n_mem_tokens |
|
|
|
input_ids = torch.cat([input_ids, mem_tokens], dim=1) |
|
attention_mask = torch.cat([attention_mask, torch.ones(input_ids.size(0), n_mem_tokens)], dim=1) |
|
return input_ids, attention_mask |
|
|
|
|
|
class PISCOConfig(PretrainedConfig): |
|
|
|
model_type = "PISCO" |
|
def __init__(self, |
|
decoder_model_name: str = "meta-llama/Llama-2-7b-chat-hf", |
|
compr_rate: int = 16, |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.decoder_model_name = decoder_model_name |
|
self.compr_rate = compr_rate |
|
self.lora_r = 16 |
|
self.sep = True |
|
|
|
|
|
class PISCO(PreTrainedModel): |
|
config_class = PISCOConfig |
|
def __init__(self, cfg): |
|
super().__init__(cfg) |
|
self.decoder_model_name = cfg.decoder_model_name |
|
self.sep = cfg.sep |
|
self.compr_rate = cfg.compr_rate |
|
|
|
self.create_tokenizer(cfg) |
|
|
|
|
|
decoder_config = AutoConfig.from_pretrained(cfg.decoder_model_name) |
|
decoder_config.vocab_size = len(self.tokenizer) |
|
|
|
|
|
self.decoder = AutoModelForCausalLM.from_config(decoder_config, |
|
attn_implementation='flash_attention_2', |
|
torch_dtype=torch.bfloat16) |
|
|
|
peft_config = self.get_peft_config(cfg) |
|
|
|
self.adapter_keys = [] |
|
self.decoder.add_adapter(peft_config, 'decoder_adapter') |
|
self.decoder.set_adapter('decoder_adapter') |
|
self.adapter_keys.append('decoder_adapter') |
|
self.decoder.add_adapter(peft_config, 'encoder_adapter') |
|
self.adapter_keys.append('encoder_adapter') |
|
|
|
self.generation_config = GenerationConfig(do_sample=False, top_p=None) |
|
|
|
def create_tokenizer(self, cfg): |
|
self.tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name, use_fast=True, padding_side='left') |
|
|
|
n_mem_tokens = 128 // cfg.compr_rate |
|
mem_tokens = ['<MEM' + str(i) + '>' for i in range(n_mem_tokens)] |
|
self.tokenizer.add_special_tokens({'additional_special_tokens': mem_tokens + ['<AE>', '<ENC>', '<SEP>']}) |
|
self.tokenizer.mem_tokens = mem_tokens |
|
|
|
self.tokenizer.mem_token_ids = [self.tokenizer.convert_tokens_to_ids(elt) for elt in self.tokenizer.mem_tokens] |
|
self.tokenizer.mem_token_ids_pt = torch.LongTensor(self.tokenizer.mem_token_ids) |
|
|
|
self.tokenizer.ae_token = '<AE>' |
|
self.tokenizer.ae_token_id = self.tokenizer.convert_tokens_to_ids('<AE>') |
|
self.tokenizer.enc_token = '<ENC>' |
|
self.tokenizer.sep_token = '<SEP>' |
|
self.tokenizer.sep_token_id = self.tokenizer.convert_tokens_to_ids('<SEP>') |
|
|
|
|
|
if self.tokenizer.pad_token_id is None: |
|
self.tokenizer.pad_token_id = self.tokenizer.bos_token_id |
|
|
|
def set_all_adapters(self): |
|
if len(self.adapter_keys) > 0: |
|
self.decoder.set_adapter(self.adapter_keys) |
|
|
|
def get_peft_config(self, cfg: PISCOConfig) -> LoraConfig: |
|
""" |
|
Builds the peft config |
|
""" |
|
return LoraConfig(task_type="CAUSAL_LM", r=cfg.lora_r, lora_alpha=2* cfg.lora_r, target_modules='all-linear', lora_dropout=0.1) |
|
|
|
def compress(self, enc_input_ids, enc_attention_mask): |
|
return self.compr_decoder(enc_input_ids, enc_attention_mask) |
|
|
|
def replace_emb(self, compressed_embs, dec_input_ids): |
|
""" |
|
Create an input embedding vector combining the compressed_embs and the dec_input_ids |
|
""" |
|
indices = range(0, compressed_embs.size(0) + 1, self.generation_top_k) |
|
|
|
input_embeds = self.decoder.get_input_embeddings()(dec_input_ids) |
|
num_embs = compressed_embs.size(1) |
|
if self.sep: |
|
slot_len = num_embs + 1 |
|
else: |
|
slot_len = num_embs |
|
|
|
first_mem_token_indices = torch.argmax((dec_input_ids == self.tokenizer.mem_token_ids[0]).int(), dim=1) |
|
batch_size = input_embeds.size(0) |
|
|
|
for i in range(batch_size): |
|
for j in range(indices[i], indices[i + 1]): |
|
start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len |
|
assert input_embeds[i, start_idx:start_idx + num_embs, :].size() == compressed_embs[j].size(), \ |
|
f"{input_embeds[i, start_idx:start_idx + num_embs, :].size()} VS {compressed_embs[j].size()}" |
|
input_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j] |
|
|
|
return input_embeds |
|
|
|
def compr_decoder(self, input_ids, attention_mask): |
|
""" |
|
Compression using the decoder |
|
""" |
|
assert input_ids.size() == attention_mask.size(), f"{input_ids.size()} vs {attention_mask.size()}" |
|
|
|
|
|
if 'encoder_adapter' in self.adapter_keys: |
|
self.decoder.set_adapter('encoder_adapter') |
|
|
|
emb = self.decoder(input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
output_hidden_states=True).hidden_states[-1] |
|
mask = torch.isin(input_ids, self.tokenizer.mem_token_ids_pt.to(input_ids.device)) |
|
return emb[mask].reshape(emb.size(0), -1, emb.size(-1)) |
|
|
|
def prepare_encoder_inputs_to_decoder(self, texts, max_length): |
|
inp_enc = [self.tokenizer.enc_token + self.tokenizer.bos_token + text + self.tokenizer.eos_token for text in texts] |
|
inp_enc = self.tokenizer(inp_enc, return_tensors='pt', padding="longest", max_length=max_length+3, truncation=True, add_special_tokens=False) |
|
num_mem_tokens = 128 // self.compr_rate |
|
assert num_mem_tokens == len(self.tokenizer.mem_tokens) |
|
inp_enc['input_ids'], inp_enc['attention_mask'] = add_memory_tokens_to_inputs(inp_enc['input_ids'], |
|
inp_enc['attention_mask'], |
|
num_mem_tokens, |
|
tokenizer=self.tokenizer) |
|
|
|
return inp_enc |
|
|
|
def prepare_encoder_inputs(self, texts, max_length): |
|
return self.prepare_encoder_inputs_to_decoder(texts, max_length) |
|
|
|
def forward(self, |
|
enc_input_ids: torch.LongTensor = None, |
|
enc_attention_mask: torch.LongTensor = None, |
|
dec_input_ids: torch.LongTensor = None, |
|
dec_attention_mask: torch.LongTensor = None, |
|
labels: torch.LongTensor = None): |
|
""" |
|
enc_input_ids: stores the contexts, should be flattened from all queries before input, can be of shape: |
|
- (batch_size*generation_top_k, enc_token_length) |
|
- (batch_size, generation_top_k, enc_token_length) |
|
enc_attention_mask: attention mask of enc_input_ids, same shape as enc_input_ids |
|
dec_input_ids: stores the prompts (including mem tokens), dimention (batch_size, dec_token_length) |
|
dec_attention_mask: attention mask of dec_input_ids |
|
""" |
|
assert enc_input_ids.size() == enc_attention_mask.size(), f"{enc_input_ids.size()} vs {enc_attention_mask.size()}" |
|
|
|
if len(enc_input_ids.size()) == 3: |
|
batch_size, top_k, seq_length = enc_input_ids.size() |
|
enc_input_ids = enc_input_ids.view(batch_size * top_k, seq_length) |
|
enc_attention_mask = enc_attention_mask.view(batch_size * top_k, seq_length) |
|
|
|
|
|
assert enc_input_ids.size(0) == dec_input_ids.size(0) * self.generation_top_k, \ |
|
f"{enc_input_ids.size(0)} VS {dec_input_ids.size(0)} with generation_top_k={self.generation_top_k}" |
|
|
|
|
|
compressed_embs = self.compress(enc_input_ids, enc_attention_mask) |
|
inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids) |
|
|
|
|
|
if 'decoder_adapter' in self.adapter_keys: |
|
self.decoder.set_adapter('decoder_adapter') |
|
|
|
decoder_outputs = self.decoder(inputs_embeds=inputs_embeds, attention_mask=dec_attention_mask, labels=labels) |
|
|
|
|
|
self.set_all_adapters() |
|
|
|
return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits} |
|
|
|
def generate_from_text(self, questions: list[str], documents: list[list[str]], max_new_tokens: int = 128) -> list[str]: |
|
""" |
|
Generates answers from documents (via compression then decoding) |
|
questions: list of string |
|
documents: list of list of strings (they should all be of equal length: the nb of doc for each question) |
|
""" |
|
self.generation_top_k = len(documents[0]) |
|
assert len(documents) == len(questions) |
|
assert all([len(context) == len(documents[0]) for context in documents]) |
|
flat_documents = sum(documents, []) |
|
|
|
model_input = {} |
|
|
|
|
|
input_encoder = self.prepare_encoder_inputs(flat_documents, max_length=128) |
|
device = self.decoder.device |
|
model_input['enc_input_ids'], model_input['enc_attention_mask'] = input_encoder['input_ids'].to(device), input_encoder['attention_mask'].to(device) |
|
|
|
|
|
instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions] |
|
inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048) |
|
model_input['dec_input_ids'], model_input['dec_attention_mask'] = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device) |
|
|
|
|
|
return self.generate(model_input, max_new_tokens=max_new_tokens) |
|
|
|
def generate_from_compressed_documents_and_questions(self, questions: list[str], compressed_documents: torch.Tensor, max_new_tokens: int = 128) -> list[str]: |
|
""" |
|
Generates answers from compressed documents |
|
questions: list of string |
|
compressed_documents: torch tensor, its first dimension should be a multiple of len(questions) |
|
""" |
|
self.generation_top_k = compressed_documents.size(0) // len(questions) |
|
assert compressed_documents.size(0) % self.generation_top_k == 0, f"{compressed_documents.size(0)} {self.generation_top_k}" |
|
|
|
|
|
instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions] |
|
inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048) |
|
device = self.decoder.device |
|
dec_input_ids, dec_attention_mask = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device) |
|
|
|
|
|
inputs_embeds = self.replace_emb(compressed_documents, dec_input_ids) |
|
|
|
|
|
if 'decoder_adapter' in self.adapter_keys: |
|
self.decoder.set_adapter('decoder_adapter') |
|
|
|
output_ids = self.decoder.generate( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=dec_attention_mask, |
|
generation_config=self.generation_config, |
|
max_new_tokens=max_new_tokens |
|
) |
|
|
|
|
|
return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
|
|
|
def compress_documents(self, documents: list[str]) -> torch.Tensor: |
|
""" |
|
Compress a list of documents |
|
""" |
|
input_encoder = self.prepare_encoder_inputs(documents, max_length=128) |
|
enc_input_ids = input_encoder['input_ids'].to(self.decoder.device) |
|
attention_mask = input_encoder['attention_mask'].to(self.decoder.device) |
|
return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask) |
|
|
|
def generate(self, model_input, max_new_tokens=128): |
|
""" |
|
Generation pipeline including compression + decoding from compressed |
|
""" |
|
|
|
enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask'] |
|
|
|
assert enc_input_ids.size() == enc_attention_mask.size() |
|
|
|
if len(enc_input_ids.size()) == 3: |
|
batch_size, top_k, seq_length = enc_input_ids.size() |
|
enc_input_ids = enc_input_ids.view(batch_size * top_k, seq_length) |
|
enc_attention_mask = enc_attention_mask.view(batch_size * top_k, seq_length) |
|
|
|
|
|
assert enc_input_ids.size(0) == dec_input_ids.size(0) * self.generation_top_k, \ |
|
f"{enc_input_ids.size(0)} VS {dec_input_ids.size(0)} with generation_top_k={self.generation_top_k}" |
|
|
|
compressed_embs = self.compress(enc_input_ids, enc_attention_mask) |
|
inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids) |
|
|
|
if 'decoder_adapter' in self.adapter_keys: |
|
self.decoder.set_adapter('decoder_adapter') |
|
|
|
output_ids = self.decoder.generate( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=dec_attention_mask, |
|
generation_config=self.generation_config, |
|
max_new_tokens=max_new_tokens |
|
) |
|
|
|
return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
|
|
|
def blend_prompt_and_memory_tokens(self, query: str): |
|
""" |
|
Takes care of blending the prompt with the memory tokens: |
|
Also returns, if a label is provided, the position of the first token index of the label (for loss comp later on) |
|
""" |
|
mem_tokens_str = ''.join(self.tokenizer.mem_tokens) + self.tokenizer.sep_token |
|
|
|
|
|
docs = mem_tokens_str * self.generation_top_k |
|
question = query |
|
|
|
prompt_system = 'You are a helpful assistant. Your task is to extract relevant information from provided documents and to answer to questions as briefly as possible.' |
|
prompt_user = f"Background:\n{docs}\n\nQuestion:{question}" |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": prompt_system}, |
|
{"role": "user", "content": prompt_user.replace(':\ ', ': ')} |
|
] |
|
|
|
|
|
try: |
|
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
except TemplateError as e: |
|
|
|
if "System role not supported" in str(e): |
|
|
|
messages = [{"role": "user", "content": messages[0]['content'] + '\n' + messages[1]['content']}] |
|
|
|
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
else: |
|
|
|
raise e |
|
|
|
return prompt |
|
|