Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
from transformers import GPTJForCausalLM, AutoTokenizer | |
from .model_utils import Hack_no_grad, find_max_subspans | |
from .steers import Projected_Adaptor | |
from .model_base import LMSteerBase | |
from lm_steer.utils import set_seed | |
class Switching_GPTJModel(LMSteerBase): | |
def __init__(self, model_name, adapted_component, adaptor_class, | |
num_steers, rank, epsilon, init_var, low_resource_mode): | |
super().__init__() | |
self.adapted_component = adapted_component | |
self.adaptor_class = adaptor_class | |
# self.generator = pipeline('text-generation', model=model_name) | |
# self.tokenizer = self.generator.tokenizer | |
# self.model = self.generator.model | |
if low_resource_mode: | |
print("using low_resource_mode and fp16") | |
self.model = GPTJForCausalLM.from_pretrained( | |
"EleutherAI/gpt-j-6B", revision="float16", | |
torch_dtype=torch.float16, low_cpu_mem_usage=True | |
) | |
else: | |
self.model = GPTJForCausalLM.from_pretrained( | |
"EleutherAI/gpt-j-6B", | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | |
self.init_var = init_var | |
self.num_steers = num_steers | |
self.device = torch.device("cpu") | |
self.low_resource_mode = low_resource_mode | |
embed_dim = self.model.lm_head.weight.shape[1] | |
vocab_size = self.model.lm_head.weight.shape[0] | |
for _param in self.model.parameters(): | |
_param.requires_grad_(False) | |
if adapted_component == "final_layer": | |
self.model.transformer = Hack_no_grad(self.model.transformer) | |
self.steer = Projected_Adaptor( | |
self.model.lm_head, adaptor_class, num_steers, embed_dim, | |
vocab_size, rank, epsilon, init_var, "output") | |
self.model.set_output_embeddings(self.steer) | |
elif adapted_component == "input_embedding": | |
self.steer = Projected_Adaptor( | |
self.model.transformer.wte, adaptor_class, num_steers, | |
embed_dim, vocab_size, rank, epsilon, init_var, "input") | |
self.model.transformer.set_input_embeddings(self.steer) | |
else: | |
raise NotImplementedError() | |
def generate(self, prompt, steer_values, min_length=20, max_length=100, | |
seed=None, num_beams=1, num_beam_groups=1, do_sample=True, | |
temperature=1, top_p=1): | |
''' | |
prompt: a string | |
steer_values | |
min_length: minimum generation length | |
max_length: maximum generation length | |
seed: seed for generation. None if not specified. | |
''' | |
return super().generate_low_resource( | |
prompt, steer_values, min_length, max_length, seed, | |
num_beams, num_beam_groups, do_sample, temperature, top_p) | |
def generate_multiple( | |
self, prompts, steer_values, min_length=20, max_length=100, | |
seed=None): | |
''' | |
prompt: a string | |
steer_values | |
min_length: minimum generation length | |
max_length: maximum generation length | |
seed: seed for generation. None if not specified. | |
''' | |
if seed is not None: | |
set_seed(seed) | |
steer_values = torch.Tensor(steer_values).to( | |
self.device) | |
if self.low_resource_mode: | |
fp16 = torch.float16 | |
steer_values = steer_values.to(fp16) | |
self.steer.projector1.data = self.steer.projector1.to(fp16) | |
self.steer.projector2.data = self.steer.projector2.to(fp16) | |
self.steer.set_value(steer_values) | |
with torch.no_grad(): | |
input_ids = self.tokenizer( | |
prompts, return_tensors="pt").input_ids.to(self.device) | |
gen_tokens = self.model.generate( | |
input_ids, | |
do_sample=True, | |
min_new_tokens=min_length, max_new_tokens=max_length, | |
pad_token_id=self.tokenizer.pad_token_id) | |
text = self.tokenizer.batch_decode(gen_tokens) | |
# recovering | |
if self.low_resource_mode: | |
fp32 = torch.float32 | |
self.steer.projector1.data = self.steer.projector1.to(fp32) | |
self.steer.projector2.data = self.steer.projector2.to(fp32) | |
return text | |
# def evidence_words(self, prompt, original_steer_values, | |
# truncation_length=1024, max_segments=4, max_length=10): | |
# if isinstance(original_steer_values, list): | |
# original_steer_values = torch.Tensor(original_steer_values) | |
# if original_steer_values.abs().sum() <= 0.2: | |
# return [(prompt, None)] | |
# tokenized = self.tokenizer( | |
# prompt, return_tensors="pt", max_length=truncation_length, truncation=True) | |
# input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device) | |
# input_ids = input_ids.expand(2, -1) | |
# attention_mask = torch.LongTensor(tokenized["attention_mask"]).to( | |
# self.device) | |
# attention_mask = attention_mask.expand(2, -1) | |
# steer_values = torch.zeros(2, self.num_steers).to(self.device) | |
# steer_values[0] = original_steer_values | |
# steer_values[1] = (-original_steer_values > 0) * 2 - 1 | |
# if self.low_resource_mode: | |
# fp16 = torch.float16 | |
# steer_values = steer_values.to(fp16) | |
# self.steer.projector1.data = self.steer.projector1.to(fp16) | |
# self.steer.projector2.data = self.steer.projector2.to(fp16) | |
# self.steer.set_value(steer_values) | |
# with torch.no_grad(): | |
# output = self.model( | |
# input_ids=input_ids, | |
# attention_mask=attention_mask, | |
# labels=input_ids) | |
# length = input_ids.shape[1] | |
# loss_token = F.cross_entropy( | |
# output.logits[:, :-1].reshape((2)*(length-1), -1), | |
# input_ids[:, 1:].reshape(-1), | |
# reduction="none" | |
# ) | |
# loss_token = loss_token.reshape(2, length - 1) | |
# token_evidence = (- loss_token[0] + loss_token[1]) | |
# tokens = input_ids[0] | |
# evidence_segments = find_max_subspans( | |
# token_evidence.cpu().numpy().tolist(), max_segments, max_length)[0] | |
# evidence_segments = [ | |
# (_seg[0]+1, _seg[1]+1) for _seg in evidence_segments] | |
# start = 0 | |
# output = [] | |
# color = ( | |
# "gray" if original_steer_values.shape[0] > 1 | |
# else "red" if original_steer_values[0] > 0 | |
# else "blue" | |
# ) | |
# if len(evidence_segments) > 0: | |
# for _segment in evidence_segments: | |
# if _segment[0] > start: | |
# output.append(( | |
# self.tokenizer.decode(tokens[start: _segment[0]]), | |
# None | |
# )) | |
# output.append(( | |
# self.tokenizer.decode(tokens[_segment[0]: _segment[1]]), | |
# color | |
# )) | |
# start = _segment[1] | |
# length = tokens.shape[-1] | |
# if _segment[1] < length: | |
# output.append(( | |
# self.tokenizer.decode(tokens[_segment[1]: length]), | |
# None | |
# )) | |
# else: | |
# output = [(prompt, None)] | |
# if self.low_resource_mode: | |
# fp32 = torch.float32 | |
# self.steer.projector1.data = self.steer.projector1.to(fp32) | |
# self.steer.projector2.data = self.steer.projector2.to(fp32) | |
# return output | |
# def steer_analysis(self, prompt, steer_dim, min_value=-3, max_value=3, | |
# bins=7, truncation_length=1024): | |
# tokenized = self.tokenizer( | |
# prompt, return_tensors="pt", | |
# max_length=truncation_length, | |
# truncation=True) | |
# input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device) | |
# input_ids = input_ids.expand(bins + 1, -1) | |
# attention_mask = torch.LongTensor(tokenized["attention_mask"]).to( | |
# self.device) | |
# attention_mask = attention_mask.expand(bins + 1, -1) | |
# steer_values = torch.zeros(bins+1, self.num_steers).to(self.device) | |
# for bin_i in range(bins): | |
# steer_values[bin_i, steer_dim] = ( | |
# min_value + (max_value - min_value) / (bins - 1) * bin_i | |
# ) | |
# if self.low_resource_mode: | |
# fp16 = torch.float16 | |
# steer_values = steer_values.to(fp16) | |
# self.steer.projector1.data = self.steer.projector1.to(fp16) | |
# self.steer.projector2.data = self.steer.projector2.to(fp16) | |
# self.steer.set_value(steer_values) | |
# with torch.no_grad(): | |
# output = self.model( | |
# input_ids=input_ids, | |
# attention_mask=attention_mask, | |
# labels=input_ids) | |
# length = input_ids.shape[1] | |
# loss_token = F.cross_entropy( | |
# output.logits[:, :-1].reshape((bins+1)*(length-1), -1), | |
# input_ids[:, 1:].reshape(-1), | |
# reduction="none" | |
# ) | |
# loss_token = loss_token.reshape(bins + 1, length - 1) | |
# loss = loss_token.mean(-1)[:-1] | |
# dist = ((- loss + loss.mean()) * 100).softmax(0) | |
# dist_list = list(zip( | |
# [ | |
# min_value + (max_value - min_value) / (bins - 1) * bin_i | |
# for bin_i in range(bins) | |
# ], | |
# dist.tolist(), | |
# )) | |
# best_guess = loss.argmin(0) | |
# best_guess_value = min_value + \ | |
# (max_value - min_value) / (bins - 1) * best_guess.item() | |
# token_evidence = self.evidence_words( | |
# prompt, steer_values[best_guess], | |
# ) | |
# if self.low_resource_mode: | |
# fp32 = torch.float32 | |
# self.steer.projector1.data = self.steer.projector1.to(fp32) | |
# return best_guess_value, dist_list, token_evidence | |