LM-Steer / lm_steer /models /model_gpt_neox.py
hanchier's picture
caching
d75dc6d
raw
history blame
2.32 kB
import torch
from transformers import GPTNeoXForCausalLM, AutoTokenizer
from .model_utils import Hack_no_grad
from .steers import Projected_Adaptor
from .model_base import LMSteerBase
class Switching_GPTNeoXModel(LMSteerBase):
def __init__(self, model_name, adapted_component, adaptor_class,
num_steers, rank, epsilon, init_var,
low_resource_mode):
super().__init__()
self.adapted_component = adapted_component
if low_resource_mode:
self.model = GPTNeoXForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, low_cpu_mem_usage=True
)
else:
self.model = GPTNeoXForCausalLM.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.init_var = init_var
self.num_steers = num_steers
self.device = torch.device("cpu")
embed_dim = self.model.embed_out.weight.shape[1]
vocab_size = self.model.embed_out.weight.shape[0]
self.low_resource_mode = low_resource_mode
for _param in self.model.parameters():
_param.requires_grad_(False)
if adapted_component == "final_layer":
self.model.gpt_neox = Hack_no_grad(self.model.gpt_neox)
self.steer = Projected_Adaptor(
self.model.embed_out, adaptor_class, num_steers, embed_dim,
vocab_size, rank, epsilon, init_var, "output")
self.model.set_output_embeddings(self.steer)
else:
raise NotImplementedError()
def generate(self, prompt, steer_values, min_length=20, max_length=100,
seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
temperature=1, top_p=1):
'''
prompt: a string
steer_values
min_length: minimum generation length
max_length: maximum generation length
seed: seed for generation. None if not specified.
'''
return super().generate_low_resource(
prompt, steer_values, min_length, max_length, seed,
num_beams, num_beam_groups, do_sample, temperature, top_p)