|
import transformers |
|
import torch |
|
import os |
|
import numpy as np |
|
import datetime |
|
import struct |
|
from torch.nn.utils.rnn import pad_sequence |
|
import torch.nn.functional as F |
|
|
|
def get_inner_params(named_parameters, inner_names): |
|
param_dict = dict(named_parameters) |
|
return [(n, param_dict[n]) for n in inner_names] |
|
|
|
def param_subset(named_parameters, inner_names): |
|
param_dict = dict(named_parameters) |
|
return [param_dict[n] for n in inner_names] |
|
|
|
def parent_module(model, pname): |
|
components = pname.split('.') |
|
parent = model |
|
|
|
for component in components[:-1]: |
|
if hasattr(parent, component): |
|
parent = getattr(parent, component) |
|
elif component.isdigit(): |
|
parent = parent[int(component)] |
|
else: |
|
raise RuntimeError(f"Couldn't find child module {component}") |
|
|
|
if not hasattr(parent, components[-1]): |
|
raise RuntimeError(f"Couldn't find child module {components[-1]}") |
|
|
|
return parent |
|
|
|
def uuid(digits=4): |
|
if not hasattr(uuid, "uuid_value"): |
|
uuid.uuid_value = struct.unpack('I', os.urandom(4))[0] % int(10**digits) |
|
|
|
return uuid.uuid_value |
|
|
|
def ckpt_dir(): |
|
"""returns the directory in which to store model checkpoints""" |
|
path = "./ckpts/" |
|
if not os.path.exists(path): |
|
os.makedirs(path) |
|
return path |
|
|
|
def brackets_to_periods(name): |
|
return name.replace("[", ".").replace("]", "") |
|
|
|
def get_params(model): |
|
return model.state_dict() |
|
|
|
def get_shape(p, model): |
|
|
|
return p.shape if isinstance(model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0]) |
|
|
|
def get_logits(x): |
|
return x.logits if hasattr(x, "logits") else x |
|
|
|
def tokenize(batch, tokenizer, device, test=False): |
|
prompt, label = batch["prompt"], batch["target_new"] |
|
if not isinstance(prompt, list): |
|
prompt=[prompt] |
|
if not isinstance(label, list): |
|
label=[label] |
|
mask_token = -100 |
|
if test or not label: |
|
tokens = tokenizer(list(prompt), return_tensors="pt", padding=True, truncation=True) |
|
tokens["labels"] = tokens["input_ids"].clone() |
|
tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token |
|
|
|
else: |
|
full_prompt = [f"{p} {l}" for p, l in zip(prompt, label)] |
|
prompt_ids = tokenizer(list(prompt), return_tensors="pt", padding=True, truncation=True)["input_ids"] |
|
num_prompt_toks = [int((i != tokenizer.pad_token_id).sum()) for i in prompt_ids] |
|
tokens = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True) |
|
tokens["labels"] = tokens["input_ids"].clone() |
|
for i in range(len(prompt)): |
|
tokens["labels"][i][:num_prompt_toks[i]] = mask_token |
|
|
|
tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token |
|
|
|
tokens = {f"{k1}" : v1.to(device) for k1, v1 in tokens.items()} |
|
return tokens |
|
|
|
|