LM-Steer / lm_steer /models /model_gpt_j.py
hanchier's picture
caching
d75dc6d
raw
history blame
10.3 kB
import torch
import torch.nn.functional as F
from transformers import GPTJForCausalLM, AutoTokenizer
from .model_utils import Hack_no_grad, find_max_subspans
from .steers import Projected_Adaptor
from .model_base import LMSteerBase
from lm_steer.utils import set_seed
class Switching_GPTJModel(LMSteerBase):
def __init__(self, model_name, adapted_component, adaptor_class,
num_steers, rank, epsilon, init_var, low_resource_mode):
super().__init__()
self.adapted_component = adapted_component
self.adaptor_class = adaptor_class
# self.generator = pipeline('text-generation', model=model_name)
# self.tokenizer = self.generator.tokenizer
# self.model = self.generator.model
if low_resource_mode:
print("using low_resource_mode and fp16")
self.model = GPTJForCausalLM.from_pretrained(
"EleutherAI/gpt-j-6B", revision="float16",
torch_dtype=torch.float16, low_cpu_mem_usage=True
)
else:
self.model = GPTJForCausalLM.from_pretrained(
"EleutherAI/gpt-j-6B",
)
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.init_var = init_var
self.num_steers = num_steers
self.device = torch.device("cpu")
self.low_resource_mode = low_resource_mode
embed_dim = self.model.lm_head.weight.shape[1]
vocab_size = self.model.lm_head.weight.shape[0]
for _param in self.model.parameters():
_param.requires_grad_(False)
if adapted_component == "final_layer":
self.model.transformer = Hack_no_grad(self.model.transformer)
self.steer = Projected_Adaptor(
self.model.lm_head, adaptor_class, num_steers, embed_dim,
vocab_size, rank, epsilon, init_var, "output")
self.model.set_output_embeddings(self.steer)
elif adapted_component == "input_embedding":
self.steer = Projected_Adaptor(
self.model.transformer.wte, adaptor_class, num_steers,
embed_dim, vocab_size, rank, epsilon, init_var, "input")
self.model.transformer.set_input_embeddings(self.steer)
else:
raise NotImplementedError()
def generate(self, prompt, steer_values, min_length=20, max_length=100,
seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
temperature=1, top_p=1):
'''
prompt: a string
steer_values
min_length: minimum generation length
max_length: maximum generation length
seed: seed for generation. None if not specified.
'''
return super().generate_low_resource(
prompt, steer_values, min_length, max_length, seed,
num_beams, num_beam_groups, do_sample, temperature, top_p)
def generate_multiple(
self, prompts, steer_values, min_length=20, max_length=100,
seed=None):
'''
prompt: a string
steer_values
min_length: minimum generation length
max_length: maximum generation length
seed: seed for generation. None if not specified.
'''
if seed is not None:
set_seed(seed)
steer_values = torch.Tensor(steer_values).to(
self.device)
if self.low_resource_mode:
fp16 = torch.float16
steer_values = steer_values.to(fp16)
self.steer.projector1.data = self.steer.projector1.to(fp16)
self.steer.projector2.data = self.steer.projector2.to(fp16)
self.steer.set_value(steer_values)
with torch.no_grad():
input_ids = self.tokenizer(
prompts, return_tensors="pt").input_ids.to(self.device)
gen_tokens = self.model.generate(
input_ids,
do_sample=True,
min_new_tokens=min_length, max_new_tokens=max_length,
pad_token_id=self.tokenizer.pad_token_id)
text = self.tokenizer.batch_decode(gen_tokens)
# recovering
if self.low_resource_mode:
fp32 = torch.float32
self.steer.projector1.data = self.steer.projector1.to(fp32)
self.steer.projector2.data = self.steer.projector2.to(fp32)
return text
# def evidence_words(self, prompt, original_steer_values,
# truncation_length=1024, max_segments=4, max_length=10):
# if isinstance(original_steer_values, list):
# original_steer_values = torch.Tensor(original_steer_values)
# if original_steer_values.abs().sum() <= 0.2:
# return [(prompt, None)]
# tokenized = self.tokenizer(
# prompt, return_tensors="pt", max_length=truncation_length, truncation=True)
# input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device)
# input_ids = input_ids.expand(2, -1)
# attention_mask = torch.LongTensor(tokenized["attention_mask"]).to(
# self.device)
# attention_mask = attention_mask.expand(2, -1)
# steer_values = torch.zeros(2, self.num_steers).to(self.device)
# steer_values[0] = original_steer_values
# steer_values[1] = (-original_steer_values > 0) * 2 - 1
# if self.low_resource_mode:
# fp16 = torch.float16
# steer_values = steer_values.to(fp16)
# self.steer.projector1.data = self.steer.projector1.to(fp16)
# self.steer.projector2.data = self.steer.projector2.to(fp16)
# self.steer.set_value(steer_values)
# with torch.no_grad():
# output = self.model(
# input_ids=input_ids,
# attention_mask=attention_mask,
# labels=input_ids)
# length = input_ids.shape[1]
# loss_token = F.cross_entropy(
# output.logits[:, :-1].reshape((2)*(length-1), -1),
# input_ids[:, 1:].reshape(-1),
# reduction="none"
# )
# loss_token = loss_token.reshape(2, length - 1)
# token_evidence = (- loss_token[0] + loss_token[1])
# tokens = input_ids[0]
# evidence_segments = find_max_subspans(
# token_evidence.cpu().numpy().tolist(), max_segments, max_length)[0]
# evidence_segments = [
# (_seg[0]+1, _seg[1]+1) for _seg in evidence_segments]
# start = 0
# output = []
# color = (
# "gray" if original_steer_values.shape[0] > 1
# else "red" if original_steer_values[0] > 0
# else "blue"
# )
# if len(evidence_segments) > 0:
# for _segment in evidence_segments:
# if _segment[0] > start:
# output.append((
# self.tokenizer.decode(tokens[start: _segment[0]]),
# None
# ))
# output.append((
# self.tokenizer.decode(tokens[_segment[0]: _segment[1]]),
# color
# ))
# start = _segment[1]
# length = tokens.shape[-1]
# if _segment[1] < length:
# output.append((
# self.tokenizer.decode(tokens[_segment[1]: length]),
# None
# ))
# else:
# output = [(prompt, None)]
# if self.low_resource_mode:
# fp32 = torch.float32
# self.steer.projector1.data = self.steer.projector1.to(fp32)
# self.steer.projector2.data = self.steer.projector2.to(fp32)
# return output
# def steer_analysis(self, prompt, steer_dim, min_value=-3, max_value=3,
# bins=7, truncation_length=1024):
# tokenized = self.tokenizer(
# prompt, return_tensors="pt",
# max_length=truncation_length,
# truncation=True)
# input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device)
# input_ids = input_ids.expand(bins + 1, -1)
# attention_mask = torch.LongTensor(tokenized["attention_mask"]).to(
# self.device)
# attention_mask = attention_mask.expand(bins + 1, -1)
# steer_values = torch.zeros(bins+1, self.num_steers).to(self.device)
# for bin_i in range(bins):
# steer_values[bin_i, steer_dim] = (
# min_value + (max_value - min_value) / (bins - 1) * bin_i
# )
# if self.low_resource_mode:
# fp16 = torch.float16
# steer_values = steer_values.to(fp16)
# self.steer.projector1.data = self.steer.projector1.to(fp16)
# self.steer.projector2.data = self.steer.projector2.to(fp16)
# self.steer.set_value(steer_values)
# with torch.no_grad():
# output = self.model(
# input_ids=input_ids,
# attention_mask=attention_mask,
# labels=input_ids)
# length = input_ids.shape[1]
# loss_token = F.cross_entropy(
# output.logits[:, :-1].reshape((bins+1)*(length-1), -1),
# input_ids[:, 1:].reshape(-1),
# reduction="none"
# )
# loss_token = loss_token.reshape(bins + 1, length - 1)
# loss = loss_token.mean(-1)[:-1]
# dist = ((- loss + loss.mean()) * 100).softmax(0)
# dist_list = list(zip(
# [
# min_value + (max_value - min_value) / (bins - 1) * bin_i
# for bin_i in range(bins)
# ],
# dist.tolist(),
# ))
# best_guess = loss.argmin(0)
# best_guess_value = min_value + \
# (max_value - min_value) / (bins - 1) * best_guess.item()
# token_evidence = self.evidence_words(
# prompt, steer_values[best_guess],
# )
# if self.low_resource_mode:
# fp32 = torch.float32
# self.steer.projector1.data = self.steer.projector1.to(fp32)
# return best_guess_value, dist_list, token_evidence