autoencoders / autoencoder.py
Dan Friedman
Add autoencoder.py
f9cfa84
raw
history blame
26.6 kB
import math
import numpy as np
import torch
from torch import nn
from torch.distributions import Independent, Normal, MultivariateNormal
import torch.nn.functional as F
from transformers import AutoModel, AutoModelForCausalLM
from tqdm import tqdm
from tqdm.notebook import tqdm as tqdm_notebook
class Res(nn.Module):
def __init__(self, H):
super().__init__()
self.u1 = nn.Linear(H, H)
self.u2 = nn.Linear(H, H)
self.v1 = nn.Linear(H, H)
self.v2 = nn.Linear(H, H)
self.w = nn.Linear(H, H)
def forward(self, x):
x = self.w(x)
x = x + torch.relu(self.v1(torch.relu(self.u1(x))))
return x + torch.relu(self.v2(torch.relu(self.u2(x))))
class MLP(nn.Module):
def __init__(self, H, out=None):
super().__init__()
out = out or H
self.mlp = nn.Sequential(
nn.Linear(H, H),
nn.ReLU(),
nn.Linear(H, H),
nn.ReLU(),
nn.Linear(H, out),
)
def forward(self, x):
return self.mlp(x)
class Encoder(nn.Module):
def __init__(self, tokenizer, model_name_or_path="roberta-base", **kwargs):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name_or_path)
self.encoder.resize_token_embeddings(len(tokenizer))
self.dim = self.encoder.config.hidden_size
@property
def device(self):
return self.encoder.device
def forward(self, **inputs):
model_inputs = {
k: inputs[k].to(self.device)
for k in ("input_ids", "attention_mask")
}
if inputs.get("token_type_ids", None) is not None:
model_inputs["token_type_ids"] = inputs["token_type_ids"].to(
self.device
)
out = self.encoder(**model_inputs)
emb = out.last_hidden_state[:, 0]
return emb
class PrefixDecoder(nn.Module):
def __init__(
self,
tokenizer,
model_name_or_path="gpt2",
prefix_length=1,
ffn="res",
**kwargs,
):
super().__init__()
self.decoder = AutoModelForCausalLM.from_pretrained(model_name_or_path)
self.hidden_dim = D = self.decoder.config.n_embd
self.num_layers = L = self.decoder.config.n_layer
self.num_heads = H = self.decoder.config.n_head
self.prefix_length = K = prefix_length
self.lin1 = nn.Linear(D, D * 2)
self.z_size = D * L * K * 2
if ffn == "res":
self.mlp = nn.Sequential(Res(D), nn.Linear(D, self.z_size))
else:
self.mlp = MLP(D, self.z_size)
def get_prefix(self, z):
B = z.shape[0]
D, L, H, K = (
self.hidden_dim,
self.num_layers,
self.num_heads,
self.prefix_length,
)
z_up = self.mlp(z).reshape(B, H, K, D // H, L, 2)
keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1))
layers = tuple(
[
(k.squeeze(-1), v.squeeze(-1))
for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1))
]
)
return layers
def forward(self, z, **inputs):
B = z.shape[0]
D, L, H, K = (
self.hidden_dim,
self.num_layers,
self.num_heads,
self.prefix_length,
)
z_up = self.mlp(z).reshape(B, H, K, D // H, L, 2)
keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1))
layers = tuple(
[
(k.squeeze(-1), v.squeeze(-1))
for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1))
]
)
input_ids = inputs["input_ids"].to(z.device)
attention_mask = inputs["attention_mask"].to(z.device)
attention_mask = torch.cat(
[torch.ones(B, K, dtype=bool, device=z.device), attention_mask],
1,
)
out = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=layers,
)
return out
def get_inputs(
inputs, prefix, keys=["input_ids", "attention_mask", "token_type_ids"]
):
return {k: inputs.get(f"{prefix}{k}", None) for k in keys}
class VAE(nn.Module):
def __init__(self, encoder, decoder, beta=1.0, do_sample=True, **kwargs):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.beta = beta
D = decoder.hidden_dim
self.lin = nn.Linear(D, D * 2)
self.do_sample = do_sample
@property
def device(self):
return self.encoder.device
def get_z(self, sample=True, **inputs):
enc = self.encoder(**get_inputs(inputs, "enc_"))
B, D = enc.shape
mu, logvar = (
t.squeeze(-1) for t in self.lin(enc).view(B, D, 2).chunk(2, -1)
)
qz = Normal(mu, logvar.exp())
pz = Normal(torch.zeros_like(mu[0]), torch.ones_like(mu[0]))
kl = torch.distributions.kl_divergence(qz, pz).sum(-1)
if sample:
z = qz.rsample()
else:
z = mu
return z, kl
def forward(self, **inputs):
z, kl = self.get_z(sample=self.do_sample, **inputs)
out = self.decoder(z, **get_inputs(inputs, "dec_"))
out["kl"] = kl
return out
class AAE(nn.Module):
def __init__(self, encoder, decoder, _lambda=1.0, word_drop=None, **kwargs):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self._lambda = _lambda
dim = decoder.hidden_dim
self.D = nn.Sequential(
nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid()
)
self.word_drop = word_drop
@property
def device(self):
return self.encoder.device
def get_z(self, **inputs):
if self.word_drop is not None:
m = inputs["enc_attention_mask"]
b = torch.rand_like(m.float()) > self.word_drop
inputs["enc_attention_mask"] = m & b
return self.encoder(**get_inputs(inputs, "enc_")), None
def loss_adv(self, z):
# https://github.com/shentianxiao/text-autoencoders
zn = torch.randn_like(z)
zeros = torch.zeros(len(z), 1, device=z.device)
ones = torch.ones(len(z), 1, device=z.device)
loss_d = F.binary_cross_entropy(
self.D(z.detach()), zeros, reduction="none"
) + F.binary_cross_entropy(self.D(zn), ones, reduction="none")
adv = F.binary_cross_entropy(self.D(z), ones, reduction="none")
return loss_d, adv
def forward(self, **inputs):
z, _ = self.get_z(**inputs)
out = self.decoder(z, **get_inputs(inputs, "dec_"))
b, n, _ = out["logits"].shape
log_probs = out["logits"].log_softmax(-1)
log_probs = torch.gather(
log_probs[:, :-1],
-1,
inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
).squeeze(-1)
log_probs = log_probs.masked_fill(
~inputs["dec_attention_mask"][:, 1:], 0
)
out["l_rec"] = -log_probs.sum(-1)
out["loss_d"], out["adv"] = self.loss_adv(z)
return out
class AE(nn.Module):
def __init__(self, encoder, decoder, **kwargs):
super().__init__()
self.encoder = encoder
self.decoder = decoder
dim = decoder.hidden_dim
self.D = nn.Sequential(
nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid()
)
@property
def device(self):
return self.encoder.device
def get_z(self, **inputs):
return self.encoder(**get_inputs(inputs, "enc_")), None
def step(self, **inputs):
z, _ = self.get_z(**inputs)
out = self.decoder(z, **get_inputs(inputs, "dec_"))
b, n, _ = out["logits"].shape
log_probs = out["logits"].log_softmax(-1)
log_probs = torch.gather(
log_probs[:, :-1],
-1,
inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
).squeeze(-1)
log_probs = log_probs.masked_fill(
~inputs["dec_attention_mask"][:, 1:], 0
)
out["loss_r"] = -log_probs.sum(-1)
return z, out
def forward(self, **inputs):
z, out = self.step(**inputs)
out["loss_c"] = torch.zeros_like(out["loss_r"])
return out
class CDAE(nn.Module):
def __init__(
self, encoder, decoder, _lambda=1.0, word_drop=None, tau=1.0, **kwargs
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self._lambda = _lambda
dim = decoder.hidden_dim
self.D = nn.Sequential(
nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid()
)
self.word_drop = word_drop
self.tau = tau
@property
def device(self):
return self.encoder.device
def do_mask(self, **inputs):
m = inputs["enc_attention_mask"]
b = torch.rand_like(m.float()) > self.word_drop
inputs["enc_attention_mask"] = m & b
B, N = inputs["dec_attention_mask"].shape
_, M = m.shape
m2 = inputs["dec_attention_mask"]
if N <= M:
b2 = b[:, :N]
else:
b_ = torch.rand((B, N - M), device=b.device) > self.word_drop
b2 = torch.cat([b, b_], -1)
inputs["dec_attention_mask"] = m2 & b2
def get_z(self, **inputs):
return self.encoder(**get_inputs(inputs, "enc_")), None
def step(self, **inputs):
z, _ = self.get_z(**inputs)
out = self.decoder(z, **get_inputs(inputs, "dec_"))
b, n, _ = out["logits"].shape
log_probs = out["logits"].log_softmax(-1)
log_probs = torch.gather(
log_probs[:, :-1],
-1,
inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
).squeeze(-1)
log_probs = log_probs.masked_fill(
~inputs["dec_attention_mask"][:, 1:], 0
)
out["loss_r"] = -log_probs.sum(-1)
return z, out
def loss_c(self, z, z2):
scores = -(torch.cdist(z, z2) ** 2)
log_probs = (scores / self.tau).log_softmax(-1)
loss = -torch.diagonal(log_probs)
return loss
def forward(self, **inputs):
z, out = self.step(**inputs)
self.do_mask(**inputs)
z_, out_ = self.step(**inputs)
out["loss_r"] = out["loss_r"] + out_["loss_r"]
out["loss_c"] = self.loss_c(z, z_)
return out
def run_aae_epoch(
model,
batches,
opt,
optD,
num_samples=1,
lambda_adv=1.0,
desc="",
notebook=True,
):
losses = {k: [] for k in ("l_rec", "adv", "loss_d")}
t = (
tqdm_notebook(batches, desc=desc)
if notebook
else tqdm(batches, desc=desc)
)
for batch in t:
model_inputs = {
k: v.to(model.device)
for k, v in batch.items()
if type(v) == torch.Tensor
}
out = model(**model_inputs)
loss = (out["l_rec"] + lambda_adv * out["adv"]).sum()
opt.zero_grad()
loss.backward()
opt.step()
loss_d = out["loss_d"].sum()
optD.zero_grad()
loss_d.backward()
optD.step()
d = {}
for k in ("l_rec", "adv", "loss_d"):
d[k] = out[k].mean().item()
losses[k].append(out[k].detach().cpu().numpy())
t.set_postfix(d)
return {k: np.concatenate(v, 0) for k, v in losses.items()}
class GAE(nn.Module):
def __init__(self, encoder, decoder, tau=0.05, **kwargs):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.tau = tau
@property
def device(self):
return self.encoder.device
def get_z(self, **inputs):
return self.encoder(**get_inputs(inputs, "enc_")), None
def loss_c(self, z, z2):
scores = F.normalize(z, dim=-1) @ F.normalize(z2, dim=-1).T
log_probs = (scores / self.tau).log_softmax(-1)
loss = -torch.diagonal(log_probs)
return loss
def forward(self, **inputs):
z, _ = self.get_z(**inputs)
out = self.decoder(z, **get_inputs(inputs, "dec_"))
b, n, _ = out["logits"].shape
log_probs = out["logits"].log_softmax(-1)
log_probs = torch.gather(
log_probs[:, :-1],
-1,
inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
).squeeze(-1)
log_probs = log_probs.masked_fill(
~inputs["dec_attention_mask"][:, 1:], 0
)
out["loss_r"] = -log_probs.sum(-1)
out["loss_c"] = self.loss_c(z)
return out
class CAE(nn.Module):
def __init__(self, encoder, decoder, tau=0.05, **kwargs):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.tau = tau
@property
def device(self):
return self.encoder.device
def get_z(self, **inputs):
return self.encoder(**get_inputs(inputs, "enc_")), None
def loss_c(self, z, z2):
scores = F.normalize(z, dim=-1) @ F.normalize(z2, dim=-1).T
log_probs = (scores / self.tau).log_softmax(-1)
loss = -torch.diagonal(log_probs)
return loss
def forward(self, **inputs):
z, _ = self.get_z(**inputs)
with torch.no_grad():
z2, _ = self.get_z(**inputs)
out = self.decoder(z, **get_inputs(inputs, "dec_"))
b, n, _ = out["logits"].shape
log_probs = out["logits"].log_softmax(-1)
log_probs = torch.gather(
log_probs[:, :-1],
-1,
inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
).squeeze(-1)
log_probs = log_probs.masked_fill(
~inputs["dec_attention_mask"][:, 1:], 0
)
out["loss_r"] = -log_probs.sum(-1)
out["loss_c"] = self.loss_c(z, z2)
return out
def run_cae_epoch(
model,
batches,
opt,
num_samples=1,
lambda_c=1.0,
desc="",
notebook=True,
):
losses = {k: [] for k in ("loss_r", "loss_c")}
t = (
tqdm_notebook(batches, desc=desc)
if notebook
else tqdm(batches, desc=desc)
)
model.train()
for batch in t:
model_inputs = {
k: v.to(model.device)
for k, v in batch.items()
if type(v) == torch.Tensor
}
out = model(**model_inputs)
loss = (out["loss_r"] + lambda_c * out["loss_c"]).sum()
opt.zero_grad()
loss.backward()
opt.step()
d = {}
for k in ("loss_r", "loss_c"):
d[k] = out[k].mean().item()
losses[k].append(out[k].detach().cpu().numpy())
t.set_postfix(d)
return {k: np.concatenate(v, 0) for k, v in losses.items()}
def batch_kl(l1, s1, l2=None, s2=None):
# 1/2[log |s1|/|s2| - d + tr[s2^{-1}s1] + (l2 - l1)^{\top} s2^{-1}(l2 - l1)]
return
class SubpopCondAE(nn.Module):
def __init__(
self,
encoder,
decoder,
num_labels,
sublabels=4,
tau=0.05,
disc_loss=True,
**kwargs,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.dim = dim = decoder.hidden_dim
self.locs = nn.Parameter(torch.randn(num_labels * sublabels, dim))
self.log_scales = nn.Parameter(torch.zeros(num_labels * sublabels, dim))
self.num_labels = num_labels
self.sublabels = sublabels
self.L = num_labels * sublabels
self.tau = tau
self.disc_loss = disc_loss
@property
def device(self):
return self.encoder.device
def get_z(self, **inputs):
return self.encoder(**get_inputs(inputs, "enc_")), None
def loss_c(self, z, **inputs):
scores = []
for i in range(self.L):
dist = Independent(
Normal(loc=self.locs[i], scale=self.log_scales[i].exp()), 1
)
scores.append(dist.log_prob(z))
B = z.shape[0]
sub_log_probs = torch.stack(scores, -1)
if self.disc_loss:
sub_log_probs = sub_log_probs.log_softmax(-1)
log_probs = sub_log_probs.view(
B, self.num_labels, self.num_sublabels
).logsumexp(-1)
loss = F.nll_loss(log_probs, inputs["label"], reduction="none")
acc = log_probs.argmax(-1) == inputs["label"]
return {
"loss_c": loss,
"log_probs": log_probs,
"sub_log_probs": sub_log_probs,
"acc": acc.float(),
}
def get_kl(self):
p = MultivariateNormal(
torch.zeros(self.dim, device=self.device),
torch.eye(self.dim, device=self.device),
)
kl = 0
for i in range(self.L):
q = MultivariateNormal(
self.locs[i], torch.diag(self.log_scales[i].exp())
)
kl += torch.distributions.kl_divergence(q, p)
return kl
def forward(self, **inputs):
z, _ = self.get_z(**inputs)
out = self.decoder(z, **get_inputs(inputs, "dec_"))
b, n, _ = out["logits"].shape
log_probs = out["logits"].log_softmax(-1)
log_probs = torch.gather(
log_probs[:, :-1],
-1,
inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
).squeeze(-1)
log_probs = log_probs.masked_fill(
~inputs["dec_attention_mask"][:, 1:], 0
)
out["loss_r"] = -log_probs.sum(-1)
out_c = self.loss_c(z, **inputs)
for k, v in out_c.items():
out[k] = v
out["kl"] = self.get_kl().unsqueeze(0)
return out
def gaussian_prob_product(m1, s1, m2, s2, rho=1.0):
# s1, s2 diagonal
s1_inv = 1 / s1
s2_inv = 1 / s2
s_hat = 1 / (s1 + s2)
m_hat = s1_inv * s1 + s2_inv * s2
dim = m1.shape[-1]
return (
((2 * math.pi) ** ((1 - 2 * rho) * dim / 2))
* (rho ** (-dim / 2))
* torch.sqrt(s_hat.prod(-1))
* ((s1.prod(-1) * s2.prod(-1)) ** (-rho / 2))
* torch.exp(
-(1 / rho)
* (
m1 @ (s1_inv * m1).T
+ m2 @ (s2_inv * m2).T
- m_hat @ (s_hat * m_hat).T
)
)
)
class CondAE(nn.Module):
def __init__(
self,
encoder,
decoder,
num_labels,
logdet=False,
l2_reg=False,
disc_loss=True,
tau=0.05,
**kwargs,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.dim = dim = decoder.hidden_dim
self.locs = nn.Parameter(torch.randn(num_labels, dim))
self.log_scales = nn.Parameter(torch.zeros(num_labels, dim))
self.num_labels = num_labels
self.tau = tau
self.logdet = logdet
self.l2_reg = l2_reg
self.disc_loss = disc_loss
@property
def device(self):
return self.encoder.device
def get_z(self, **inputs):
return self.encoder(**get_inputs(inputs, "enc_")), None
def loss_c(self, z, **inputs):
scores = []
for i in range(self.num_labels):
dist = Independent(
Normal(loc=self.locs[i], scale=self.log_scales[i].exp()), 1
)
scores.append(dist.log_prob(z))
log_probs = torch.stack(scores, -1)
if self.disc_loss:
log_probs = log_probs.log_softmax(-1)
loss = F.nll_loss(log_probs, inputs["label"], reduction="none")
acc = log_probs.argmax(-1) == inputs["label"]
return {"loss_c": loss, "log_probs": log_probs, "acc": acc.float()}
def get_kl(self):
p = MultivariateNormal(
torch.zeros(self.dim, device=self.device),
torch.eye(self.dim, device=self.device),
)
kl = 0
for i in range(self.num_labels):
q = MultivariateNormal(
self.locs[i], torch.diag(self.log_scales[i].exp())
)
kl += torch.distributions.kl_divergence(q, p)
if self.logdet:
K = torch.exp(-torch.cdist(self.locs, self.locs) ** 2)
kl += torch.logdet(K)
elif self.l2_reg:
K = torch.exp(-torch.cdist(self.locs, self.locs) ** 2)
kl += torch.log(
torch.linalg.norm(K / K.shape[0], dim=(-2, -1)) ** 2
).sum()
return kl
def forward(self, **inputs):
z, _ = self.get_z(**inputs)
out = self.decoder(z, **get_inputs(inputs, "dec_"))
b, n, _ = out["logits"].shape
log_probs = out["logits"].log_softmax(-1)
log_probs = torch.gather(
log_probs[:, :-1],
-1,
inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
).squeeze(-1)
log_probs = log_probs.masked_fill(
~inputs["dec_attention_mask"][:, 1:], 0
)
out["loss_r"] = -log_probs.sum(-1)
out_c = self.loss_c(z, **inputs)
for k, v in out_c.items():
out[k] = v
out["kl"] = self.get_kl().unsqueeze(0)
return out
class BasicCondAE(nn.Module):
def __init__(self, encoder, decoder, num_labels, tau=0.05, **kwargs):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.dim = dim = decoder.hidden_dim
self.linear = nn.Linear(dim, num_labels)
self.num_labels = num_labels
self.tau = tau
@property
def device(self):
return self.encoder.device
def get_z(self, **inputs):
return self.encoder(**get_inputs(inputs, "enc_")), None
def loss_c(self, z, **inputs):
log_probs = self.linear(z).log_softmax(-1)
loss = F.nll_loss(log_probs, inputs["label"], reduction="none")
acc = log_probs.argmax(-1) == inputs["label"]
return {"loss_c": loss, "log_probs": log_probs, "acc": acc.float()}
def forward(self, **inputs):
z, _ = self.get_z(**inputs)
out = self.decoder(z, **get_inputs(inputs, "dec_"))
b, n, _ = out["logits"].shape
log_probs = out["logits"].log_softmax(-1)
log_probs = torch.gather(
log_probs[:, :-1],
-1,
inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
).squeeze(-1)
log_probs = log_probs.masked_fill(
~inputs["dec_attention_mask"][:, 1:], 0
)
out["loss_r"] = -log_probs.sum(-1)
out_c = self.loss_c(z, **inputs)
for k, v in out_c.items():
out[k] = v
out["kl"] = torch.zeros_like(out["loss_r"])
return out
def run_cond_ae_epoch(
model,
batches,
opt,
num_samples=1,
lambda_c=1.0,
lambda_r=1.0,
beta=1.0,
desc="",
notebook=True,
):
losses = {k: [] for k in ("loss_r", "loss_c", "kl", "acc")}
t = (
tqdm_notebook(batches, desc=desc)
if notebook
else tqdm(batches, desc=desc)
)
model.train()
for batch in t:
model_inputs = {
k: v.to(model.device)
for k, v in batch.items()
if type(v) == torch.Tensor
}
out = model(**model_inputs)
loss = (
lambda_r * out["loss_r"] + lambda_c * out["loss_c"]
).sum() + beta * out["kl"].sum()
opt.zero_grad()
loss.backward()
opt.step()
d = {}
for k in ("loss_r", "loss_c", "kl", "acc"):
d[k] = out[k].mean().item()
losses[k].append(out[k].detach().cpu().numpy())
t.set_postfix(d)
return {k: np.concatenate(v, 0) for k, v in losses.items()}
def run_cond_ae_eval(
model,
batches,
lambda_c=1.0,
beta=1.0,
desc="",
notebook=True,
):
losses = {k: [] for k in ("loss_r", "loss_c", "kl", "acc")}
t = (
tqdm_notebook(batches, desc=desc)
if notebook
else tqdm(batches, desc=desc)
)
model.eval()
for batch in t:
model_inputs = {
k: v.to(model.device)
for k, v in batch.items()
if type(v) == torch.Tensor
}
with torch.no_grad():
out = model(**model_inputs)
loss = (
out["loss_r"] + lambda_c * out["loss_c"]
).sum() + beta * out["kl"].sum()
d = {}
for k in ("loss_r", "loss_c", "kl", "acc"):
d[k] = out[k].mean().item()
losses[k].append(out[k].detach().cpu().numpy())
t.set_postfix(d)
return {k: np.concatenate(v, 0) for k, v in losses.items()}
def generate(
model,
tokenizer,
batch=None,
z=None,
do_sample=False,
max_length=128,
**kwargs,
):
if z is None:
with torch.no_grad():
z, _ = model.get_z(sample=False, **batch)
B, D = z.shape
else:
z = torch.tensor(z, device=model.device)
B, D = z.shape
D, L, H, K = (
model.decoder.hidden_dim,
model.decoder.num_layers,
model.decoder.num_heads,
model.decoder.prefix_length,
)
z_up = model.decoder.mlp(z).reshape(B, H, K, D // H, L, 2)
keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1))
layers = tuple(
[
(k.squeeze(-1), v.squeeze(-1))
for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1))
]
)
output = model.decoder.decoder.generate(
input_ids=torch.tensor(
[[tokenizer.bos_token_id]] * B, device=model.device
),
attention_mask=torch.ones((B, K + 1), device=model.device),
past=layers,
do_sample=do_sample,
max_length=max_length,
**kwargs,
)
lst = tokenizer.batch_decode(output[:, 1:])
return [l.replace("<|endoftext|>", "") for l in lst]
def get_embeddings(model, batches, desc="", notebook=True):
out = []
t = (
tqdm_notebook(batches, desc=desc)
if notebook
else tqdm(batches, desc=desc)
)
model.eval()
for batch in t:
with torch.no_grad():
model_inputs = {
k: v.to(model.device)
for k, v in batch.items()
if type(v) == torch.Tensor
}
z, _ = model.get_z(sample=False, **model_inputs)
out.append(z.detach().cpu().numpy())
return np.concatenate(out, 0)
def interpolate(model, tokenizer, a, b, num_steps=10, **kwargs):
z = np.stack(
[l * b + (1 - l) * a for l in np.linspace(0, 1.0, num_steps)], 0
)
return generate(model, tokenizer, z=z, **kwargs)