import torch import torch.nn as nn from transformers import pipeline from peft import LoraConfig, get_peft_model from lm_steer.utils import set_seed class LORA_GPTNeoModel(nn.Module): def __init__(self, model_name, rank, epsilon): super().__init__() self.generator = pipeline('text-generation', model=model_name.replace("lora-", "")) self.tokenizer = self.generator.tokenizer model = self.generator.model self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id config = LoraConfig( r=rank, lora_alpha=epsilon, target_modules=["c_attn", "c_proj", "c_fc"], lora_dropout=0.1, bias="lora_only", modules_to_save=[], ) self.model = get_peft_model(model, config) self.generator.model = self.model self.model.print_trainable_parameters() def forward(self, input_ids, attention_mask, steer_values): output = self.model( input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) return output def to_device(self, device): self.generator.device = device self.model.to(device) self.device = device def regularization_term(self): return torch.tensor(0) def generate(self, prompt, steer_values, min_length=20, max_length=100, seed=None, num_beams=1, num_beam_groups=1, do_sample=True, temperature=1, top_p=1): if seed is not None: set_seed(seed) with torch.no_grad(): text = self.generator( prompt, num_beams=num_beams, num_beam_groups=num_beam_groups, do_sample=do_sample, temperature=temperature, top_p=top_p, min_length=min_length, max_length=max_length, pad_token_id=self.tokenizer.pad_token_id, ) text = text[0]["generated_text"] return text