Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import math | |
from typing import Optional, Tuple, Union | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from ..activations import get_activation | |
from ..resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims | |
class DownResnetBlock1D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: Optional[int] = None, | |
num_layers: int = 1, | |
conv_shortcut: bool = False, | |
temb_channels: int = 32, | |
groups: int = 32, | |
groups_out: Optional[int] = None, | |
non_linearity: Optional[str] = None, | |
time_embedding_norm: str = "default", | |
output_scale_factor: float = 1.0, | |
add_downsample: bool = True, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
out_channels = in_channels if out_channels is None else out_channels | |
self.out_channels = out_channels | |
self.use_conv_shortcut = conv_shortcut | |
self.time_embedding_norm = time_embedding_norm | |
self.add_downsample = add_downsample | |
self.output_scale_factor = output_scale_factor | |
if groups_out is None: | |
groups_out = groups | |
# there will always be at least one resnet | |
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] | |
for _ in range(num_layers): | |
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) | |
self.resnets = nn.ModuleList(resnets) | |
if non_linearity is None: | |
self.nonlinearity = None | |
else: | |
self.nonlinearity = get_activation(non_linearity) | |
self.downsample = None | |
if add_downsample: | |
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) | |
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
output_states = () | |
hidden_states = self.resnets[0](hidden_states, temb) | |
for resnet in self.resnets[1:]: | |
hidden_states = resnet(hidden_states, temb) | |
output_states += (hidden_states,) | |
if self.nonlinearity is not None: | |
hidden_states = self.nonlinearity(hidden_states) | |
if self.downsample is not None: | |
hidden_states = self.downsample(hidden_states) | |
return hidden_states, output_states | |
class UpResnetBlock1D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: Optional[int] = None, | |
num_layers: int = 1, | |
temb_channels: int = 32, | |
groups: int = 32, | |
groups_out: Optional[int] = None, | |
non_linearity: Optional[str] = None, | |
time_embedding_norm: str = "default", | |
output_scale_factor: float = 1.0, | |
add_upsample: bool = True, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
out_channels = in_channels if out_channels is None else out_channels | |
self.out_channels = out_channels | |
self.time_embedding_norm = time_embedding_norm | |
self.add_upsample = add_upsample | |
self.output_scale_factor = output_scale_factor | |
if groups_out is None: | |
groups_out = groups | |
# there will always be at least one resnet | |
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] | |
for _ in range(num_layers): | |
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) | |
self.resnets = nn.ModuleList(resnets) | |
if non_linearity is None: | |
self.nonlinearity = None | |
else: | |
self.nonlinearity = get_activation(non_linearity) | |
self.upsample = None | |
if add_upsample: | |
self.upsample = Upsample1D(out_channels, use_conv_transpose=True) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None, | |
temb: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
if res_hidden_states_tuple is not None: | |
res_hidden_states = res_hidden_states_tuple[-1] | |
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) | |
hidden_states = self.resnets[0](hidden_states, temb) | |
for resnet in self.resnets[1:]: | |
hidden_states = resnet(hidden_states, temb) | |
if self.nonlinearity is not None: | |
hidden_states = self.nonlinearity(hidden_states) | |
if self.upsample is not None: | |
hidden_states = self.upsample(hidden_states) | |
return hidden_states | |
class ValueFunctionMidBlock1D(nn.Module): | |
def __init__(self, in_channels: int, out_channels: int, embed_dim: int): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.embed_dim = embed_dim | |
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) | |
self.down1 = Downsample1D(out_channels // 2, use_conv=True) | |
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) | |
self.down2 = Downsample1D(out_channels // 4, use_conv=True) | |
def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
x = self.res1(x, temb) | |
x = self.down1(x) | |
x = self.res2(x, temb) | |
x = self.down2(x) | |
return x | |
class MidResTemporalBlock1D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
embed_dim: int, | |
num_layers: int = 1, | |
add_downsample: bool = False, | |
add_upsample: bool = False, | |
non_linearity: Optional[str] = None, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.add_downsample = add_downsample | |
# there will always be at least one resnet | |
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] | |
for _ in range(num_layers): | |
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) | |
self.resnets = nn.ModuleList(resnets) | |
if non_linearity is None: | |
self.nonlinearity = None | |
else: | |
self.nonlinearity = get_activation(non_linearity) | |
self.upsample = None | |
if add_upsample: | |
self.upsample = Downsample1D(out_channels, use_conv=True) | |
self.downsample = None | |
if add_downsample: | |
self.downsample = Downsample1D(out_channels, use_conv=True) | |
if self.upsample and self.downsample: | |
raise ValueError("Block cannot downsample and upsample") | |
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor: | |
hidden_states = self.resnets[0](hidden_states, temb) | |
for resnet in self.resnets[1:]: | |
hidden_states = resnet(hidden_states, temb) | |
if self.upsample: | |
hidden_states = self.upsample(hidden_states) | |
if self.downsample: | |
self.downsample = self.downsample(hidden_states) | |
return hidden_states | |
class OutConv1DBlock(nn.Module): | |
def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str): | |
super().__init__() | |
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) | |
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) | |
self.final_conv1d_act = get_activation(act_fn) | |
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) | |
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
hidden_states = self.final_conv1d_1(hidden_states) | |
hidden_states = rearrange_dims(hidden_states) | |
hidden_states = self.final_conv1d_gn(hidden_states) | |
hidden_states = rearrange_dims(hidden_states) | |
hidden_states = self.final_conv1d_act(hidden_states) | |
hidden_states = self.final_conv1d_2(hidden_states) | |
return hidden_states | |
class OutValueFunctionBlock(nn.Module): | |
def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"): | |
super().__init__() | |
self.final_block = nn.ModuleList( | |
[ | |
nn.Linear(fc_dim + embed_dim, fc_dim // 2), | |
get_activation(act_fn), | |
nn.Linear(fc_dim // 2, 1), | |
] | |
) | |
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor: | |
hidden_states = hidden_states.view(hidden_states.shape[0], -1) | |
hidden_states = torch.cat((hidden_states, temb), dim=-1) | |
for layer in self.final_block: | |
hidden_states = layer(hidden_states) | |
return hidden_states | |
_kernels = { | |
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], | |
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], | |
"lanczos3": [ | |
0.003689131001010537, | |
0.015056144446134567, | |
-0.03399861603975296, | |
-0.066637322306633, | |
0.13550527393817902, | |
0.44638532400131226, | |
0.44638532400131226, | |
0.13550527393817902, | |
-0.066637322306633, | |
-0.03399861603975296, | |
0.015056144446134567, | |
0.003689131001010537, | |
], | |
} | |
class Downsample1d(nn.Module): | |
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): | |
super().__init__() | |
self.pad_mode = pad_mode | |
kernel_1d = torch.tensor(_kernels[kernel]) | |
self.pad = kernel_1d.shape[0] // 2 - 1 | |
self.register_buffer("kernel", kernel_1d) | |
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: | |
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) | |
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) | |
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) | |
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) | |
weight[indices, indices] = kernel | |
return F.conv1d(hidden_states, weight, stride=2) | |
class Upsample1d(nn.Module): | |
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): | |
super().__init__() | |
self.pad_mode = pad_mode | |
kernel_1d = torch.tensor(_kernels[kernel]) * 2 | |
self.pad = kernel_1d.shape[0] // 2 - 1 | |
self.register_buffer("kernel", kernel_1d) | |
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) | |
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) | |
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) | |
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) | |
weight[indices, indices] = kernel | |
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) | |
class SelfAttention1d(nn.Module): | |
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0): | |
super().__init__() | |
self.channels = in_channels | |
self.group_norm = nn.GroupNorm(1, num_channels=in_channels) | |
self.num_heads = n_head | |
self.query = nn.Linear(self.channels, self.channels) | |
self.key = nn.Linear(self.channels, self.channels) | |
self.value = nn.Linear(self.channels, self.channels) | |
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True) | |
self.dropout = nn.Dropout(dropout_rate, inplace=True) | |
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: | |
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) | |
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) | |
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) | |
return new_projection | |
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: | |
residual = hidden_states | |
batch, channel_dim, seq = hidden_states.shape | |
hidden_states = self.group_norm(hidden_states) | |
hidden_states = hidden_states.transpose(1, 2) | |
query_proj = self.query(hidden_states) | |
key_proj = self.key(hidden_states) | |
value_proj = self.value(hidden_states) | |
query_states = self.transpose_for_scores(query_proj) | |
key_states = self.transpose_for_scores(key_proj) | |
value_states = self.transpose_for_scores(value_proj) | |
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) | |
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) | |
attention_probs = torch.softmax(attention_scores, dim=-1) | |
# compute attention output | |
hidden_states = torch.matmul(attention_probs, value_states) | |
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() | |
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) | |
hidden_states = hidden_states.view(new_hidden_states_shape) | |
# compute next hidden_states | |
hidden_states = self.proj_attn(hidden_states) | |
hidden_states = hidden_states.transpose(1, 2) | |
hidden_states = self.dropout(hidden_states) | |
output = hidden_states + residual | |
return output | |
class ResConvBlock(nn.Module): | |
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False): | |
super().__init__() | |
self.is_last = is_last | |
self.has_conv_skip = in_channels != out_channels | |
if self.has_conv_skip: | |
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) | |
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) | |
self.group_norm_1 = nn.GroupNorm(1, mid_channels) | |
self.gelu_1 = nn.GELU() | |
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) | |
if not self.is_last: | |
self.group_norm_2 = nn.GroupNorm(1, out_channels) | |
self.gelu_2 = nn.GELU() | |
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: | |
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states | |
hidden_states = self.conv_1(hidden_states) | |
hidden_states = self.group_norm_1(hidden_states) | |
hidden_states = self.gelu_1(hidden_states) | |
hidden_states = self.conv_2(hidden_states) | |
if not self.is_last: | |
hidden_states = self.group_norm_2(hidden_states) | |
hidden_states = self.gelu_2(hidden_states) | |
output = hidden_states + residual | |
return output | |
class UNetMidBlock1D(nn.Module): | |
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None): | |
super().__init__() | |
out_channels = in_channels if out_channels is None else out_channels | |
# there is always at least one resnet | |
self.down = Downsample1d("cubic") | |
resnets = [ | |
ResConvBlock(in_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, out_channels), | |
] | |
attentions = [ | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(out_channels, out_channels // 32), | |
] | |
self.up = Upsample1d(kernel="cubic") | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
hidden_states = self.down(hidden_states) | |
for attn, resnet in zip(self.attentions, self.resnets): | |
hidden_states = resnet(hidden_states) | |
hidden_states = attn(hidden_states) | |
hidden_states = self.up(hidden_states) | |
return hidden_states | |
class AttnDownBlock1D(nn.Module): | |
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): | |
super().__init__() | |
mid_channels = out_channels if mid_channels is None else mid_channels | |
self.down = Downsample1d("cubic") | |
resnets = [ | |
ResConvBlock(in_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, out_channels), | |
] | |
attentions = [ | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(out_channels, out_channels // 32), | |
] | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
hidden_states = self.down(hidden_states) | |
for resnet, attn in zip(self.resnets, self.attentions): | |
hidden_states = resnet(hidden_states) | |
hidden_states = attn(hidden_states) | |
return hidden_states, (hidden_states,) | |
class DownBlock1D(nn.Module): | |
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): | |
super().__init__() | |
mid_channels = out_channels if mid_channels is None else mid_channels | |
self.down = Downsample1d("cubic") | |
resnets = [ | |
ResConvBlock(in_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, out_channels), | |
] | |
self.resnets = nn.ModuleList(resnets) | |
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
hidden_states = self.down(hidden_states) | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states) | |
return hidden_states, (hidden_states,) | |
class DownBlock1DNoSkip(nn.Module): | |
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): | |
super().__init__() | |
mid_channels = out_channels if mid_channels is None else mid_channels | |
resnets = [ | |
ResConvBlock(in_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, out_channels), | |
] | |
self.resnets = nn.ModuleList(resnets) | |
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: | |
hidden_states = torch.cat([hidden_states, temb], dim=1) | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states) | |
return hidden_states, (hidden_states,) | |
class AttnUpBlock1D(nn.Module): | |
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): | |
super().__init__() | |
mid_channels = out_channels if mid_channels is None else mid_channels | |
resnets = [ | |
ResConvBlock(2 * in_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, out_channels), | |
] | |
attentions = [ | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(mid_channels, mid_channels // 32), | |
SelfAttention1d(out_channels, out_channels // 32), | |
] | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
self.up = Upsample1d(kernel="cubic") | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
temb: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
res_hidden_states = res_hidden_states_tuple[-1] | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
for resnet, attn in zip(self.resnets, self.attentions): | |
hidden_states = resnet(hidden_states) | |
hidden_states = attn(hidden_states) | |
hidden_states = self.up(hidden_states) | |
return hidden_states | |
class UpBlock1D(nn.Module): | |
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): | |
super().__init__() | |
mid_channels = in_channels if mid_channels is None else mid_channels | |
resnets = [ | |
ResConvBlock(2 * in_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, out_channels), | |
] | |
self.resnets = nn.ModuleList(resnets) | |
self.up = Upsample1d(kernel="cubic") | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
temb: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
res_hidden_states = res_hidden_states_tuple[-1] | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states) | |
hidden_states = self.up(hidden_states) | |
return hidden_states | |
class UpBlock1DNoSkip(nn.Module): | |
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): | |
super().__init__() | |
mid_channels = in_channels if mid_channels is None else mid_channels | |
resnets = [ | |
ResConvBlock(2 * in_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, mid_channels), | |
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), | |
] | |
self.resnets = nn.ModuleList(resnets) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
temb: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
res_hidden_states = res_hidden_states_tuple[-1] | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states) | |
return hidden_states | |
DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip] | |
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D] | |
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock] | |
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip] | |
def get_down_block( | |
down_block_type: str, | |
num_layers: int, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
add_downsample: bool, | |
) -> DownBlockType: | |
if down_block_type == "DownResnetBlock1D": | |
return DownResnetBlock1D( | |
in_channels=in_channels, | |
num_layers=num_layers, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
add_downsample=add_downsample, | |
) | |
elif down_block_type == "DownBlock1D": | |
return DownBlock1D(out_channels=out_channels, in_channels=in_channels) | |
elif down_block_type == "AttnDownBlock1D": | |
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) | |
elif down_block_type == "DownBlock1DNoSkip": | |
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) | |
raise ValueError(f"{down_block_type} does not exist.") | |
def get_up_block( | |
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool | |
) -> UpBlockType: | |
if up_block_type == "UpResnetBlock1D": | |
return UpResnetBlock1D( | |
in_channels=in_channels, | |
num_layers=num_layers, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
add_upsample=add_upsample, | |
) | |
elif up_block_type == "UpBlock1D": | |
return UpBlock1D(in_channels=in_channels, out_channels=out_channels) | |
elif up_block_type == "AttnUpBlock1D": | |
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) | |
elif up_block_type == "UpBlock1DNoSkip": | |
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) | |
raise ValueError(f"{up_block_type} does not exist.") | |
def get_mid_block( | |
mid_block_type: str, | |
num_layers: int, | |
in_channels: int, | |
mid_channels: int, | |
out_channels: int, | |
embed_dim: int, | |
add_downsample: bool, | |
) -> MidBlockType: | |
if mid_block_type == "MidResTemporalBlock1D": | |
return MidResTemporalBlock1D( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
embed_dim=embed_dim, | |
add_downsample=add_downsample, | |
) | |
elif mid_block_type == "ValueFunctionMidBlock1D": | |
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) | |
elif mid_block_type == "UNetMidBlock1D": | |
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) | |
raise ValueError(f"{mid_block_type} does not exist.") | |
def get_out_block( | |
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int | |
) -> Optional[OutBlockType]: | |
if out_block_type == "OutConv1DBlock": | |
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) | |
elif out_block_type == "ValueFunction": | |
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn) | |
return None | |