|
|
|
import math |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
|
|
|
|
def get_activation(activation_type): |
|
if activation_type == "relu": |
|
return nn.ReLU() |
|
elif activation_type == "relu6": |
|
return nn.ReLU6() |
|
elif activation_type == "prelu": |
|
return nn.PReLU() |
|
elif activation_type == "selu": |
|
return nn.SELU() |
|
elif activation_type == "celu": |
|
return nn.CELU() |
|
elif activation_type == "gelu": |
|
return nn.GELU() |
|
elif activation_type == "sigmoid": |
|
return nn.Sigmoid() |
|
elif activation_type == "softplus": |
|
return nn.Softplus() |
|
elif activation_type == "softshrink": |
|
return nn.Softshrink() |
|
elif activation_type == "softsign": |
|
return nn.Softsign() |
|
elif activation_type == "tanh": |
|
return nn.Tanh() |
|
elif activation_type == "tanhshrink": |
|
return nn.Tanhshrink() |
|
else: |
|
raise ValueError("Unknown activation type {}".format(activation_type)) |
|
|
|
|
|
class MaskedNorm(nn.Module): |
|
""" |
|
Original Code from: |
|
https://discuss.pytorch.org/t/batchnorm-for-different-sized-samples-in-batch/44251/8 |
|
""" |
|
|
|
def __init__(self, norm_type, num_groups, num_features): |
|
super().__init__() |
|
self.norm_type = norm_type |
|
if self.norm_type == "batch": |
|
self.norm = nn.BatchNorm1d(num_features=num_features) |
|
elif self.norm_type == "group": |
|
self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_features) |
|
elif self.norm_type == "layer": |
|
self.norm = nn.LayerNorm(normalized_shape=num_features) |
|
else: |
|
raise ValueError("Unsupported Normalization Layer") |
|
|
|
self.num_features = num_features |
|
|
|
def forward(self, x: Tensor, mask: Tensor): |
|
if self.training: |
|
reshaped = x.reshape([-1, self.num_features]) |
|
reshaped_mask = mask.reshape([-1, 1]) > 0 |
|
selected = torch.masked_select(reshaped, reshaped_mask).reshape( |
|
[-1, self.num_features] |
|
) |
|
batch_normed = self.norm(selected) |
|
scattered = reshaped.masked_scatter(reshaped_mask, batch_normed) |
|
return scattered.reshape([x.shape[0], -1, self.num_features]) |
|
else: |
|
reshaped = x.reshape([-1, self.num_features]) |
|
batched_normed = self.norm(reshaped) |
|
return batched_normed.reshape([x.shape[0], -1, self.num_features]) |
|
|
|
|
|
|
|
|
|
|
|
class Embeddings(nn.Module): |
|
|
|
""" |
|
Simple embeddings class |
|
""" |
|
|
|
|
|
def __init__( |
|
self, |
|
embedding_dim: int = 64, |
|
num_heads: int = 8, |
|
scale: bool = False, |
|
scale_factor: float = None, |
|
norm_type: str = None, |
|
activation_type: str = None, |
|
vocab_size: int = 0, |
|
padding_idx: int = 1, |
|
freeze: bool = False, |
|
**kwargs |
|
): |
|
""" |
|
Create new embeddings for the vocabulary. |
|
Use scaling for the Transformer. |
|
|
|
:param embedding_dim: |
|
:param scale: |
|
:param vocab_size: |
|
:param padding_idx: |
|
:param freeze: freeze the embeddings during training |
|
""" |
|
super().__init__() |
|
|
|
self.embedding_dim = embedding_dim |
|
self.vocab_size = vocab_size |
|
self.lut = nn.Embedding(vocab_size, self.embedding_dim, padding_idx=padding_idx) |
|
|
|
self.norm_type = norm_type |
|
if self.norm_type: |
|
self.norm = MaskedNorm( |
|
norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim |
|
) |
|
|
|
self.activation_type = activation_type |
|
if self.activation_type: |
|
self.activation = get_activation(activation_type) |
|
|
|
self.scale = scale |
|
if self.scale: |
|
if scale_factor: |
|
self.scale_factor = scale_factor |
|
else: |
|
self.scale_factor = math.sqrt(self.embedding_dim) |
|
|
|
if freeze: |
|
freeze_params(self) |
|
|
|
|
|
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: |
|
""" |
|
Perform lookup for input `x` in the embedding table. |
|
|
|
:param mask: token masks |
|
:param x: index in the vocabulary |
|
:return: embedded representation for `x` |
|
""" |
|
|
|
x = self.lut(x) |
|
|
|
if self.norm_type: |
|
x = self.norm(x, mask) |
|
|
|
if self.activation_type: |
|
x = self.activation(x) |
|
|
|
if self.scale: |
|
return x * self.scale_factor |
|
else: |
|
return x |
|
|
|
def __repr__(self): |
|
return "%s(embedding_dim=%d, vocab_size=%d)" % ( |
|
self.__class__.__name__, |
|
self.embedding_dim, |
|
self.vocab_size, |
|
) |
|
|
|
|
|
class SpatialEmbeddings(nn.Module): |
|
|
|
""" |
|
Simple Linear Projection Layer |
|
(For encoder outputs to predict glosses) |
|
""" |
|
|
|
|
|
def __init__( |
|
self, |
|
embedding_dim: int, |
|
input_size: int, |
|
num_heads: int, |
|
freeze: bool = False, |
|
norm_type: str = "batch", |
|
activation_type: str = "softsign", |
|
scale: bool = False, |
|
scale_factor: float = None, |
|
**kwargs |
|
): |
|
""" |
|
Create new embeddings for the vocabulary. |
|
Use scaling for the Transformer. |
|
|
|
:param embedding_dim: |
|
:param input_size: |
|
:param freeze: freeze the embeddings during training |
|
""" |
|
super().__init__() |
|
|
|
self.embedding_dim = embedding_dim |
|
self.input_size = input_size |
|
self.ln = nn.Linear(self.input_size, self.embedding_dim) |
|
|
|
self.norm_type = norm_type |
|
if self.norm_type: |
|
self.norm = MaskedNorm( |
|
norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim |
|
) |
|
|
|
self.activation_type = activation_type |
|
if self.activation_type: |
|
self.activation = get_activation(activation_type) |
|
|
|
self.scale = scale |
|
if self.scale: |
|
if scale_factor: |
|
self.scale_factor = scale_factor |
|
else: |
|
self.scale_factor = math.sqrt(self.embedding_dim) |
|
|
|
if freeze: |
|
freeze_params(self) |
|
|
|
|
|
def forward(self, x: Tensor, mask: Tensor) -> Tensor: |
|
""" |
|
:param mask: frame masks |
|
:param x: input frame features |
|
:return: embedded representation for `x` |
|
""" |
|
|
|
x = self.ln(x) |
|
|
|
if self.norm_type: |
|
x = self.norm(x, mask) |
|
|
|
if self.activation_type: |
|
x = self.activation(x) |
|
|
|
if self.scale: |
|
return x * self.scale_factor |
|
else: |
|
return x |
|
|
|
def __repr__(self): |
|
return "%s(embedding_dim=%d, input_size=%d)" % ( |
|
self.__class__.__name__, |
|
self.embedding_dim, |
|
self.input_size, |
|
) |
|
|
|
def get_timestep_embedding( |
|
timesteps: torch.Tensor, |
|
embedding_dim: int, |
|
flip_sin_to_cos: bool = False, |
|
downscale_freq_shift: float = 1, |
|
scale: float = 1, |
|
max_period: int = 10000, |
|
): |
|
""" |
|
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. |
|
|
|
:param timesteps: a 1-D Tensor of N indices, one per batch element. |
|
These may be fractional. |
|
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the |
|
embeddings. :return: an [N x dim] Tensor of positional embeddings. |
|
""" |
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" |
|
|
|
half_dim = embedding_dim // 2 |
|
exponent = -math.log(max_period) * torch.arange( |
|
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device |
|
) |
|
exponent = exponent / (half_dim - downscale_freq_shift) |
|
|
|
emb = torch.exp(exponent) |
|
emb = timesteps[:, None].float() * emb[None, :] |
|
|
|
|
|
emb = scale * emb |
|
|
|
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) |
|
|
|
|
|
if flip_sin_to_cos: |
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) |
|
|
|
|
|
if embedding_dim % 2 == 1: |
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) |
|
return emb |
|
|
|
|
|
class TimestepEmbedding(nn.Module): |
|
def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): |
|
super().__init__() |
|
|
|
self.linear_1 = nn.Linear(channel, time_embed_dim) |
|
self.act = None |
|
if act_fn == "silu": |
|
self.act = nn.SiLU() |
|
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) |
|
|
|
def forward(self, sample): |
|
sample = self.linear_1(sample) |
|
|
|
if self.act is not None: |
|
sample = self.act(sample) |
|
|
|
sample = self.linear_2(sample) |
|
return sample |
|
|
|
|
|
class Timesteps(nn.Module): |
|
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): |
|
super().__init__() |
|
self.num_channels = num_channels |
|
self.flip_sin_to_cos = flip_sin_to_cos |
|
self.downscale_freq_shift = downscale_freq_shift |
|
|
|
def forward(self, timesteps): |
|
t_emb = get_timestep_embedding( |
|
timesteps, |
|
self.num_channels, |
|
flip_sin_to_cos=self.flip_sin_to_cos, |
|
downscale_freq_shift=self.downscale_freq_shift, |
|
) |
|
return t_emb |
|
|