import torch import torch.nn as nn from transformers import pipeline from .model_utils import Hack_no_grad from lm_steer.utils import set_seed class EmbeddingTuning_GPTNeoModel(nn.Module): def __init__(self, model_name): super().__init__() self.generator = pipeline( 'text-generation', model=model_name.replace("embedding_tuning-", "")) self.tokenizer = self.generator.tokenizer self.model = self.generator.model self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.model.transformer = Hack_no_grad(self.model.transformer) def forward(self, input_ids, attention_mask, steer_values): output = self.model( input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) return output def parameters(self): return [self.model.lm_head.weight] def state_dict(self): return self.model.lm_head.state_dict() def load_state_dict(self, state_dict): self.model.lm_head.load_state_dict(state_dict) def to_device(self, device): self.generator.device = device self.model.to(device) self.device = device def regularization_term(self): return torch.tensor(0) def generate(self, prompt, steer_values, min_length=20, max_length=100, seed=None, num_beams=1, num_beam_groups=1, do_sample=True, temperature=1, top_p=1): if seed is not None: set_seed(seed) with torch.no_grad(): text = self.generator( prompt, num_beams=num_beams, num_beam_groups=num_beam_groups, do_sample=do_sample, temperature=temperature, top_p=top_p, min_length=min_length, max_length=max_length, pad_token_id=self.tokenizer.pad_token_id, ) text = text[0]["generated_text"] return text