Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,863 Bytes
b4511c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from typing import Optional
import torch
from torch import nn
from torch.nn.utils import weight_norm
from vocos.modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm
class Backbone(nn.Module):
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
C denotes output features, and L is the sequence length.
Returns:
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
and H denotes the model dimension.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class VocosBackbone(Backbone):
"""
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
num_layers (int): Number of ConvNeXtBlock layers.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional model. Defaults to None.
"""
def __init__(
self,
input_channels: int,
dim: int,
intermediate_dim: int,
num_layers: int,
layer_scale_init_value: Optional[float] = None,
adanorm_num_embeddings: Optional[int] = None,
ckpt: Optional[str] = None,
):
super().__init__()
self.input_channels = input_channels
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
self.adanorm = adanorm_num_embeddings is not None
if adanorm_num_embeddings:
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
else:
self.norm = nn.LayerNorm(dim, eps=1e-6)
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
self.convnext = nn.ModuleList(
[
ConvNeXtBlock(
dim=dim,
intermediate_dim=intermediate_dim,
layer_scale_init_value=layer_scale_init_value,
adanorm_num_embeddings=adanorm_num_embeddings,
)
for _ in range(num_layers)
]
)
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
# print out self's state dict
if ckpt is not None:
state_dict = torch.load(ckpt, map_location='cpu')
state_dict = self._fuzzy_load_state_dict(state_dict)
self.load_state_dict(state_dict)
self.apply(self._init_weights)
def _fuzzy_load_state_dict(self, state_dict):
def _get_key(key):
return key.split('backbone.')[-1]
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('backbone'):
if v.shape == self.state_dict()[_get_key(k)].shape:
new_state_dict[_get_key(k)] = v
else:
new_state_dict[_get_key(k)] = self.state_dict()[_get_key(k)]
nn.init.trunc_normal_(new_state_dict[_get_key(k)], std=0.02)
nn.init.constant_(new_state_dict[_get_key(k)], 0)
return new_state_dict
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
bandwidth_id = kwargs.get('bandwidth_id', None)
x = self.embed(x)
if self.adanorm:
assert bandwidth_id is not None
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
else:
x = self.norm(x.transpose(1, 2))
x = x.transpose(1, 2)
for conv_block in self.convnext:
x = conv_block(x, cond_embedding_id=bandwidth_id)
x = self.final_layer_norm(x.transpose(1, 2))
return x
class VocosResNetBackbone(Backbone):
"""
Vocos backbone module built with ResBlocks.
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
num_blocks (int): Number of ResBlock1 blocks.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
"""
def __init__(
self, input_channels, dim, num_blocks, layer_scale_init_value=None,
):
super().__init__()
self.input_channels = input_channels
self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1))
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
self.resnet = nn.Sequential(
*[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)]
)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.embed(x)
x = self.resnet(x)
x = x.transpose(1, 2)
return x
if __name__ == '__main__':
# Define the model
model = VocosBackbone(
input_channels=1024,
dim=512,
intermediate_dim=1536,
num_layers=8,
ckpt="/root/OpenMusicVoco/vocos/pretrained.pth"
)
# Generate some random input
x = torch.randn(2, 1024, 100)
# Forward pass
output = model(x)
print(output.shape) # torch.Size([2, 100, 512]) |