LM-Steer / lm_steer /models /model_gpt_neo.py
hanchier's picture
caching
d75dc6d
raw
history blame
1.76 kB
import torch
from transformers import pipeline
from .model_utils import Hack_no_grad
from .steers import Projected_Adaptor
from .model_base import LMSteerBase
class Switching_GPTNeoModel(LMSteerBase):
def __init__(self, model_name, adapted_component, adaptor_class,
num_steers, rank, epsilon, init_var,
low_resource_mode):
super().__init__()
self.adapted_component = adapted_component
self.pipeline = pipeline('text-generation', model=model_name)
self.model = self.pipeline.model
self.tokenizer = self.pipeline.tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.init_var = init_var
self.num_steers = num_steers
self.device = torch.device("cpu")
embed_dim = self.model.lm_head.weight.shape[1]
vocab_size = self.model.lm_head.weight.shape[0]
for _param in self.model.parameters():
_param.requires_grad_(False)
if adapted_component == "final_layer":
self.model.transformer = Hack_no_grad(self.model.transformer)
self.steer = Projected_Adaptor(
self.model.lm_head, adaptor_class, num_steers, embed_dim,
vocab_size, rank, epsilon, init_var, "output")
self.model.set_output_embeddings(self.steer)
elif adapted_component == "input_embedding":
self.steer = Projected_Adaptor(
self.model.transformer.wte, adaptor_class, num_steers,
embed_dim, vocab_size, rank, epsilon, init_var, "input")
self.model.transformer.set_input_embeddings(self.steer)
else:
raise NotImplementedError()