LM-Steer / lm_steer /models /model_base.py
hanchier's picture
caching
d75dc6d
raw
history blame
9.19 kB
import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F
from lm_steer.utils import set_seed
from .model_utils import find_max_subspans
punctuations = [
'!', '"', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.',
# '/', '#',
':', ';', '<', '=', '>', '?', '@',
'[', '\\', ']', '^', '_', '`',
'{', '|', '}', '~',
'¨', '©', 'ª', '«', '¬', '®', '¯', '°', '±', '²', '³', '´', 'µ', '¶', '·',
'¸', '¹', 'º', '»', '¼', '½', '¾',
'\n', ' ',
]
class LMSteerBase(nn.Module):
def evidence_words(self, prompt, comparing_steer_values,
truncation_length=1024, max_segments=4, max_length=10):
if isinstance(comparing_steer_values, list):
comparing_steer_values = \
torch.Tensor(comparing_steer_values).to(self.device)
if (comparing_steer_values[0] - comparing_steer_values[1]).abs().sum()\
<= 0.2:
return [(prompt, None)]
tokenized = self.tokenizer(
prompt, return_tensors="pt",
max_length=truncation_length, truncation=True)
input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device)
input_ids = input_ids.expand(2, -1)
attention_mask = torch.LongTensor(tokenized["attention_mask"]).to(
self.device)
attention_mask = attention_mask.expand(2, -1)
self.steer.set_value(comparing_steer_values)
with torch.no_grad():
output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids)
length = input_ids.shape[1]
loss_token = F.cross_entropy(
output.logits[:, :-1].reshape((2)*(length-1), -1),
input_ids[:, 1:].reshape(-1),
reduction="none"
)
loss_token = loss_token.reshape(2, length - 1)
token_evidence = (- loss_token[0] + loss_token[1])
tokens = input_ids[0]
evidence_segments = find_max_subspans(
token_evidence.cpu().numpy().tolist(), max_segments, max_length)[0]
evidence_segments = [
(_seg[0]+1, _seg[1]+1) for _seg in evidence_segments]
start = 0
output = []
if len(evidence_segments) > 0:
for _segment in evidence_segments:
if _segment[0] > start:
output.append((
self.tokenizer.decode(tokens[start: _segment[0]]),
None
))
output.append((
self.tokenizer.decode(tokens[_segment[0]: _segment[1]]),
"evidence"
))
start = _segment[1]
length = tokens.shape[-1]
if _segment[1] < length:
output.append((
self.tokenizer.decode(tokens[_segment[1]: length]),
None
))
else:
output = [(prompt, None)]
return output, token_evidence.tolist()
def steer_analysis(self, prompt, steer_dim, min_value=-3, max_value=3,
bins=7):
tokenized = self.tokenizer(prompt)
input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device)
input_ids = input_ids.expand(bins + 1, -1)
attention_mask = torch.LongTensor(tokenized["attention_mask"]).to(
self.device)
attention_mask = attention_mask.expand(bins + 1, -1)
steer_values = torch.zeros(bins+1, self.num_steers).to(self.device)
for bin_i in range(bins):
steer_values[bin_i, steer_dim] = (
min_value + (max_value - min_value) / (bins - 1) * bin_i
)
self.steer.set_value(steer_values)
with torch.no_grad():
output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids)
length = input_ids.shape[1]
loss_token = F.cross_entropy(
output.logits[:, :-1].reshape((bins+1)*(length-1), -1),
input_ids[:, 1:].reshape(-1),
reduction="none"
)
loss_token = loss_token.reshape(bins + 1, length - 1)
loss = loss_token.mean(-1)[:-1]
dist = ((- loss + loss.mean()) * 100).softmax(0)
dist_list = list(zip(
[
min_value + (max_value - min_value) / (bins - 1) * bin_i
for bin_i in range(bins)
],
dist.tolist(),
))
best_guess = loss.argmin(0)
best_guess_value = min_value + \
(max_value - min_value) / (bins - 1) * best_guess.item()
token_evidence = (- loss_token[best_guess] + loss_token[-1]) * 10
token_evidence = [0] + token_evidence.tolist()
# tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
word_evidence_list = []
start = 0
n_tokens = len(input_ids[0])
for token_i in range(1, n_tokens+1):
span = self.tokenizer.decode(input_ids[0][start: token_i])
for _punc in punctuations:
if token_i == n_tokens or _punc in span:
new_span = self.tokenizer.decode(
input_ids[0][start: token_i-1]).strip()
if len(new_span) <= 1:
break
word_evidence_list.append((
new_span,
np.array(token_evidence[start: token_i-1]).mean()
))
start = token_i - 1
break
# token_evidence_list = list(zip(tokens, token_evidence))
return best_guess_value, dist_list, word_evidence_list
def generate(self, prompt, steer_values, min_length=20, max_length=100,
seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
temperature=1, top_p=1):
'''
prompt: a string
steer_values
min_length: minimum generation length
max_length: maximum generation length
seed: seed for generation. None if not specified.
'''
if seed is not None:
set_seed(seed)
steer_values = torch.Tensor(steer_values).to(
self.device)
self.steer.set_value(steer_values[None])
with torch.no_grad():
inputs = self.tokenizer(
prompt, return_tensors="pt").to(self.device)
text = self.model.generate(
**inputs,
num_beams=num_beams, num_beam_groups=num_beam_groups,
do_sample=do_sample, temperature=temperature, top_p=top_p,
min_length=min_length, max_length=max_length,
pad_token_id=self.tokenizer.pad_token_id,
)
text = self.tokenizer.decode(text[0], skip_special_tokens=True)
return text
def generate_low_resource(
self, prompt, steer_values, min_length=20, max_length=100,
seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
temperature=1, top_p=1
):
'''
prompt: a string
steer_values
min_length: minimum generation length
max_length: maximum generation length
seed: seed for generation. None if not specified.
'''
if seed is not None:
set_seed(seed)
steer_values = torch.Tensor(steer_values).to(
self.device)
fp16 = torch.float16
steer_values = steer_values.to(fp16)
self.steer.projector1.data = self.steer.projector1.to(fp16)
self.steer.projector2.data = self.steer.projector2.to(fp16)
self.steer.set_value(steer_values[None])
with torch.no_grad():
input_ids = self.tokenizer(
prompt, return_tensors="pt").input_ids.to(self.device)
gen_tokens = self.model.generate(
input_ids,
num_beams=num_beams, num_beam_groups=num_beam_groups,
do_sample=do_sample, temperature=temperature, top_p=top_p,
min_length=min_length, max_length=max_length,
pad_token_id=self.tokenizer.pad_token_id)
text = self.tokenizer.batch_decode(gen_tokens)[0]
# recovering
fp32 = torch.float32
self.steer.projector1.data = self.steer.projector1.to(fp32)
self.steer.projector2.data = self.steer.projector2.to(fp32)
return text
def state_dict(self):
return self.steer.state_dict()
def load_state_dict(self, state_dict):
self.steer.load_state_dict(state_dict)
def parameters(self):
return self.steer.parameters()
def to_device(self, device):
self.model.to(device)
self.device = device
def regularization_term(self):
return self.steer.regularization_term()
def forward(self, input_ids, attention_mask, steer_values):
self.steer.set_value(steer_values)
output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids)
return output