Spaces:
Sleeping
Sleeping
File size: 4,763 Bytes
41b9d24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import time
from typing import Optional
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
def num_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def recurse_children(module, fn):
for child in module.children():
if isinstance(child, nn.ModuleList):
for c in child:
yield recurse_children(c, fn)
if isinstance(child, nn.ModuleDict):
for c in child.values():
yield recurse_children(c, fn)
yield recurse_children(child, fn)
yield fn(child)
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class SequentialWithFiLM(nn.Module):
"""
handy wrapper for nn.Sequential that allows FiLM layers to be
inserted in between other layers.
"""
def __init__(self, *layers):
super().__init__()
self.layers = nn.ModuleList(layers)
@staticmethod
def has_film(module):
mod_has_film = any(
[res for res in recurse_children(module, lambda c: isinstance(c, FiLM))]
)
return mod_has_film
def forward(self, x, cond):
for layer in self.layers:
if self.has_film(layer):
x = layer(x, cond)
else:
x = layer(x)
return x
class FiLM(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
if input_dim > 0:
self.beta = nn.Linear(input_dim, output_dim)
self.gamma = nn.Linear(input_dim, output_dim)
def forward(self, x, r):
if self.input_dim == 0:
return x
else:
beta, gamma = self.beta(r), self.gamma(r)
beta, gamma = (
beta.view(x.size(0), self.output_dim, 1),
gamma.view(x.size(0), self.output_dim, 1),
)
x = x * (gamma + 1) + beta
return x
class CodebookEmbedding(nn.Module):
def __init__(
self,
vocab_size: int,
latent_dim: int,
n_codebooks: int,
emb_dim: int,
special_tokens: Optional[Tuple[str]] = None,
):
super().__init__()
self.n_codebooks = n_codebooks
self.emb_dim = emb_dim
self.latent_dim = latent_dim
self.vocab_size = vocab_size
if special_tokens is not None:
for tkn in special_tokens:
self.special = nn.ParameterDict(
{
tkn: nn.Parameter(torch.randn(n_codebooks, self.latent_dim))
for tkn in special_tokens
}
)
self.special_idxs = {
tkn: i + vocab_size for i, tkn in enumerate(special_tokens)
}
self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
def from_codes(self, codes: torch.Tensor, codec):
"""
get a sequence of continuous embeddings from a sequence of discrete codes.
unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens
necessary for the language model, like <MASK>.
"""
n_codebooks = codes.shape[1]
latent = []
for i in range(n_codebooks):
c = codes[:, i, :]
lookup_table = codec.quantizer.quantizers[i].codebook.weight
if hasattr(self, "special"):
special_lookup = torch.cat(
[self.special[tkn][i : i + 1] for tkn in self.special], dim=0
)
lookup_table = torch.cat([lookup_table, special_lookup], dim=0)
l = F.embedding(c, lookup_table).transpose(1, 2)
latent.append(l)
latent = torch.cat(latent, dim=1)
return latent
def forward(self, latents: torch.Tensor):
"""
project a sequence of latents to a sequence of embeddings
"""
x = self.out_proj(latents)
return x
|