Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from transformers import pipeline | |
from .model_utils import Hack_no_grad | |
from lm_steer.utils import set_seed | |
class EmbeddingTuning_GPTNeoModel(nn.Module): | |
def __init__(self, model_name): | |
super().__init__() | |
self.generator = pipeline( | |
'text-generation', | |
model=model_name.replace("embedding_tuning-", "")) | |
self.tokenizer = self.generator.tokenizer | |
self.model = self.generator.model | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | |
self.model.transformer = Hack_no_grad(self.model.transformer) | |
def forward(self, input_ids, attention_mask, steer_values): | |
output = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
labels=input_ids) | |
return output | |
def parameters(self): | |
return [self.model.lm_head.weight] | |
def state_dict(self): | |
return self.model.lm_head.state_dict() | |
def load_state_dict(self, state_dict): | |
self.model.lm_head.load_state_dict(state_dict) | |
def to_device(self, device): | |
self.generator.device = device | |
self.model.to(device) | |
self.device = device | |
def regularization_term(self): | |
return torch.tensor(0) | |
def generate(self, prompt, steer_values, min_length=20, max_length=100, | |
seed=None, num_beams=1, num_beam_groups=1, do_sample=True, | |
temperature=1, top_p=1): | |
if seed is not None: | |
set_seed(seed) | |
with torch.no_grad(): | |
text = self.generator( | |
prompt, num_beams=num_beams, num_beam_groups=num_beam_groups, | |
do_sample=do_sample, temperature=temperature, top_p=top_p, | |
min_length=min_length, max_length=max_length, | |
pad_token_id=self.tokenizer.pad_token_id, | |
) | |
text = text[0]["generated_text"] | |
return text | |