Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
class Projected_Adaptor(nn.Module): | |
def __init__(self, lm_head, adaptor_class, num_steers, embed_dim, | |
vocab_size, rank, epsilon, init_var, position="output"): | |
super().__init__() | |
assert rank > 0 | |
if adaptor_class == "multiply": | |
self.projector1 = nn.Parameter(torch.randn( | |
num_steers, embed_dim, rank | |
) * init_var) | |
self.projector2 = nn.Parameter(torch.randn( | |
num_steers, embed_dim, rank | |
) * init_var) | |
elif adaptor_class == "add": | |
self.add_vec = nn.Parameter(torch.randn( | |
num_steers, embed_dim | |
)) | |
elif adaptor_class == "offset": | |
self.offset_vec = nn.Parameter(torch.randn( | |
num_steers, vocab_size | |
)) | |
else: | |
raise NotImplementedError() | |
self.adaptor_class = adaptor_class | |
self.rank = rank | |
self.lm_head = lm_head | |
self.epsilon = epsilon | |
self.position = position | |
self.num_steers = num_steers | |
self.init_var = init_var | |
self.steer_values = torch.zeros(num_steers) | |
def set_value(self, steer_values): | |
self.steer_values = steer_values | |
def forward(self, state): | |
if self.steer_values.abs().sum() == 0: | |
return state.matmul( | |
self.lm_head.weight.detach().transpose(0, 1)) | |
if self.adaptor_class == "multiply": | |
delta = state[:, None].matmul(self.projector1[None]) *\ | |
self.steer_values[:, :, None, None] | |
delta = delta.matmul( | |
self.projector2.transpose(1, 2)[None]).sum(1) | |
projected_state = state + self.epsilon * delta | |
logits = projected_state.matmul( | |
self.lm_head.weight.detach().transpose(0, 1)) | |
elif self.adaptor_class == "add": | |
add_values = self.steer_values.matmul(self.add_vec) | |
projected_state = state + self.epsilon * add_values[:, None] | |
logits = projected_state.matmul( | |
self.lm_head.weight.detach().transpose(0, 1)) | |
elif self.adaptor_class == "offset": | |
offset_values = self.steer_values.matmul(self.offset_vec) | |
logits = state.matmul( | |
self.lm_head.weight.detach().transpose(0, 1)) | |
logits = logits + self.epsilon * offset_values[:, None] | |
return logits | |
def regularization_term(self): | |
if self.adaptor_class == "multiply": | |
return self.projector1.pow(2).sum() + self.projector2.pow(2).sum() | |
elif self.adaptor_class == "add": | |
return self.add_vec.pow(2).sum() | |
elif self.adaptor_class == "offset": | |
return self.offset_vec.pow(2).sum() | |
def parameters(self): | |
if self.adaptor_class == "multiply": | |
return [self.projector1, self.projector2] | |
elif self.adaptor_class == "add": | |
return [self.add_vec] | |
elif self.adaptor_class == "offset": | |
return [self.offset_vec] | |
def state_dict(self): | |
if self.adaptor_class == "multiply": | |
return {"projector1": self.projector1, | |
"projector2": self.projector2} | |
elif self.adaptor_class == "add": | |
return {"add_vec": self.add_vec} | |
elif self.adaptor_class == "offset": | |
return {"offset_vec": self.offset_vec} | |
def load_state_dict(self, state_dict): | |
if self.adaptor_class == "multiply": | |
self.projector1.data = state_dict["projector1"] | |
self.projector2.data = state_dict["projector2"] | |
elif self.adaptor_class == "add": | |
self.add_vec.data = state_dict["add_vec"] | |
elif self.adaptor_class == "offset": | |
self.offset_vec.data = state_dict["offset_vec"] | |