|
import math |
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from torch.distributions import Independent, Normal, MultivariateNormal |
|
import torch.nn.functional as F |
|
|
|
from transformers import AutoModel, AutoModelForCausalLM |
|
from tqdm import tqdm |
|
from tqdm.notebook import tqdm as tqdm_notebook |
|
|
|
|
|
class Res(nn.Module): |
|
def __init__(self, H): |
|
super().__init__() |
|
self.u1 = nn.Linear(H, H) |
|
self.u2 = nn.Linear(H, H) |
|
|
|
self.v1 = nn.Linear(H, H) |
|
self.v2 = nn.Linear(H, H) |
|
self.w = nn.Linear(H, H) |
|
|
|
def forward(self, x): |
|
x = self.w(x) |
|
x = x + torch.relu(self.v1(torch.relu(self.u1(x)))) |
|
return x + torch.relu(self.v2(torch.relu(self.u2(x)))) |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, H, out=None): |
|
super().__init__() |
|
out = out or H |
|
self.mlp = nn.Sequential( |
|
nn.Linear(H, H), |
|
nn.ReLU(), |
|
nn.Linear(H, H), |
|
nn.ReLU(), |
|
nn.Linear(H, out), |
|
) |
|
|
|
def forward(self, x): |
|
return self.mlp(x) |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, tokenizer, model_name_or_path="roberta-base", **kwargs): |
|
super().__init__() |
|
self.encoder = AutoModel.from_pretrained(model_name_or_path) |
|
self.encoder.resize_token_embeddings(len(tokenizer)) |
|
self.dim = self.encoder.config.hidden_size |
|
|
|
@property |
|
def device(self): |
|
return self.encoder.device |
|
|
|
def forward(self, **inputs): |
|
model_inputs = { |
|
k: inputs[k].to(self.device) |
|
for k in ("input_ids", "attention_mask") |
|
} |
|
if inputs.get("token_type_ids", None) is not None: |
|
model_inputs["token_type_ids"] = inputs["token_type_ids"].to( |
|
self.device |
|
) |
|
out = self.encoder(**model_inputs) |
|
emb = out.last_hidden_state[:, 0] |
|
return emb |
|
|
|
|
|
class PrefixDecoder(nn.Module): |
|
def __init__( |
|
self, |
|
tokenizer, |
|
model_name_or_path="gpt2", |
|
prefix_length=1, |
|
ffn="res", |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.decoder = AutoModelForCausalLM.from_pretrained(model_name_or_path) |
|
self.hidden_dim = D = self.decoder.config.n_embd |
|
self.num_layers = L = self.decoder.config.n_layer |
|
self.num_heads = H = self.decoder.config.n_head |
|
self.prefix_length = K = prefix_length |
|
self.lin1 = nn.Linear(D, D * 2) |
|
self.z_size = D * L * K * 2 |
|
if ffn == "res": |
|
self.mlp = nn.Sequential(Res(D), nn.Linear(D, self.z_size)) |
|
else: |
|
self.mlp = MLP(D, self.z_size) |
|
|
|
def get_prefix(self, z): |
|
B = z.shape[0] |
|
D, L, H, K = ( |
|
self.hidden_dim, |
|
self.num_layers, |
|
self.num_heads, |
|
self.prefix_length, |
|
) |
|
z_up = self.mlp(z).reshape(B, H, K, D // H, L, 2) |
|
keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1)) |
|
layers = tuple( |
|
[ |
|
(k.squeeze(-1), v.squeeze(-1)) |
|
for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1)) |
|
] |
|
) |
|
return layers |
|
|
|
def forward(self, z, **inputs): |
|
B = z.shape[0] |
|
D, L, H, K = ( |
|
self.hidden_dim, |
|
self.num_layers, |
|
self.num_heads, |
|
self.prefix_length, |
|
) |
|
z_up = self.mlp(z).reshape(B, H, K, D // H, L, 2) |
|
keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1)) |
|
layers = tuple( |
|
[ |
|
(k.squeeze(-1), v.squeeze(-1)) |
|
for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1)) |
|
] |
|
) |
|
input_ids = inputs["input_ids"].to(z.device) |
|
attention_mask = inputs["attention_mask"].to(z.device) |
|
attention_mask = torch.cat( |
|
[torch.ones(B, K, dtype=bool, device=z.device), attention_mask], |
|
1, |
|
) |
|
out = self.decoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
past_key_values=layers, |
|
) |
|
return out |
|
|
|
|
|
def get_inputs( |
|
inputs, prefix, keys=["input_ids", "attention_mask", "token_type_ids"] |
|
): |
|
return {k: inputs.get(f"{prefix}{k}", None) for k in keys} |
|
|
|
|
|
class VAE(nn.Module): |
|
def __init__(self, encoder, decoder, beta=1.0, do_sample=True, **kwargs): |
|
super().__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self.beta = beta |
|
D = decoder.hidden_dim |
|
self.lin = nn.Linear(D, D * 2) |
|
self.do_sample = do_sample |
|
|
|
@property |
|
def device(self): |
|
return self.encoder.device |
|
|
|
def get_z(self, sample=True, **inputs): |
|
enc = self.encoder(**get_inputs(inputs, "enc_")) |
|
B, D = enc.shape |
|
mu, logvar = ( |
|
t.squeeze(-1) for t in self.lin(enc).view(B, D, 2).chunk(2, -1) |
|
) |
|
qz = Normal(mu, logvar.exp()) |
|
pz = Normal(torch.zeros_like(mu[0]), torch.ones_like(mu[0])) |
|
kl = torch.distributions.kl_divergence(qz, pz).sum(-1) |
|
if sample: |
|
z = qz.rsample() |
|
else: |
|
z = mu |
|
return z, kl |
|
|
|
def forward(self, **inputs): |
|
z, kl = self.get_z(sample=self.do_sample, **inputs) |
|
out = self.decoder(z, **get_inputs(inputs, "dec_")) |
|
out["kl"] = kl |
|
return out |
|
|
|
|
|
class AAE(nn.Module): |
|
def __init__(self, encoder, decoder, _lambda=1.0, word_drop=None, **kwargs): |
|
super().__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self._lambda = _lambda |
|
dim = decoder.hidden_dim |
|
self.D = nn.Sequential( |
|
nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid() |
|
) |
|
self.word_drop = word_drop |
|
|
|
@property |
|
def device(self): |
|
return self.encoder.device |
|
|
|
def get_z(self, **inputs): |
|
if self.word_drop is not None: |
|
m = inputs["enc_attention_mask"] |
|
b = torch.rand_like(m.float()) > self.word_drop |
|
inputs["enc_attention_mask"] = m & b |
|
return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
|
def loss_adv(self, z): |
|
|
|
zn = torch.randn_like(z) |
|
zeros = torch.zeros(len(z), 1, device=z.device) |
|
ones = torch.ones(len(z), 1, device=z.device) |
|
loss_d = F.binary_cross_entropy( |
|
self.D(z.detach()), zeros, reduction="none" |
|
) + F.binary_cross_entropy(self.D(zn), ones, reduction="none") |
|
adv = F.binary_cross_entropy(self.D(z), ones, reduction="none") |
|
return loss_d, adv |
|
|
|
def forward(self, **inputs): |
|
z, _ = self.get_z(**inputs) |
|
out = self.decoder(z, **get_inputs(inputs, "dec_")) |
|
b, n, _ = out["logits"].shape |
|
log_probs = out["logits"].log_softmax(-1) |
|
log_probs = torch.gather( |
|
log_probs[:, :-1], |
|
-1, |
|
inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
|
).squeeze(-1) |
|
log_probs = log_probs.masked_fill( |
|
~inputs["dec_attention_mask"][:, 1:], 0 |
|
) |
|
out["l_rec"] = -log_probs.sum(-1) |
|
out["loss_d"], out["adv"] = self.loss_adv(z) |
|
return out |
|
|
|
|
|
class AE(nn.Module): |
|
def __init__(self, encoder, decoder, **kwargs): |
|
super().__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
dim = decoder.hidden_dim |
|
self.D = nn.Sequential( |
|
nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid() |
|
) |
|
|
|
@property |
|
def device(self): |
|
return self.encoder.device |
|
|
|
def get_z(self, **inputs): |
|
return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
|
def step(self, **inputs): |
|
z, _ = self.get_z(**inputs) |
|
out = self.decoder(z, **get_inputs(inputs, "dec_")) |
|
b, n, _ = out["logits"].shape |
|
log_probs = out["logits"].log_softmax(-1) |
|
log_probs = torch.gather( |
|
log_probs[:, :-1], |
|
-1, |
|
inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
|
).squeeze(-1) |
|
log_probs = log_probs.masked_fill( |
|
~inputs["dec_attention_mask"][:, 1:], 0 |
|
) |
|
out["loss_r"] = -log_probs.sum(-1) |
|
return z, out |
|
|
|
def forward(self, **inputs): |
|
z, out = self.step(**inputs) |
|
out["loss_c"] = torch.zeros_like(out["loss_r"]) |
|
return out |
|
|
|
|
|
class CDAE(nn.Module): |
|
def __init__( |
|
self, encoder, decoder, _lambda=1.0, word_drop=None, tau=1.0, **kwargs |
|
): |
|
super().__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self._lambda = _lambda |
|
dim = decoder.hidden_dim |
|
self.D = nn.Sequential( |
|
nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid() |
|
) |
|
self.word_drop = word_drop |
|
self.tau = tau |
|
|
|
@property |
|
def device(self): |
|
return self.encoder.device |
|
|
|
def do_mask(self, **inputs): |
|
m = inputs["enc_attention_mask"] |
|
b = torch.rand_like(m.float()) > self.word_drop |
|
inputs["enc_attention_mask"] = m & b |
|
|
|
B, N = inputs["dec_attention_mask"].shape |
|
_, M = m.shape |
|
m2 = inputs["dec_attention_mask"] |
|
if N <= M: |
|
b2 = b[:, :N] |
|
else: |
|
b_ = torch.rand((B, N - M), device=b.device) > self.word_drop |
|
b2 = torch.cat([b, b_], -1) |
|
inputs["dec_attention_mask"] = m2 & b2 |
|
|
|
def get_z(self, **inputs): |
|
return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
|
def step(self, **inputs): |
|
z, _ = self.get_z(**inputs) |
|
out = self.decoder(z, **get_inputs(inputs, "dec_")) |
|
b, n, _ = out["logits"].shape |
|
log_probs = out["logits"].log_softmax(-1) |
|
log_probs = torch.gather( |
|
log_probs[:, :-1], |
|
-1, |
|
inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
|
).squeeze(-1) |
|
log_probs = log_probs.masked_fill( |
|
~inputs["dec_attention_mask"][:, 1:], 0 |
|
) |
|
out["loss_r"] = -log_probs.sum(-1) |
|
return z, out |
|
|
|
def loss_c(self, z, z2): |
|
scores = -(torch.cdist(z, z2) ** 2) |
|
log_probs = (scores / self.tau).log_softmax(-1) |
|
loss = -torch.diagonal(log_probs) |
|
return loss |
|
|
|
def forward(self, **inputs): |
|
z, out = self.step(**inputs) |
|
self.do_mask(**inputs) |
|
z_, out_ = self.step(**inputs) |
|
out["loss_r"] = out["loss_r"] + out_["loss_r"] |
|
out["loss_c"] = self.loss_c(z, z_) |
|
return out |
|
|
|
|
|
def run_aae_epoch( |
|
model, |
|
batches, |
|
opt, |
|
optD, |
|
num_samples=1, |
|
lambda_adv=1.0, |
|
desc="", |
|
notebook=True, |
|
): |
|
losses = {k: [] for k in ("l_rec", "adv", "loss_d")} |
|
t = ( |
|
tqdm_notebook(batches, desc=desc) |
|
if notebook |
|
else tqdm(batches, desc=desc) |
|
) |
|
for batch in t: |
|
model_inputs = { |
|
k: v.to(model.device) |
|
for k, v in batch.items() |
|
if type(v) == torch.Tensor |
|
} |
|
out = model(**model_inputs) |
|
loss = (out["l_rec"] + lambda_adv * out["adv"]).sum() |
|
opt.zero_grad() |
|
loss.backward() |
|
opt.step() |
|
|
|
loss_d = out["loss_d"].sum() |
|
optD.zero_grad() |
|
loss_d.backward() |
|
optD.step() |
|
|
|
d = {} |
|
for k in ("l_rec", "adv", "loss_d"): |
|
d[k] = out[k].mean().item() |
|
losses[k].append(out[k].detach().cpu().numpy()) |
|
t.set_postfix(d) |
|
return {k: np.concatenate(v, 0) for k, v in losses.items()} |
|
|
|
|
|
class GAE(nn.Module): |
|
def __init__(self, encoder, decoder, tau=0.05, **kwargs): |
|
super().__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self.tau = tau |
|
|
|
@property |
|
def device(self): |
|
return self.encoder.device |
|
|
|
def get_z(self, **inputs): |
|
return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
|
def loss_c(self, z, z2): |
|
scores = F.normalize(z, dim=-1) @ F.normalize(z2, dim=-1).T |
|
log_probs = (scores / self.tau).log_softmax(-1) |
|
loss = -torch.diagonal(log_probs) |
|
return loss |
|
|
|
def forward(self, **inputs): |
|
z, _ = self.get_z(**inputs) |
|
out = self.decoder(z, **get_inputs(inputs, "dec_")) |
|
b, n, _ = out["logits"].shape |
|
log_probs = out["logits"].log_softmax(-1) |
|
log_probs = torch.gather( |
|
log_probs[:, :-1], |
|
-1, |
|
inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
|
).squeeze(-1) |
|
log_probs = log_probs.masked_fill( |
|
~inputs["dec_attention_mask"][:, 1:], 0 |
|
) |
|
out["loss_r"] = -log_probs.sum(-1) |
|
out["loss_c"] = self.loss_c(z) |
|
return out |
|
|
|
|
|
class CAE(nn.Module): |
|
def __init__(self, encoder, decoder, tau=0.05, **kwargs): |
|
super().__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self.tau = tau |
|
|
|
@property |
|
def device(self): |
|
return self.encoder.device |
|
|
|
def get_z(self, **inputs): |
|
return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
|
def loss_c(self, z, z2): |
|
scores = F.normalize(z, dim=-1) @ F.normalize(z2, dim=-1).T |
|
log_probs = (scores / self.tau).log_softmax(-1) |
|
loss = -torch.diagonal(log_probs) |
|
return loss |
|
|
|
def forward(self, **inputs): |
|
z, _ = self.get_z(**inputs) |
|
with torch.no_grad(): |
|
z2, _ = self.get_z(**inputs) |
|
out = self.decoder(z, **get_inputs(inputs, "dec_")) |
|
b, n, _ = out["logits"].shape |
|
log_probs = out["logits"].log_softmax(-1) |
|
log_probs = torch.gather( |
|
log_probs[:, :-1], |
|
-1, |
|
inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
|
).squeeze(-1) |
|
log_probs = log_probs.masked_fill( |
|
~inputs["dec_attention_mask"][:, 1:], 0 |
|
) |
|
out["loss_r"] = -log_probs.sum(-1) |
|
out["loss_c"] = self.loss_c(z, z2) |
|
return out |
|
|
|
|
|
def run_cae_epoch( |
|
model, |
|
batches, |
|
opt, |
|
num_samples=1, |
|
lambda_c=1.0, |
|
desc="", |
|
notebook=True, |
|
): |
|
losses = {k: [] for k in ("loss_r", "loss_c")} |
|
t = ( |
|
tqdm_notebook(batches, desc=desc) |
|
if notebook |
|
else tqdm(batches, desc=desc) |
|
) |
|
model.train() |
|
for batch in t: |
|
model_inputs = { |
|
k: v.to(model.device) |
|
for k, v in batch.items() |
|
if type(v) == torch.Tensor |
|
} |
|
out = model(**model_inputs) |
|
loss = (out["loss_r"] + lambda_c * out["loss_c"]).sum() |
|
opt.zero_grad() |
|
loss.backward() |
|
opt.step() |
|
d = {} |
|
for k in ("loss_r", "loss_c"): |
|
d[k] = out[k].mean().item() |
|
losses[k].append(out[k].detach().cpu().numpy()) |
|
t.set_postfix(d) |
|
return {k: np.concatenate(v, 0) for k, v in losses.items()} |
|
|
|
|
|
def batch_kl(l1, s1, l2=None, s2=None): |
|
|
|
return |
|
|
|
|
|
class SubpopCondAE(nn.Module): |
|
def __init__( |
|
self, |
|
encoder, |
|
decoder, |
|
num_labels, |
|
sublabels=4, |
|
tau=0.05, |
|
disc_loss=True, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self.dim = dim = decoder.hidden_dim |
|
self.locs = nn.Parameter(torch.randn(num_labels * sublabels, dim)) |
|
self.log_scales = nn.Parameter(torch.zeros(num_labels * sublabels, dim)) |
|
self.num_labels = num_labels |
|
self.sublabels = sublabels |
|
self.L = num_labels * sublabels |
|
self.tau = tau |
|
self.disc_loss = disc_loss |
|
|
|
@property |
|
def device(self): |
|
return self.encoder.device |
|
|
|
def get_z(self, **inputs): |
|
return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
|
def loss_c(self, z, **inputs): |
|
scores = [] |
|
for i in range(self.L): |
|
dist = Independent( |
|
Normal(loc=self.locs[i], scale=self.log_scales[i].exp()), 1 |
|
) |
|
scores.append(dist.log_prob(z)) |
|
B = z.shape[0] |
|
sub_log_probs = torch.stack(scores, -1) |
|
if self.disc_loss: |
|
sub_log_probs = sub_log_probs.log_softmax(-1) |
|
log_probs = sub_log_probs.view( |
|
B, self.num_labels, self.num_sublabels |
|
).logsumexp(-1) |
|
loss = F.nll_loss(log_probs, inputs["label"], reduction="none") |
|
acc = log_probs.argmax(-1) == inputs["label"] |
|
return { |
|
"loss_c": loss, |
|
"log_probs": log_probs, |
|
"sub_log_probs": sub_log_probs, |
|
"acc": acc.float(), |
|
} |
|
|
|
def get_kl(self): |
|
p = MultivariateNormal( |
|
torch.zeros(self.dim, device=self.device), |
|
torch.eye(self.dim, device=self.device), |
|
) |
|
kl = 0 |
|
for i in range(self.L): |
|
q = MultivariateNormal( |
|
self.locs[i], torch.diag(self.log_scales[i].exp()) |
|
) |
|
kl += torch.distributions.kl_divergence(q, p) |
|
return kl |
|
|
|
def forward(self, **inputs): |
|
z, _ = self.get_z(**inputs) |
|
out = self.decoder(z, **get_inputs(inputs, "dec_")) |
|
b, n, _ = out["logits"].shape |
|
log_probs = out["logits"].log_softmax(-1) |
|
log_probs = torch.gather( |
|
log_probs[:, :-1], |
|
-1, |
|
inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
|
).squeeze(-1) |
|
log_probs = log_probs.masked_fill( |
|
~inputs["dec_attention_mask"][:, 1:], 0 |
|
) |
|
out["loss_r"] = -log_probs.sum(-1) |
|
out_c = self.loss_c(z, **inputs) |
|
for k, v in out_c.items(): |
|
out[k] = v |
|
out["kl"] = self.get_kl().unsqueeze(0) |
|
return out |
|
|
|
|
|
def gaussian_prob_product(m1, s1, m2, s2, rho=1.0): |
|
|
|
s1_inv = 1 / s1 |
|
s2_inv = 1 / s2 |
|
s_hat = 1 / (s1 + s2) |
|
m_hat = s1_inv * s1 + s2_inv * s2 |
|
dim = m1.shape[-1] |
|
return ( |
|
((2 * math.pi) ** ((1 - 2 * rho) * dim / 2)) |
|
* (rho ** (-dim / 2)) |
|
* torch.sqrt(s_hat.prod(-1)) |
|
* ((s1.prod(-1) * s2.prod(-1)) ** (-rho / 2)) |
|
* torch.exp( |
|
-(1 / rho) |
|
* ( |
|
m1 @ (s1_inv * m1).T |
|
+ m2 @ (s2_inv * m2).T |
|
- m_hat @ (s_hat * m_hat).T |
|
) |
|
) |
|
) |
|
|
|
|
|
class CondAE(nn.Module): |
|
def __init__( |
|
self, |
|
encoder, |
|
decoder, |
|
num_labels, |
|
logdet=False, |
|
l2_reg=False, |
|
disc_loss=True, |
|
tau=0.05, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self.dim = dim = decoder.hidden_dim |
|
self.locs = nn.Parameter(torch.randn(num_labels, dim)) |
|
self.log_scales = nn.Parameter(torch.zeros(num_labels, dim)) |
|
self.num_labels = num_labels |
|
self.tau = tau |
|
self.logdet = logdet |
|
self.l2_reg = l2_reg |
|
self.disc_loss = disc_loss |
|
|
|
@property |
|
def device(self): |
|
return self.encoder.device |
|
|
|
def get_z(self, **inputs): |
|
return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
|
def loss_c(self, z, **inputs): |
|
scores = [] |
|
for i in range(self.num_labels): |
|
dist = Independent( |
|
Normal(loc=self.locs[i], scale=self.log_scales[i].exp()), 1 |
|
) |
|
scores.append(dist.log_prob(z)) |
|
log_probs = torch.stack(scores, -1) |
|
if self.disc_loss: |
|
log_probs = log_probs.log_softmax(-1) |
|
loss = F.nll_loss(log_probs, inputs["label"], reduction="none") |
|
acc = log_probs.argmax(-1) == inputs["label"] |
|
return {"loss_c": loss, "log_probs": log_probs, "acc": acc.float()} |
|
|
|
def get_kl(self): |
|
p = MultivariateNormal( |
|
torch.zeros(self.dim, device=self.device), |
|
torch.eye(self.dim, device=self.device), |
|
) |
|
kl = 0 |
|
for i in range(self.num_labels): |
|
q = MultivariateNormal( |
|
self.locs[i], torch.diag(self.log_scales[i].exp()) |
|
) |
|
kl += torch.distributions.kl_divergence(q, p) |
|
if self.logdet: |
|
K = torch.exp(-torch.cdist(self.locs, self.locs) ** 2) |
|
kl += torch.logdet(K) |
|
elif self.l2_reg: |
|
K = torch.exp(-torch.cdist(self.locs, self.locs) ** 2) |
|
kl += torch.log( |
|
torch.linalg.norm(K / K.shape[0], dim=(-2, -1)) ** 2 |
|
).sum() |
|
return kl |
|
|
|
def forward(self, **inputs): |
|
z, _ = self.get_z(**inputs) |
|
out = self.decoder(z, **get_inputs(inputs, "dec_")) |
|
b, n, _ = out["logits"].shape |
|
log_probs = out["logits"].log_softmax(-1) |
|
log_probs = torch.gather( |
|
log_probs[:, :-1], |
|
-1, |
|
inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
|
).squeeze(-1) |
|
log_probs = log_probs.masked_fill( |
|
~inputs["dec_attention_mask"][:, 1:], 0 |
|
) |
|
out["loss_r"] = -log_probs.sum(-1) |
|
out_c = self.loss_c(z, **inputs) |
|
for k, v in out_c.items(): |
|
out[k] = v |
|
out["kl"] = self.get_kl().unsqueeze(0) |
|
return out |
|
|
|
|
|
class BasicCondAE(nn.Module): |
|
def __init__(self, encoder, decoder, num_labels, tau=0.05, **kwargs): |
|
super().__init__() |
|
self.encoder = encoder |
|
self.decoder = decoder |
|
self.dim = dim = decoder.hidden_dim |
|
self.linear = nn.Linear(dim, num_labels) |
|
self.num_labels = num_labels |
|
self.tau = tau |
|
|
|
@property |
|
def device(self): |
|
return self.encoder.device |
|
|
|
def get_z(self, **inputs): |
|
return self.encoder(**get_inputs(inputs, "enc_")), None |
|
|
|
def loss_c(self, z, **inputs): |
|
log_probs = self.linear(z).log_softmax(-1) |
|
loss = F.nll_loss(log_probs, inputs["label"], reduction="none") |
|
acc = log_probs.argmax(-1) == inputs["label"] |
|
return {"loss_c": loss, "log_probs": log_probs, "acc": acc.float()} |
|
|
|
def forward(self, **inputs): |
|
z, _ = self.get_z(**inputs) |
|
out = self.decoder(z, **get_inputs(inputs, "dec_")) |
|
b, n, _ = out["logits"].shape |
|
log_probs = out["logits"].log_softmax(-1) |
|
log_probs = torch.gather( |
|
log_probs[:, :-1], |
|
-1, |
|
inputs["dec_input_ids"][:, 1:].unsqueeze(-1), |
|
).squeeze(-1) |
|
log_probs = log_probs.masked_fill( |
|
~inputs["dec_attention_mask"][:, 1:], 0 |
|
) |
|
out["loss_r"] = -log_probs.sum(-1) |
|
out_c = self.loss_c(z, **inputs) |
|
for k, v in out_c.items(): |
|
out[k] = v |
|
out["kl"] = torch.zeros_like(out["loss_r"]) |
|
return out |
|
|
|
|
|
def run_cond_ae_epoch( |
|
model, |
|
batches, |
|
opt, |
|
num_samples=1, |
|
lambda_c=1.0, |
|
lambda_r=1.0, |
|
beta=1.0, |
|
desc="", |
|
notebook=True, |
|
): |
|
losses = {k: [] for k in ("loss_r", "loss_c", "kl", "acc")} |
|
t = ( |
|
tqdm_notebook(batches, desc=desc) |
|
if notebook |
|
else tqdm(batches, desc=desc) |
|
) |
|
model.train() |
|
for batch in t: |
|
model_inputs = { |
|
k: v.to(model.device) |
|
for k, v in batch.items() |
|
if type(v) == torch.Tensor |
|
} |
|
out = model(**model_inputs) |
|
loss = ( |
|
lambda_r * out["loss_r"] + lambda_c * out["loss_c"] |
|
).sum() + beta * out["kl"].sum() |
|
opt.zero_grad() |
|
loss.backward() |
|
opt.step() |
|
d = {} |
|
for k in ("loss_r", "loss_c", "kl", "acc"): |
|
d[k] = out[k].mean().item() |
|
losses[k].append(out[k].detach().cpu().numpy()) |
|
t.set_postfix(d) |
|
return {k: np.concatenate(v, 0) for k, v in losses.items()} |
|
|
|
|
|
def run_cond_ae_eval( |
|
model, |
|
batches, |
|
lambda_c=1.0, |
|
beta=1.0, |
|
desc="", |
|
notebook=True, |
|
): |
|
losses = {k: [] for k in ("loss_r", "loss_c", "kl", "acc")} |
|
t = ( |
|
tqdm_notebook(batches, desc=desc) |
|
if notebook |
|
else tqdm(batches, desc=desc) |
|
) |
|
model.eval() |
|
for batch in t: |
|
model_inputs = { |
|
k: v.to(model.device) |
|
for k, v in batch.items() |
|
if type(v) == torch.Tensor |
|
} |
|
with torch.no_grad(): |
|
out = model(**model_inputs) |
|
loss = ( |
|
out["loss_r"] + lambda_c * out["loss_c"] |
|
).sum() + beta * out["kl"].sum() |
|
d = {} |
|
for k in ("loss_r", "loss_c", "kl", "acc"): |
|
d[k] = out[k].mean().item() |
|
losses[k].append(out[k].detach().cpu().numpy()) |
|
t.set_postfix(d) |
|
return {k: np.concatenate(v, 0) for k, v in losses.items()} |
|
|
|
|
|
def generate( |
|
model, |
|
tokenizer, |
|
batch=None, |
|
z=None, |
|
do_sample=False, |
|
max_length=128, |
|
**kwargs, |
|
): |
|
if z is None: |
|
with torch.no_grad(): |
|
z, _ = model.get_z(sample=False, **batch) |
|
B, D = z.shape |
|
else: |
|
z = torch.tensor(z, device=model.device) |
|
B, D = z.shape |
|
D, L, H, K = ( |
|
model.decoder.hidden_dim, |
|
model.decoder.num_layers, |
|
model.decoder.num_heads, |
|
model.decoder.prefix_length, |
|
) |
|
z_up = model.decoder.mlp(z).reshape(B, H, K, D // H, L, 2) |
|
keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1)) |
|
layers = tuple( |
|
[ |
|
(k.squeeze(-1), v.squeeze(-1)) |
|
for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1)) |
|
] |
|
) |
|
output = model.decoder.decoder.generate( |
|
input_ids=torch.tensor( |
|
[[tokenizer.bos_token_id]] * B, device=model.device |
|
), |
|
attention_mask=torch.ones((B, K + 1), device=model.device), |
|
past=layers, |
|
do_sample=do_sample, |
|
max_length=max_length, |
|
**kwargs, |
|
) |
|
lst = tokenizer.batch_decode(output[:, 1:]) |
|
return [l.replace("<|endoftext|>", "") for l in lst] |
|
|
|
|
|
def get_embeddings(model, batches, desc="", notebook=True): |
|
out = [] |
|
t = ( |
|
tqdm_notebook(batches, desc=desc) |
|
if notebook |
|
else tqdm(batches, desc=desc) |
|
) |
|
model.eval() |
|
for batch in t: |
|
with torch.no_grad(): |
|
model_inputs = { |
|
k: v.to(model.device) |
|
for k, v in batch.items() |
|
if type(v) == torch.Tensor |
|
} |
|
z, _ = model.get_z(sample=False, **model_inputs) |
|
out.append(z.detach().cpu().numpy()) |
|
return np.concatenate(out, 0) |
|
|
|
|
|
def interpolate(model, tokenizer, a, b, num_steps=10, **kwargs): |
|
z = np.stack( |
|
[l * b + (1 - l) * a for l in np.linspace(0, 1.0, num_steps)], 0 |
|
) |
|
return generate(model, tokenizer, z=z, **kwargs) |
|
|