Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from transformers import pipeline | |
from peft import LoraConfig, get_peft_model | |
from lm_steer.utils import set_seed | |
class LORA_GPTNeoModel(nn.Module): | |
def __init__(self, model_name, rank, epsilon): | |
super().__init__() | |
self.generator = pipeline('text-generation', | |
model=model_name.replace("lora-", "")) | |
self.tokenizer = self.generator.tokenizer | |
model = self.generator.model | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | |
config = LoraConfig( | |
r=rank, | |
lora_alpha=epsilon, | |
target_modules=["c_attn", "c_proj", "c_fc"], | |
lora_dropout=0.1, | |
bias="lora_only", | |
modules_to_save=[], | |
) | |
self.model = get_peft_model(model, config) | |
self.generator.model = self.model | |
self.model.print_trainable_parameters() | |
def forward(self, input_ids, attention_mask, steer_values): | |
output = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
labels=input_ids) | |
return output | |
def to_device(self, device): | |
self.generator.device = device | |
self.model.to(device) | |
self.device = device | |
def regularization_term(self): | |
return torch.tensor(0) | |
def generate(self, prompt, steer_values, min_length=20, max_length=100, | |
seed=None, num_beams=1, num_beam_groups=1, do_sample=True, | |
temperature=1, top_p=1): | |
if seed is not None: | |
set_seed(seed) | |
with torch.no_grad(): | |
text = self.generator( | |
prompt, num_beams=num_beams, num_beam_groups=num_beam_groups, | |
do_sample=do_sample, temperature=temperature, top_p=top_p, | |
min_length=min_length, max_length=max_length, | |
pad_token_id=self.tokenizer.pad_token_id, | |
) | |
text = text[0]["generated_text"] | |
return text | |