import math import numpy as np import torch from torch import nn from torch.distributions import Independent, Normal, MultivariateNormal import torch.nn.functional as F from transformers import AutoModel, AutoModelForCausalLM from tqdm import tqdm from tqdm.notebook import tqdm as tqdm_notebook class Res(nn.Module): def __init__(self, H): super().__init__() self.u1 = nn.Linear(H, H) self.u2 = nn.Linear(H, H) self.v1 = nn.Linear(H, H) self.v2 = nn.Linear(H, H) self.w = nn.Linear(H, H) def forward(self, x): x = self.w(x) x = x + torch.relu(self.v1(torch.relu(self.u1(x)))) return x + torch.relu(self.v2(torch.relu(self.u2(x)))) class MLP(nn.Module): def __init__(self, H, out=None): super().__init__() out = out or H self.mlp = nn.Sequential( nn.Linear(H, H), nn.ReLU(), nn.Linear(H, H), nn.ReLU(), nn.Linear(H, out), ) def forward(self, x): return self.mlp(x) class Encoder(nn.Module): def __init__(self, tokenizer, model_name_or_path="roberta-base", **kwargs): super().__init__() self.encoder = AutoModel.from_pretrained(model_name_or_path) self.encoder.resize_token_embeddings(len(tokenizer)) self.dim = self.encoder.config.hidden_size @property def device(self): return self.encoder.device def forward(self, **inputs): model_inputs = { k: inputs[k].to(self.device) for k in ("input_ids", "attention_mask") } if inputs.get("token_type_ids", None) is not None: model_inputs["token_type_ids"] = inputs["token_type_ids"].to( self.device ) out = self.encoder(**model_inputs) emb = out.last_hidden_state[:, 0] return emb class PrefixDecoder(nn.Module): def __init__( self, tokenizer, model_name_or_path="gpt2", prefix_length=1, ffn="res", **kwargs, ): super().__init__() self.decoder = AutoModelForCausalLM.from_pretrained(model_name_or_path) self.hidden_dim = D = self.decoder.config.n_embd self.num_layers = L = self.decoder.config.n_layer self.num_heads = H = self.decoder.config.n_head self.prefix_length = K = prefix_length self.lin1 = nn.Linear(D, D * 2) self.z_size = D * L * K * 2 if ffn == "res": self.mlp = nn.Sequential(Res(D), nn.Linear(D, self.z_size)) else: self.mlp = MLP(D, self.z_size) def get_prefix(self, z): B = z.shape[0] D, L, H, K = ( self.hidden_dim, self.num_layers, self.num_heads, self.prefix_length, ) z_up = self.mlp(z).reshape(B, H, K, D // H, L, 2) keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1)) layers = tuple( [ (k.squeeze(-1), v.squeeze(-1)) for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1)) ] ) return layers def forward(self, z, **inputs): B = z.shape[0] D, L, H, K = ( self.hidden_dim, self.num_layers, self.num_heads, self.prefix_length, ) z_up = self.mlp(z).reshape(B, H, K, D // H, L, 2) keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1)) layers = tuple( [ (k.squeeze(-1), v.squeeze(-1)) for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1)) ] ) input_ids = inputs["input_ids"].to(z.device) attention_mask = inputs["attention_mask"].to(z.device) attention_mask = torch.cat( [torch.ones(B, K, dtype=bool, device=z.device), attention_mask], 1, ) out = self.decoder( input_ids=input_ids, attention_mask=attention_mask, past_key_values=layers, ) return out def get_inputs( inputs, prefix, keys=["input_ids", "attention_mask", "token_type_ids"] ): return {k: inputs.get(f"{prefix}{k}", None) for k in keys} class VAE(nn.Module): def __init__(self, encoder, decoder, beta=1.0, do_sample=True, **kwargs): super().__init__() self.encoder = encoder self.decoder = decoder self.beta = beta D = decoder.hidden_dim self.lin = nn.Linear(D, D * 2) self.do_sample = do_sample @property def device(self): return self.encoder.device def get_z(self, sample=True, **inputs): enc = self.encoder(**get_inputs(inputs, "enc_")) B, D = enc.shape mu, logvar = ( t.squeeze(-1) for t in self.lin(enc).view(B, D, 2).chunk(2, -1) ) qz = Normal(mu, logvar.exp()) pz = Normal(torch.zeros_like(mu[0]), torch.ones_like(mu[0])) kl = torch.distributions.kl_divergence(qz, pz).sum(-1) if sample: z = qz.rsample() else: z = mu return z, kl def forward(self, **inputs): z, kl = self.get_z(sample=self.do_sample, **inputs) out = self.decoder(z, **get_inputs(inputs, "dec_")) out["kl"] = kl return out class AAE(nn.Module): def __init__(self, encoder, decoder, _lambda=1.0, word_drop=None, **kwargs): super().__init__() self.encoder = encoder self.decoder = decoder self._lambda = _lambda dim = decoder.hidden_dim self.D = nn.Sequential( nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid() ) self.word_drop = word_drop @property def device(self): return self.encoder.device def get_z(self, **inputs): if self.word_drop is not None: m = inputs["enc_attention_mask"] b = torch.rand_like(m.float()) > self.word_drop inputs["enc_attention_mask"] = m & b return self.encoder(**get_inputs(inputs, "enc_")), None def loss_adv(self, z): # https://github.com/shentianxiao/text-autoencoders zn = torch.randn_like(z) zeros = torch.zeros(len(z), 1, device=z.device) ones = torch.ones(len(z), 1, device=z.device) loss_d = F.binary_cross_entropy( self.D(z.detach()), zeros, reduction="none" ) + F.binary_cross_entropy(self.D(zn), ones, reduction="none") adv = F.binary_cross_entropy(self.D(z), ones, reduction="none") return loss_d, adv def forward(self, **inputs): z, _ = self.get_z(**inputs) out = self.decoder(z, **get_inputs(inputs, "dec_")) b, n, _ = out["logits"].shape log_probs = out["logits"].log_softmax(-1) log_probs = torch.gather( log_probs[:, :-1], -1, inputs["dec_input_ids"][:, 1:].unsqueeze(-1), ).squeeze(-1) log_probs = log_probs.masked_fill( ~inputs["dec_attention_mask"][:, 1:], 0 ) out["l_rec"] = -log_probs.sum(-1) out["loss_d"], out["adv"] = self.loss_adv(z) return out class AE(nn.Module): def __init__(self, encoder, decoder, **kwargs): super().__init__() self.encoder = encoder self.decoder = decoder dim = decoder.hidden_dim self.D = nn.Sequential( nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid() ) @property def device(self): return self.encoder.device def get_z(self, **inputs): return self.encoder(**get_inputs(inputs, "enc_")), None def step(self, **inputs): z, _ = self.get_z(**inputs) out = self.decoder(z, **get_inputs(inputs, "dec_")) b, n, _ = out["logits"].shape log_probs = out["logits"].log_softmax(-1) log_probs = torch.gather( log_probs[:, :-1], -1, inputs["dec_input_ids"][:, 1:].unsqueeze(-1), ).squeeze(-1) log_probs = log_probs.masked_fill( ~inputs["dec_attention_mask"][:, 1:], 0 ) out["loss_r"] = -log_probs.sum(-1) return z, out def forward(self, **inputs): z, out = self.step(**inputs) out["loss_c"] = torch.zeros_like(out["loss_r"]) return out class CDAE(nn.Module): def __init__( self, encoder, decoder, _lambda=1.0, word_drop=None, tau=1.0, **kwargs ): super().__init__() self.encoder = encoder self.decoder = decoder self._lambda = _lambda dim = decoder.hidden_dim self.D = nn.Sequential( nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid() ) self.word_drop = word_drop self.tau = tau @property def device(self): return self.encoder.device def do_mask(self, **inputs): m = inputs["enc_attention_mask"] b = torch.rand_like(m.float()) > self.word_drop inputs["enc_attention_mask"] = m & b B, N = inputs["dec_attention_mask"].shape _, M = m.shape m2 = inputs["dec_attention_mask"] if N <= M: b2 = b[:, :N] else: b_ = torch.rand((B, N - M), device=b.device) > self.word_drop b2 = torch.cat([b, b_], -1) inputs["dec_attention_mask"] = m2 & b2 def get_z(self, **inputs): return self.encoder(**get_inputs(inputs, "enc_")), None def step(self, **inputs): z, _ = self.get_z(**inputs) out = self.decoder(z, **get_inputs(inputs, "dec_")) b, n, _ = out["logits"].shape log_probs = out["logits"].log_softmax(-1) log_probs = torch.gather( log_probs[:, :-1], -1, inputs["dec_input_ids"][:, 1:].unsqueeze(-1), ).squeeze(-1) log_probs = log_probs.masked_fill( ~inputs["dec_attention_mask"][:, 1:], 0 ) out["loss_r"] = -log_probs.sum(-1) return z, out def loss_c(self, z, z2): scores = -(torch.cdist(z, z2) ** 2) log_probs = (scores / self.tau).log_softmax(-1) loss = -torch.diagonal(log_probs) return loss def forward(self, **inputs): z, out = self.step(**inputs) self.do_mask(**inputs) z_, out_ = self.step(**inputs) out["loss_r"] = out["loss_r"] + out_["loss_r"] out["loss_c"] = self.loss_c(z, z_) return out def run_aae_epoch( model, batches, opt, optD, num_samples=1, lambda_adv=1.0, desc="", notebook=True, ): losses = {k: [] for k in ("l_rec", "adv", "loss_d")} t = ( tqdm_notebook(batches, desc=desc) if notebook else tqdm(batches, desc=desc) ) for batch in t: model_inputs = { k: v.to(model.device) for k, v in batch.items() if type(v) == torch.Tensor } out = model(**model_inputs) loss = (out["l_rec"] + lambda_adv * out["adv"]).sum() opt.zero_grad() loss.backward() opt.step() loss_d = out["loss_d"].sum() optD.zero_grad() loss_d.backward() optD.step() d = {} for k in ("l_rec", "adv", "loss_d"): d[k] = out[k].mean().item() losses[k].append(out[k].detach().cpu().numpy()) t.set_postfix(d) return {k: np.concatenate(v, 0) for k, v in losses.items()} class GAE(nn.Module): def __init__(self, encoder, decoder, tau=0.05, **kwargs): super().__init__() self.encoder = encoder self.decoder = decoder self.tau = tau @property def device(self): return self.encoder.device def get_z(self, **inputs): return self.encoder(**get_inputs(inputs, "enc_")), None def loss_c(self, z, z2): scores = F.normalize(z, dim=-1) @ F.normalize(z2, dim=-1).T log_probs = (scores / self.tau).log_softmax(-1) loss = -torch.diagonal(log_probs) return loss def forward(self, **inputs): z, _ = self.get_z(**inputs) out = self.decoder(z, **get_inputs(inputs, "dec_")) b, n, _ = out["logits"].shape log_probs = out["logits"].log_softmax(-1) log_probs = torch.gather( log_probs[:, :-1], -1, inputs["dec_input_ids"][:, 1:].unsqueeze(-1), ).squeeze(-1) log_probs = log_probs.masked_fill( ~inputs["dec_attention_mask"][:, 1:], 0 ) out["loss_r"] = -log_probs.sum(-1) out["loss_c"] = self.loss_c(z) return out class CAE(nn.Module): def __init__(self, encoder, decoder, tau=0.05, **kwargs): super().__init__() self.encoder = encoder self.decoder = decoder self.tau = tau @property def device(self): return self.encoder.device def get_z(self, **inputs): return self.encoder(**get_inputs(inputs, "enc_")), None def loss_c(self, z, z2): scores = F.normalize(z, dim=-1) @ F.normalize(z2, dim=-1).T log_probs = (scores / self.tau).log_softmax(-1) loss = -torch.diagonal(log_probs) return loss def forward(self, **inputs): z, _ = self.get_z(**inputs) with torch.no_grad(): z2, _ = self.get_z(**inputs) out = self.decoder(z, **get_inputs(inputs, "dec_")) b, n, _ = out["logits"].shape log_probs = out["logits"].log_softmax(-1) log_probs = torch.gather( log_probs[:, :-1], -1, inputs["dec_input_ids"][:, 1:].unsqueeze(-1), ).squeeze(-1) log_probs = log_probs.masked_fill( ~inputs["dec_attention_mask"][:, 1:], 0 ) out["loss_r"] = -log_probs.sum(-1) out["loss_c"] = self.loss_c(z, z2) return out def run_cae_epoch( model, batches, opt, num_samples=1, lambda_c=1.0, desc="", notebook=True, ): losses = {k: [] for k in ("loss_r", "loss_c")} t = ( tqdm_notebook(batches, desc=desc) if notebook else tqdm(batches, desc=desc) ) model.train() for batch in t: model_inputs = { k: v.to(model.device) for k, v in batch.items() if type(v) == torch.Tensor } out = model(**model_inputs) loss = (out["loss_r"] + lambda_c * out["loss_c"]).sum() opt.zero_grad() loss.backward() opt.step() d = {} for k in ("loss_r", "loss_c"): d[k] = out[k].mean().item() losses[k].append(out[k].detach().cpu().numpy()) t.set_postfix(d) return {k: np.concatenate(v, 0) for k, v in losses.items()} def batch_kl(l1, s1, l2=None, s2=None): # 1/2[log |s1|/|s2| - d + tr[s2^{-1}s1] + (l2 - l1)^{\top} s2^{-1}(l2 - l1)] return class SubpopCondAE(nn.Module): def __init__( self, encoder, decoder, num_labels, sublabels=4, tau=0.05, disc_loss=True, **kwargs, ): super().__init__() self.encoder = encoder self.decoder = decoder self.dim = dim = decoder.hidden_dim self.locs = nn.Parameter(torch.randn(num_labels * sublabels, dim)) self.log_scales = nn.Parameter(torch.zeros(num_labels * sublabels, dim)) self.num_labels = num_labels self.sublabels = sublabels self.L = num_labels * sublabels self.tau = tau self.disc_loss = disc_loss @property def device(self): return self.encoder.device def get_z(self, **inputs): return self.encoder(**get_inputs(inputs, "enc_")), None def loss_c(self, z, **inputs): scores = [] for i in range(self.L): dist = Independent( Normal(loc=self.locs[i], scale=self.log_scales[i].exp()), 1 ) scores.append(dist.log_prob(z)) B = z.shape[0] sub_log_probs = torch.stack(scores, -1) if self.disc_loss: sub_log_probs = sub_log_probs.log_softmax(-1) log_probs = sub_log_probs.view( B, self.num_labels, self.num_sublabels ).logsumexp(-1) loss = F.nll_loss(log_probs, inputs["label"], reduction="none") acc = log_probs.argmax(-1) == inputs["label"] return { "loss_c": loss, "log_probs": log_probs, "sub_log_probs": sub_log_probs, "acc": acc.float(), } def get_kl(self): p = MultivariateNormal( torch.zeros(self.dim, device=self.device), torch.eye(self.dim, device=self.device), ) kl = 0 for i in range(self.L): q = MultivariateNormal( self.locs[i], torch.diag(self.log_scales[i].exp()) ) kl += torch.distributions.kl_divergence(q, p) return kl def forward(self, **inputs): z, _ = self.get_z(**inputs) out = self.decoder(z, **get_inputs(inputs, "dec_")) b, n, _ = out["logits"].shape log_probs = out["logits"].log_softmax(-1) log_probs = torch.gather( log_probs[:, :-1], -1, inputs["dec_input_ids"][:, 1:].unsqueeze(-1), ).squeeze(-1) log_probs = log_probs.masked_fill( ~inputs["dec_attention_mask"][:, 1:], 0 ) out["loss_r"] = -log_probs.sum(-1) out_c = self.loss_c(z, **inputs) for k, v in out_c.items(): out[k] = v out["kl"] = self.get_kl().unsqueeze(0) return out def gaussian_prob_product(m1, s1, m2, s2, rho=1.0): # s1, s2 diagonal s1_inv = 1 / s1 s2_inv = 1 / s2 s_hat = 1 / (s1 + s2) m_hat = s1_inv * s1 + s2_inv * s2 dim = m1.shape[-1] return ( ((2 * math.pi) ** ((1 - 2 * rho) * dim / 2)) * (rho ** (-dim / 2)) * torch.sqrt(s_hat.prod(-1)) * ((s1.prod(-1) * s2.prod(-1)) ** (-rho / 2)) * torch.exp( -(1 / rho) * ( m1 @ (s1_inv * m1).T + m2 @ (s2_inv * m2).T - m_hat @ (s_hat * m_hat).T ) ) ) class CondAE(nn.Module): def __init__( self, encoder, decoder, num_labels, logdet=False, l2_reg=False, disc_loss=True, tau=0.05, **kwargs, ): super().__init__() self.encoder = encoder self.decoder = decoder self.dim = dim = decoder.hidden_dim self.locs = nn.Parameter(torch.randn(num_labels, dim)) self.log_scales = nn.Parameter(torch.zeros(num_labels, dim)) self.num_labels = num_labels self.tau = tau self.logdet = logdet self.l2_reg = l2_reg self.disc_loss = disc_loss @property def device(self): return self.encoder.device def get_z(self, **inputs): return self.encoder(**get_inputs(inputs, "enc_")), None def loss_c(self, z, **inputs): scores = [] for i in range(self.num_labels): dist = Independent( Normal(loc=self.locs[i], scale=self.log_scales[i].exp()), 1 ) scores.append(dist.log_prob(z)) log_probs = torch.stack(scores, -1) if self.disc_loss: log_probs = log_probs.log_softmax(-1) loss = F.nll_loss(log_probs, inputs["label"], reduction="none") acc = log_probs.argmax(-1) == inputs["label"] return {"loss_c": loss, "log_probs": log_probs, "acc": acc.float()} def get_kl(self): p = MultivariateNormal( torch.zeros(self.dim, device=self.device), torch.eye(self.dim, device=self.device), ) kl = 0 for i in range(self.num_labels): q = MultivariateNormal( self.locs[i], torch.diag(self.log_scales[i].exp()) ) kl += torch.distributions.kl_divergence(q, p) if self.logdet: K = torch.exp(-torch.cdist(self.locs, self.locs) ** 2) kl += torch.logdet(K) elif self.l2_reg: K = torch.exp(-torch.cdist(self.locs, self.locs) ** 2) kl += torch.log( torch.linalg.norm(K / K.shape[0], dim=(-2, -1)) ** 2 ).sum() return kl def forward(self, **inputs): z, _ = self.get_z(**inputs) out = self.decoder(z, **get_inputs(inputs, "dec_")) b, n, _ = out["logits"].shape log_probs = out["logits"].log_softmax(-1) log_probs = torch.gather( log_probs[:, :-1], -1, inputs["dec_input_ids"][:, 1:].unsqueeze(-1), ).squeeze(-1) log_probs = log_probs.masked_fill( ~inputs["dec_attention_mask"][:, 1:], 0 ) out["loss_r"] = -log_probs.sum(-1) out_c = self.loss_c(z, **inputs) for k, v in out_c.items(): out[k] = v out["kl"] = self.get_kl().unsqueeze(0) return out class BasicCondAE(nn.Module): def __init__(self, encoder, decoder, num_labels, tau=0.05, **kwargs): super().__init__() self.encoder = encoder self.decoder = decoder self.dim = dim = decoder.hidden_dim self.linear = nn.Linear(dim, num_labels) self.num_labels = num_labels self.tau = tau @property def device(self): return self.encoder.device def get_z(self, **inputs): return self.encoder(**get_inputs(inputs, "enc_")), None def loss_c(self, z, **inputs): log_probs = self.linear(z).log_softmax(-1) loss = F.nll_loss(log_probs, inputs["label"], reduction="none") acc = log_probs.argmax(-1) == inputs["label"] return {"loss_c": loss, "log_probs": log_probs, "acc": acc.float()} def forward(self, **inputs): z, _ = self.get_z(**inputs) out = self.decoder(z, **get_inputs(inputs, "dec_")) b, n, _ = out["logits"].shape log_probs = out["logits"].log_softmax(-1) log_probs = torch.gather( log_probs[:, :-1], -1, inputs["dec_input_ids"][:, 1:].unsqueeze(-1), ).squeeze(-1) log_probs = log_probs.masked_fill( ~inputs["dec_attention_mask"][:, 1:], 0 ) out["loss_r"] = -log_probs.sum(-1) out_c = self.loss_c(z, **inputs) for k, v in out_c.items(): out[k] = v out["kl"] = torch.zeros_like(out["loss_r"]) return out def run_cond_ae_epoch( model, batches, opt, num_samples=1, lambda_c=1.0, lambda_r=1.0, beta=1.0, desc="", notebook=True, ): losses = {k: [] for k in ("loss_r", "loss_c", "kl", "acc")} t = ( tqdm_notebook(batches, desc=desc) if notebook else tqdm(batches, desc=desc) ) model.train() for batch in t: model_inputs = { k: v.to(model.device) for k, v in batch.items() if type(v) == torch.Tensor } out = model(**model_inputs) loss = ( lambda_r * out["loss_r"] + lambda_c * out["loss_c"] ).sum() + beta * out["kl"].sum() opt.zero_grad() loss.backward() opt.step() d = {} for k in ("loss_r", "loss_c", "kl", "acc"): d[k] = out[k].mean().item() losses[k].append(out[k].detach().cpu().numpy()) t.set_postfix(d) return {k: np.concatenate(v, 0) for k, v in losses.items()} def run_cond_ae_eval( model, batches, lambda_c=1.0, beta=1.0, desc="", notebook=True, ): losses = {k: [] for k in ("loss_r", "loss_c", "kl", "acc")} t = ( tqdm_notebook(batches, desc=desc) if notebook else tqdm(batches, desc=desc) ) model.eval() for batch in t: model_inputs = { k: v.to(model.device) for k, v in batch.items() if type(v) == torch.Tensor } with torch.no_grad(): out = model(**model_inputs) loss = ( out["loss_r"] + lambda_c * out["loss_c"] ).sum() + beta * out["kl"].sum() d = {} for k in ("loss_r", "loss_c", "kl", "acc"): d[k] = out[k].mean().item() losses[k].append(out[k].detach().cpu().numpy()) t.set_postfix(d) return {k: np.concatenate(v, 0) for k, v in losses.items()} def generate( model, tokenizer, batch=None, z=None, do_sample=False, max_length=128, **kwargs, ): if z is None: with torch.no_grad(): z, _ = model.get_z(sample=False, **batch) B, D = z.shape else: z = torch.tensor(z, device=model.device) B, D = z.shape D, L, H, K = ( model.decoder.hidden_dim, model.decoder.num_layers, model.decoder.num_heads, model.decoder.prefix_length, ) z_up = model.decoder.mlp(z).reshape(B, H, K, D // H, L, 2) keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1)) layers = tuple( [ (k.squeeze(-1), v.squeeze(-1)) for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1)) ] ) output = model.decoder.decoder.generate( input_ids=torch.tensor( [[tokenizer.bos_token_id]] * B, device=model.device ), attention_mask=torch.ones((B, K + 1), device=model.device), past=layers, do_sample=do_sample, max_length=max_length, **kwargs, ) lst = tokenizer.batch_decode(output[:, 1:]) return [l.replace("<|endoftext|>", "") for l in lst] def get_embeddings(model, batches, desc="", notebook=True): out = [] t = ( tqdm_notebook(batches, desc=desc) if notebook else tqdm(batches, desc=desc) ) model.eval() for batch in t: with torch.no_grad(): model_inputs = { k: v.to(model.device) for k, v in batch.items() if type(v) == torch.Tensor } z, _ = model.get_z(sample=False, **model_inputs) out.append(z.detach().cpu().numpy()) return np.concatenate(out, 0) def interpolate(model, tokenizer, a, b, num_steps=10, **kwargs): z = np.stack( [l * b + (1 - l) * a for l in np.linspace(0, 1.0, num_steps)], 0 ) return generate(model, tokenizer, z=z, **kwargs)