LM-Steer / lm_steer /models /model_embedding_tuning_gpt_neo.py
hanchier's picture
init
e0b11c9
raw
history blame
2.01 kB
import torch
import torch.nn as nn
from transformers import pipeline
from .model_utils import Hack_no_grad
from lm_steer.utils import set_seed
class EmbeddingTuning_GPTNeoModel(nn.Module):
def __init__(self, model_name):
super().__init__()
self.generator = pipeline(
'text-generation',
model=model_name.replace("embedding_tuning-", ""))
self.tokenizer = self.generator.tokenizer
self.model = self.generator.model
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.model.transformer = Hack_no_grad(self.model.transformer)
def forward(self, input_ids, attention_mask, steer_values):
output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids)
return output
def parameters(self):
return [self.model.lm_head.weight]
def state_dict(self):
return self.model.lm_head.state_dict()
def load_state_dict(self, state_dict):
self.model.lm_head.load_state_dict(state_dict)
def to_device(self, device):
self.generator.device = device
self.model.to(device)
self.device = device
def regularization_term(self):
return torch.tensor(0)
def generate(self, prompt, steer_values, min_length=20, max_length=100,
seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
temperature=1, top_p=1):
if seed is not None:
set_seed(seed)
with torch.no_grad():
text = self.generator(
prompt, num_beams=num_beams, num_beam_groups=num_beam_groups,
do_sample=do_sample, temperature=temperature, top_p=top_p,
min_length=min_length, max_length=max_length,
pad_token_id=self.tokenizer.pad_token_id,
)
text = text[0]["generated_text"]
return text