Dan Friedman commited on
Commit
f9cfa84
1 Parent(s): 73ae2a9

Add autoencoder.py

Browse files
Files changed (1) hide show
  1. autoencoder.py +882 -0
autoencoder.py ADDED
@@ -0,0 +1,882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.distributions import Independent, Normal, MultivariateNormal
6
+ import torch.nn.functional as F
7
+
8
+ from transformers import AutoModel, AutoModelForCausalLM
9
+ from tqdm import tqdm
10
+ from tqdm.notebook import tqdm as tqdm_notebook
11
+
12
+
13
+ class Res(nn.Module):
14
+ def __init__(self, H):
15
+ super().__init__()
16
+ self.u1 = nn.Linear(H, H)
17
+ self.u2 = nn.Linear(H, H)
18
+
19
+ self.v1 = nn.Linear(H, H)
20
+ self.v2 = nn.Linear(H, H)
21
+ self.w = nn.Linear(H, H)
22
+
23
+ def forward(self, x):
24
+ x = self.w(x)
25
+ x = x + torch.relu(self.v1(torch.relu(self.u1(x))))
26
+ return x + torch.relu(self.v2(torch.relu(self.u2(x))))
27
+
28
+
29
+ class MLP(nn.Module):
30
+ def __init__(self, H, out=None):
31
+ super().__init__()
32
+ out = out or H
33
+ self.mlp = nn.Sequential(
34
+ nn.Linear(H, H),
35
+ nn.ReLU(),
36
+ nn.Linear(H, H),
37
+ nn.ReLU(),
38
+ nn.Linear(H, out),
39
+ )
40
+
41
+ def forward(self, x):
42
+ return self.mlp(x)
43
+
44
+
45
+ class Encoder(nn.Module):
46
+ def __init__(self, tokenizer, model_name_or_path="roberta-base", **kwargs):
47
+ super().__init__()
48
+ self.encoder = AutoModel.from_pretrained(model_name_or_path)
49
+ self.encoder.resize_token_embeddings(len(tokenizer))
50
+ self.dim = self.encoder.config.hidden_size
51
+
52
+ @property
53
+ def device(self):
54
+ return self.encoder.device
55
+
56
+ def forward(self, **inputs):
57
+ model_inputs = {
58
+ k: inputs[k].to(self.device)
59
+ for k in ("input_ids", "attention_mask")
60
+ }
61
+ if inputs.get("token_type_ids", None) is not None:
62
+ model_inputs["token_type_ids"] = inputs["token_type_ids"].to(
63
+ self.device
64
+ )
65
+ out = self.encoder(**model_inputs)
66
+ emb = out.last_hidden_state[:, 0]
67
+ return emb
68
+
69
+
70
+ class PrefixDecoder(nn.Module):
71
+ def __init__(
72
+ self,
73
+ tokenizer,
74
+ model_name_or_path="gpt2",
75
+ prefix_length=1,
76
+ ffn="res",
77
+ **kwargs,
78
+ ):
79
+ super().__init__()
80
+ self.decoder = AutoModelForCausalLM.from_pretrained(model_name_or_path)
81
+ self.hidden_dim = D = self.decoder.config.n_embd
82
+ self.num_layers = L = self.decoder.config.n_layer
83
+ self.num_heads = H = self.decoder.config.n_head
84
+ self.prefix_length = K = prefix_length
85
+ self.lin1 = nn.Linear(D, D * 2)
86
+ self.z_size = D * L * K * 2
87
+ if ffn == "res":
88
+ self.mlp = nn.Sequential(Res(D), nn.Linear(D, self.z_size))
89
+ else:
90
+ self.mlp = MLP(D, self.z_size)
91
+
92
+ def get_prefix(self, z):
93
+ B = z.shape[0]
94
+ D, L, H, K = (
95
+ self.hidden_dim,
96
+ self.num_layers,
97
+ self.num_heads,
98
+ self.prefix_length,
99
+ )
100
+ z_up = self.mlp(z).reshape(B, H, K, D // H, L, 2)
101
+ keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1))
102
+ layers = tuple(
103
+ [
104
+ (k.squeeze(-1), v.squeeze(-1))
105
+ for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1))
106
+ ]
107
+ )
108
+ return layers
109
+
110
+ def forward(self, z, **inputs):
111
+ B = z.shape[0]
112
+ D, L, H, K = (
113
+ self.hidden_dim,
114
+ self.num_layers,
115
+ self.num_heads,
116
+ self.prefix_length,
117
+ )
118
+ z_up = self.mlp(z).reshape(B, H, K, D // H, L, 2)
119
+ keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1))
120
+ layers = tuple(
121
+ [
122
+ (k.squeeze(-1), v.squeeze(-1))
123
+ for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1))
124
+ ]
125
+ )
126
+ input_ids = inputs["input_ids"].to(z.device)
127
+ attention_mask = inputs["attention_mask"].to(z.device)
128
+ attention_mask = torch.cat(
129
+ [torch.ones(B, K, dtype=bool, device=z.device), attention_mask],
130
+ 1,
131
+ )
132
+ out = self.decoder(
133
+ input_ids=input_ids,
134
+ attention_mask=attention_mask,
135
+ past_key_values=layers,
136
+ )
137
+ return out
138
+
139
+
140
+ def get_inputs(
141
+ inputs, prefix, keys=["input_ids", "attention_mask", "token_type_ids"]
142
+ ):
143
+ return {k: inputs.get(f"{prefix}{k}", None) for k in keys}
144
+
145
+
146
+ class VAE(nn.Module):
147
+ def __init__(self, encoder, decoder, beta=1.0, do_sample=True, **kwargs):
148
+ super().__init__()
149
+ self.encoder = encoder
150
+ self.decoder = decoder
151
+ self.beta = beta
152
+ D = decoder.hidden_dim
153
+ self.lin = nn.Linear(D, D * 2)
154
+ self.do_sample = do_sample
155
+
156
+ @property
157
+ def device(self):
158
+ return self.encoder.device
159
+
160
+ def get_z(self, sample=True, **inputs):
161
+ enc = self.encoder(**get_inputs(inputs, "enc_"))
162
+ B, D = enc.shape
163
+ mu, logvar = (
164
+ t.squeeze(-1) for t in self.lin(enc).view(B, D, 2).chunk(2, -1)
165
+ )
166
+ qz = Normal(mu, logvar.exp())
167
+ pz = Normal(torch.zeros_like(mu[0]), torch.ones_like(mu[0]))
168
+ kl = torch.distributions.kl_divergence(qz, pz).sum(-1)
169
+ if sample:
170
+ z = qz.rsample()
171
+ else:
172
+ z = mu
173
+ return z, kl
174
+
175
+ def forward(self, **inputs):
176
+ z, kl = self.get_z(sample=self.do_sample, **inputs)
177
+ out = self.decoder(z, **get_inputs(inputs, "dec_"))
178
+ out["kl"] = kl
179
+ return out
180
+
181
+
182
+ class AAE(nn.Module):
183
+ def __init__(self, encoder, decoder, _lambda=1.0, word_drop=None, **kwargs):
184
+ super().__init__()
185
+ self.encoder = encoder
186
+ self.decoder = decoder
187
+ self._lambda = _lambda
188
+ dim = decoder.hidden_dim
189
+ self.D = nn.Sequential(
190
+ nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid()
191
+ )
192
+ self.word_drop = word_drop
193
+
194
+ @property
195
+ def device(self):
196
+ return self.encoder.device
197
+
198
+ def get_z(self, **inputs):
199
+ if self.word_drop is not None:
200
+ m = inputs["enc_attention_mask"]
201
+ b = torch.rand_like(m.float()) > self.word_drop
202
+ inputs["enc_attention_mask"] = m & b
203
+ return self.encoder(**get_inputs(inputs, "enc_")), None
204
+
205
+ def loss_adv(self, z):
206
+ # https://github.com/shentianxiao/text-autoencoders
207
+ zn = torch.randn_like(z)
208
+ zeros = torch.zeros(len(z), 1, device=z.device)
209
+ ones = torch.ones(len(z), 1, device=z.device)
210
+ loss_d = F.binary_cross_entropy(
211
+ self.D(z.detach()), zeros, reduction="none"
212
+ ) + F.binary_cross_entropy(self.D(zn), ones, reduction="none")
213
+ adv = F.binary_cross_entropy(self.D(z), ones, reduction="none")
214
+ return loss_d, adv
215
+
216
+ def forward(self, **inputs):
217
+ z, _ = self.get_z(**inputs)
218
+ out = self.decoder(z, **get_inputs(inputs, "dec_"))
219
+ b, n, _ = out["logits"].shape
220
+ log_probs = out["logits"].log_softmax(-1)
221
+ log_probs = torch.gather(
222
+ log_probs[:, :-1],
223
+ -1,
224
+ inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
225
+ ).squeeze(-1)
226
+ log_probs = log_probs.masked_fill(
227
+ ~inputs["dec_attention_mask"][:, 1:], 0
228
+ )
229
+ out["l_rec"] = -log_probs.sum(-1)
230
+ out["loss_d"], out["adv"] = self.loss_adv(z)
231
+ return out
232
+
233
+
234
+ class AE(nn.Module):
235
+ def __init__(self, encoder, decoder, **kwargs):
236
+ super().__init__()
237
+ self.encoder = encoder
238
+ self.decoder = decoder
239
+ dim = decoder.hidden_dim
240
+ self.D = nn.Sequential(
241
+ nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid()
242
+ )
243
+
244
+ @property
245
+ def device(self):
246
+ return self.encoder.device
247
+
248
+ def get_z(self, **inputs):
249
+ return self.encoder(**get_inputs(inputs, "enc_")), None
250
+
251
+ def step(self, **inputs):
252
+ z, _ = self.get_z(**inputs)
253
+ out = self.decoder(z, **get_inputs(inputs, "dec_"))
254
+ b, n, _ = out["logits"].shape
255
+ log_probs = out["logits"].log_softmax(-1)
256
+ log_probs = torch.gather(
257
+ log_probs[:, :-1],
258
+ -1,
259
+ inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
260
+ ).squeeze(-1)
261
+ log_probs = log_probs.masked_fill(
262
+ ~inputs["dec_attention_mask"][:, 1:], 0
263
+ )
264
+ out["loss_r"] = -log_probs.sum(-1)
265
+ return z, out
266
+
267
+ def forward(self, **inputs):
268
+ z, out = self.step(**inputs)
269
+ out["loss_c"] = torch.zeros_like(out["loss_r"])
270
+ return out
271
+
272
+
273
+ class CDAE(nn.Module):
274
+ def __init__(
275
+ self, encoder, decoder, _lambda=1.0, word_drop=None, tau=1.0, **kwargs
276
+ ):
277
+ super().__init__()
278
+ self.encoder = encoder
279
+ self.decoder = decoder
280
+ self._lambda = _lambda
281
+ dim = decoder.hidden_dim
282
+ self.D = nn.Sequential(
283
+ nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1), nn.Sigmoid()
284
+ )
285
+ self.word_drop = word_drop
286
+ self.tau = tau
287
+
288
+ @property
289
+ def device(self):
290
+ return self.encoder.device
291
+
292
+ def do_mask(self, **inputs):
293
+ m = inputs["enc_attention_mask"]
294
+ b = torch.rand_like(m.float()) > self.word_drop
295
+ inputs["enc_attention_mask"] = m & b
296
+
297
+ B, N = inputs["dec_attention_mask"].shape
298
+ _, M = m.shape
299
+ m2 = inputs["dec_attention_mask"]
300
+ if N <= M:
301
+ b2 = b[:, :N]
302
+ else:
303
+ b_ = torch.rand((B, N - M), device=b.device) > self.word_drop
304
+ b2 = torch.cat([b, b_], -1)
305
+ inputs["dec_attention_mask"] = m2 & b2
306
+
307
+ def get_z(self, **inputs):
308
+ return self.encoder(**get_inputs(inputs, "enc_")), None
309
+
310
+ def step(self, **inputs):
311
+ z, _ = self.get_z(**inputs)
312
+ out = self.decoder(z, **get_inputs(inputs, "dec_"))
313
+ b, n, _ = out["logits"].shape
314
+ log_probs = out["logits"].log_softmax(-1)
315
+ log_probs = torch.gather(
316
+ log_probs[:, :-1],
317
+ -1,
318
+ inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
319
+ ).squeeze(-1)
320
+ log_probs = log_probs.masked_fill(
321
+ ~inputs["dec_attention_mask"][:, 1:], 0
322
+ )
323
+ out["loss_r"] = -log_probs.sum(-1)
324
+ return z, out
325
+
326
+ def loss_c(self, z, z2):
327
+ scores = -(torch.cdist(z, z2) ** 2)
328
+ log_probs = (scores / self.tau).log_softmax(-1)
329
+ loss = -torch.diagonal(log_probs)
330
+ return loss
331
+
332
+ def forward(self, **inputs):
333
+ z, out = self.step(**inputs)
334
+ self.do_mask(**inputs)
335
+ z_, out_ = self.step(**inputs)
336
+ out["loss_r"] = out["loss_r"] + out_["loss_r"]
337
+ out["loss_c"] = self.loss_c(z, z_)
338
+ return out
339
+
340
+
341
+ def run_aae_epoch(
342
+ model,
343
+ batches,
344
+ opt,
345
+ optD,
346
+ num_samples=1,
347
+ lambda_adv=1.0,
348
+ desc="",
349
+ notebook=True,
350
+ ):
351
+ losses = {k: [] for k in ("l_rec", "adv", "loss_d")}
352
+ t = (
353
+ tqdm_notebook(batches, desc=desc)
354
+ if notebook
355
+ else tqdm(batches, desc=desc)
356
+ )
357
+ for batch in t:
358
+ model_inputs = {
359
+ k: v.to(model.device)
360
+ for k, v in batch.items()
361
+ if type(v) == torch.Tensor
362
+ }
363
+ out = model(**model_inputs)
364
+ loss = (out["l_rec"] + lambda_adv * out["adv"]).sum()
365
+ opt.zero_grad()
366
+ loss.backward()
367
+ opt.step()
368
+
369
+ loss_d = out["loss_d"].sum()
370
+ optD.zero_grad()
371
+ loss_d.backward()
372
+ optD.step()
373
+
374
+ d = {}
375
+ for k in ("l_rec", "adv", "loss_d"):
376
+ d[k] = out[k].mean().item()
377
+ losses[k].append(out[k].detach().cpu().numpy())
378
+ t.set_postfix(d)
379
+ return {k: np.concatenate(v, 0) for k, v in losses.items()}
380
+
381
+
382
+ class GAE(nn.Module):
383
+ def __init__(self, encoder, decoder, tau=0.05, **kwargs):
384
+ super().__init__()
385
+ self.encoder = encoder
386
+ self.decoder = decoder
387
+ self.tau = tau
388
+
389
+ @property
390
+ def device(self):
391
+ return self.encoder.device
392
+
393
+ def get_z(self, **inputs):
394
+ return self.encoder(**get_inputs(inputs, "enc_")), None
395
+
396
+ def loss_c(self, z, z2):
397
+ scores = F.normalize(z, dim=-1) @ F.normalize(z2, dim=-1).T
398
+ log_probs = (scores / self.tau).log_softmax(-1)
399
+ loss = -torch.diagonal(log_probs)
400
+ return loss
401
+
402
+ def forward(self, **inputs):
403
+ z, _ = self.get_z(**inputs)
404
+ out = self.decoder(z, **get_inputs(inputs, "dec_"))
405
+ b, n, _ = out["logits"].shape
406
+ log_probs = out["logits"].log_softmax(-1)
407
+ log_probs = torch.gather(
408
+ log_probs[:, :-1],
409
+ -1,
410
+ inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
411
+ ).squeeze(-1)
412
+ log_probs = log_probs.masked_fill(
413
+ ~inputs["dec_attention_mask"][:, 1:], 0
414
+ )
415
+ out["loss_r"] = -log_probs.sum(-1)
416
+ out["loss_c"] = self.loss_c(z)
417
+ return out
418
+
419
+
420
+ class CAE(nn.Module):
421
+ def __init__(self, encoder, decoder, tau=0.05, **kwargs):
422
+ super().__init__()
423
+ self.encoder = encoder
424
+ self.decoder = decoder
425
+ self.tau = tau
426
+
427
+ @property
428
+ def device(self):
429
+ return self.encoder.device
430
+
431
+ def get_z(self, **inputs):
432
+ return self.encoder(**get_inputs(inputs, "enc_")), None
433
+
434
+ def loss_c(self, z, z2):
435
+ scores = F.normalize(z, dim=-1) @ F.normalize(z2, dim=-1).T
436
+ log_probs = (scores / self.tau).log_softmax(-1)
437
+ loss = -torch.diagonal(log_probs)
438
+ return loss
439
+
440
+ def forward(self, **inputs):
441
+ z, _ = self.get_z(**inputs)
442
+ with torch.no_grad():
443
+ z2, _ = self.get_z(**inputs)
444
+ out = self.decoder(z, **get_inputs(inputs, "dec_"))
445
+ b, n, _ = out["logits"].shape
446
+ log_probs = out["logits"].log_softmax(-1)
447
+ log_probs = torch.gather(
448
+ log_probs[:, :-1],
449
+ -1,
450
+ inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
451
+ ).squeeze(-1)
452
+ log_probs = log_probs.masked_fill(
453
+ ~inputs["dec_attention_mask"][:, 1:], 0
454
+ )
455
+ out["loss_r"] = -log_probs.sum(-1)
456
+ out["loss_c"] = self.loss_c(z, z2)
457
+ return out
458
+
459
+
460
+ def run_cae_epoch(
461
+ model,
462
+ batches,
463
+ opt,
464
+ num_samples=1,
465
+ lambda_c=1.0,
466
+ desc="",
467
+ notebook=True,
468
+ ):
469
+ losses = {k: [] for k in ("loss_r", "loss_c")}
470
+ t = (
471
+ tqdm_notebook(batches, desc=desc)
472
+ if notebook
473
+ else tqdm(batches, desc=desc)
474
+ )
475
+ model.train()
476
+ for batch in t:
477
+ model_inputs = {
478
+ k: v.to(model.device)
479
+ for k, v in batch.items()
480
+ if type(v) == torch.Tensor
481
+ }
482
+ out = model(**model_inputs)
483
+ loss = (out["loss_r"] + lambda_c * out["loss_c"]).sum()
484
+ opt.zero_grad()
485
+ loss.backward()
486
+ opt.step()
487
+ d = {}
488
+ for k in ("loss_r", "loss_c"):
489
+ d[k] = out[k].mean().item()
490
+ losses[k].append(out[k].detach().cpu().numpy())
491
+ t.set_postfix(d)
492
+ return {k: np.concatenate(v, 0) for k, v in losses.items()}
493
+
494
+
495
+ def batch_kl(l1, s1, l2=None, s2=None):
496
+ # 1/2[log |s1|/|s2| - d + tr[s2^{-1}s1] + (l2 - l1)^{\top} s2^{-1}(l2 - l1)]
497
+ return
498
+
499
+
500
+ class SubpopCondAE(nn.Module):
501
+ def __init__(
502
+ self,
503
+ encoder,
504
+ decoder,
505
+ num_labels,
506
+ sublabels=4,
507
+ tau=0.05,
508
+ disc_loss=True,
509
+ **kwargs,
510
+ ):
511
+ super().__init__()
512
+ self.encoder = encoder
513
+ self.decoder = decoder
514
+ self.dim = dim = decoder.hidden_dim
515
+ self.locs = nn.Parameter(torch.randn(num_labels * sublabels, dim))
516
+ self.log_scales = nn.Parameter(torch.zeros(num_labels * sublabels, dim))
517
+ self.num_labels = num_labels
518
+ self.sublabels = sublabels
519
+ self.L = num_labels * sublabels
520
+ self.tau = tau
521
+ self.disc_loss = disc_loss
522
+
523
+ @property
524
+ def device(self):
525
+ return self.encoder.device
526
+
527
+ def get_z(self, **inputs):
528
+ return self.encoder(**get_inputs(inputs, "enc_")), None
529
+
530
+ def loss_c(self, z, **inputs):
531
+ scores = []
532
+ for i in range(self.L):
533
+ dist = Independent(
534
+ Normal(loc=self.locs[i], scale=self.log_scales[i].exp()), 1
535
+ )
536
+ scores.append(dist.log_prob(z))
537
+ B = z.shape[0]
538
+ sub_log_probs = torch.stack(scores, -1)
539
+ if self.disc_loss:
540
+ sub_log_probs = sub_log_probs.log_softmax(-1)
541
+ log_probs = sub_log_probs.view(
542
+ B, self.num_labels, self.num_sublabels
543
+ ).logsumexp(-1)
544
+ loss = F.nll_loss(log_probs, inputs["label"], reduction="none")
545
+ acc = log_probs.argmax(-1) == inputs["label"]
546
+ return {
547
+ "loss_c": loss,
548
+ "log_probs": log_probs,
549
+ "sub_log_probs": sub_log_probs,
550
+ "acc": acc.float(),
551
+ }
552
+
553
+ def get_kl(self):
554
+ p = MultivariateNormal(
555
+ torch.zeros(self.dim, device=self.device),
556
+ torch.eye(self.dim, device=self.device),
557
+ )
558
+ kl = 0
559
+ for i in range(self.L):
560
+ q = MultivariateNormal(
561
+ self.locs[i], torch.diag(self.log_scales[i].exp())
562
+ )
563
+ kl += torch.distributions.kl_divergence(q, p)
564
+ return kl
565
+
566
+ def forward(self, **inputs):
567
+ z, _ = self.get_z(**inputs)
568
+ out = self.decoder(z, **get_inputs(inputs, "dec_"))
569
+ b, n, _ = out["logits"].shape
570
+ log_probs = out["logits"].log_softmax(-1)
571
+ log_probs = torch.gather(
572
+ log_probs[:, :-1],
573
+ -1,
574
+ inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
575
+ ).squeeze(-1)
576
+ log_probs = log_probs.masked_fill(
577
+ ~inputs["dec_attention_mask"][:, 1:], 0
578
+ )
579
+ out["loss_r"] = -log_probs.sum(-1)
580
+ out_c = self.loss_c(z, **inputs)
581
+ for k, v in out_c.items():
582
+ out[k] = v
583
+ out["kl"] = self.get_kl().unsqueeze(0)
584
+ return out
585
+
586
+
587
+ def gaussian_prob_product(m1, s1, m2, s2, rho=1.0):
588
+ # s1, s2 diagonal
589
+ s1_inv = 1 / s1
590
+ s2_inv = 1 / s2
591
+ s_hat = 1 / (s1 + s2)
592
+ m_hat = s1_inv * s1 + s2_inv * s2
593
+ dim = m1.shape[-1]
594
+ return (
595
+ ((2 * math.pi) ** ((1 - 2 * rho) * dim / 2))
596
+ * (rho ** (-dim / 2))
597
+ * torch.sqrt(s_hat.prod(-1))
598
+ * ((s1.prod(-1) * s2.prod(-1)) ** (-rho / 2))
599
+ * torch.exp(
600
+ -(1 / rho)
601
+ * (
602
+ m1 @ (s1_inv * m1).T
603
+ + m2 @ (s2_inv * m2).T
604
+ - m_hat @ (s_hat * m_hat).T
605
+ )
606
+ )
607
+ )
608
+
609
+
610
+ class CondAE(nn.Module):
611
+ def __init__(
612
+ self,
613
+ encoder,
614
+ decoder,
615
+ num_labels,
616
+ logdet=False,
617
+ l2_reg=False,
618
+ disc_loss=True,
619
+ tau=0.05,
620
+ **kwargs,
621
+ ):
622
+ super().__init__()
623
+ self.encoder = encoder
624
+ self.decoder = decoder
625
+ self.dim = dim = decoder.hidden_dim
626
+ self.locs = nn.Parameter(torch.randn(num_labels, dim))
627
+ self.log_scales = nn.Parameter(torch.zeros(num_labels, dim))
628
+ self.num_labels = num_labels
629
+ self.tau = tau
630
+ self.logdet = logdet
631
+ self.l2_reg = l2_reg
632
+ self.disc_loss = disc_loss
633
+
634
+ @property
635
+ def device(self):
636
+ return self.encoder.device
637
+
638
+ def get_z(self, **inputs):
639
+ return self.encoder(**get_inputs(inputs, "enc_")), None
640
+
641
+ def loss_c(self, z, **inputs):
642
+ scores = []
643
+ for i in range(self.num_labels):
644
+ dist = Independent(
645
+ Normal(loc=self.locs[i], scale=self.log_scales[i].exp()), 1
646
+ )
647
+ scores.append(dist.log_prob(z))
648
+ log_probs = torch.stack(scores, -1)
649
+ if self.disc_loss:
650
+ log_probs = log_probs.log_softmax(-1)
651
+ loss = F.nll_loss(log_probs, inputs["label"], reduction="none")
652
+ acc = log_probs.argmax(-1) == inputs["label"]
653
+ return {"loss_c": loss, "log_probs": log_probs, "acc": acc.float()}
654
+
655
+ def get_kl(self):
656
+ p = MultivariateNormal(
657
+ torch.zeros(self.dim, device=self.device),
658
+ torch.eye(self.dim, device=self.device),
659
+ )
660
+ kl = 0
661
+ for i in range(self.num_labels):
662
+ q = MultivariateNormal(
663
+ self.locs[i], torch.diag(self.log_scales[i].exp())
664
+ )
665
+ kl += torch.distributions.kl_divergence(q, p)
666
+ if self.logdet:
667
+ K = torch.exp(-torch.cdist(self.locs, self.locs) ** 2)
668
+ kl += torch.logdet(K)
669
+ elif self.l2_reg:
670
+ K = torch.exp(-torch.cdist(self.locs, self.locs) ** 2)
671
+ kl += torch.log(
672
+ torch.linalg.norm(K / K.shape[0], dim=(-2, -1)) ** 2
673
+ ).sum()
674
+ return kl
675
+
676
+ def forward(self, **inputs):
677
+ z, _ = self.get_z(**inputs)
678
+ out = self.decoder(z, **get_inputs(inputs, "dec_"))
679
+ b, n, _ = out["logits"].shape
680
+ log_probs = out["logits"].log_softmax(-1)
681
+ log_probs = torch.gather(
682
+ log_probs[:, :-1],
683
+ -1,
684
+ inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
685
+ ).squeeze(-1)
686
+ log_probs = log_probs.masked_fill(
687
+ ~inputs["dec_attention_mask"][:, 1:], 0
688
+ )
689
+ out["loss_r"] = -log_probs.sum(-1)
690
+ out_c = self.loss_c(z, **inputs)
691
+ for k, v in out_c.items():
692
+ out[k] = v
693
+ out["kl"] = self.get_kl().unsqueeze(0)
694
+ return out
695
+
696
+
697
+ class BasicCondAE(nn.Module):
698
+ def __init__(self, encoder, decoder, num_labels, tau=0.05, **kwargs):
699
+ super().__init__()
700
+ self.encoder = encoder
701
+ self.decoder = decoder
702
+ self.dim = dim = decoder.hidden_dim
703
+ self.linear = nn.Linear(dim, num_labels)
704
+ self.num_labels = num_labels
705
+ self.tau = tau
706
+
707
+ @property
708
+ def device(self):
709
+ return self.encoder.device
710
+
711
+ def get_z(self, **inputs):
712
+ return self.encoder(**get_inputs(inputs, "enc_")), None
713
+
714
+ def loss_c(self, z, **inputs):
715
+ log_probs = self.linear(z).log_softmax(-1)
716
+ loss = F.nll_loss(log_probs, inputs["label"], reduction="none")
717
+ acc = log_probs.argmax(-1) == inputs["label"]
718
+ return {"loss_c": loss, "log_probs": log_probs, "acc": acc.float()}
719
+
720
+ def forward(self, **inputs):
721
+ z, _ = self.get_z(**inputs)
722
+ out = self.decoder(z, **get_inputs(inputs, "dec_"))
723
+ b, n, _ = out["logits"].shape
724
+ log_probs = out["logits"].log_softmax(-1)
725
+ log_probs = torch.gather(
726
+ log_probs[:, :-1],
727
+ -1,
728
+ inputs["dec_input_ids"][:, 1:].unsqueeze(-1),
729
+ ).squeeze(-1)
730
+ log_probs = log_probs.masked_fill(
731
+ ~inputs["dec_attention_mask"][:, 1:], 0
732
+ )
733
+ out["loss_r"] = -log_probs.sum(-1)
734
+ out_c = self.loss_c(z, **inputs)
735
+ for k, v in out_c.items():
736
+ out[k] = v
737
+ out["kl"] = torch.zeros_like(out["loss_r"])
738
+ return out
739
+
740
+
741
+ def run_cond_ae_epoch(
742
+ model,
743
+ batches,
744
+ opt,
745
+ num_samples=1,
746
+ lambda_c=1.0,
747
+ lambda_r=1.0,
748
+ beta=1.0,
749
+ desc="",
750
+ notebook=True,
751
+ ):
752
+ losses = {k: [] for k in ("loss_r", "loss_c", "kl", "acc")}
753
+ t = (
754
+ tqdm_notebook(batches, desc=desc)
755
+ if notebook
756
+ else tqdm(batches, desc=desc)
757
+ )
758
+ model.train()
759
+ for batch in t:
760
+ model_inputs = {
761
+ k: v.to(model.device)
762
+ for k, v in batch.items()
763
+ if type(v) == torch.Tensor
764
+ }
765
+ out = model(**model_inputs)
766
+ loss = (
767
+ lambda_r * out["loss_r"] + lambda_c * out["loss_c"]
768
+ ).sum() + beta * out["kl"].sum()
769
+ opt.zero_grad()
770
+ loss.backward()
771
+ opt.step()
772
+ d = {}
773
+ for k in ("loss_r", "loss_c", "kl", "acc"):
774
+ d[k] = out[k].mean().item()
775
+ losses[k].append(out[k].detach().cpu().numpy())
776
+ t.set_postfix(d)
777
+ return {k: np.concatenate(v, 0) for k, v in losses.items()}
778
+
779
+
780
+ def run_cond_ae_eval(
781
+ model,
782
+ batches,
783
+ lambda_c=1.0,
784
+ beta=1.0,
785
+ desc="",
786
+ notebook=True,
787
+ ):
788
+ losses = {k: [] for k in ("loss_r", "loss_c", "kl", "acc")}
789
+ t = (
790
+ tqdm_notebook(batches, desc=desc)
791
+ if notebook
792
+ else tqdm(batches, desc=desc)
793
+ )
794
+ model.eval()
795
+ for batch in t:
796
+ model_inputs = {
797
+ k: v.to(model.device)
798
+ for k, v in batch.items()
799
+ if type(v) == torch.Tensor
800
+ }
801
+ with torch.no_grad():
802
+ out = model(**model_inputs)
803
+ loss = (
804
+ out["loss_r"] + lambda_c * out["loss_c"]
805
+ ).sum() + beta * out["kl"].sum()
806
+ d = {}
807
+ for k in ("loss_r", "loss_c", "kl", "acc"):
808
+ d[k] = out[k].mean().item()
809
+ losses[k].append(out[k].detach().cpu().numpy())
810
+ t.set_postfix(d)
811
+ return {k: np.concatenate(v, 0) for k, v in losses.items()}
812
+
813
+
814
+ def generate(
815
+ model,
816
+ tokenizer,
817
+ batch=None,
818
+ z=None,
819
+ do_sample=False,
820
+ max_length=128,
821
+ **kwargs,
822
+ ):
823
+ if z is None:
824
+ with torch.no_grad():
825
+ z, _ = model.get_z(sample=False, **batch)
826
+ B, D = z.shape
827
+ else:
828
+ z = torch.tensor(z, device=model.device)
829
+ B, D = z.shape
830
+ D, L, H, K = (
831
+ model.decoder.hidden_dim,
832
+ model.decoder.num_layers,
833
+ model.decoder.num_heads,
834
+ model.decoder.prefix_length,
835
+ )
836
+ z_up = model.decoder.mlp(z).reshape(B, H, K, D // H, L, 2)
837
+ keys, vals = (t.squeeze(-1) for t in z_up.chunk(2, dim=-1))
838
+ layers = tuple(
839
+ [
840
+ (k.squeeze(-1), v.squeeze(-1))
841
+ for k, v in zip(keys.chunk(L, -1), vals.chunk(L, -1))
842
+ ]
843
+ )
844
+ output = model.decoder.decoder.generate(
845
+ input_ids=torch.tensor(
846
+ [[tokenizer.bos_token_id]] * B, device=model.device
847
+ ),
848
+ attention_mask=torch.ones((B, K + 1), device=model.device),
849
+ past=layers,
850
+ do_sample=do_sample,
851
+ max_length=max_length,
852
+ **kwargs,
853
+ )
854
+ lst = tokenizer.batch_decode(output[:, 1:])
855
+ return [l.replace("<|endoftext|>", "") for l in lst]
856
+
857
+
858
+ def get_embeddings(model, batches, desc="", notebook=True):
859
+ out = []
860
+ t = (
861
+ tqdm_notebook(batches, desc=desc)
862
+ if notebook
863
+ else tqdm(batches, desc=desc)
864
+ )
865
+ model.eval()
866
+ for batch in t:
867
+ with torch.no_grad():
868
+ model_inputs = {
869
+ k: v.to(model.device)
870
+ for k, v in batch.items()
871
+ if type(v) == torch.Tensor
872
+ }
873
+ z, _ = model.get_z(sample=False, **model_inputs)
874
+ out.append(z.detach().cpu().numpy())
875
+ return np.concatenate(out, 0)
876
+
877
+
878
+ def interpolate(model, tokenizer, a, b, num_steps=10, **kwargs):
879
+ z = np.stack(
880
+ [l * b + (1 - l) * a for l in np.linspace(0, 1.0, num_steps)], 0
881
+ )
882
+ return generate(model, tokenizer, z=z, **kwargs)