Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023 Amphion. | |
# | |
# This code is modified from https://github.com/ZhangXInFD/SpeechTokenizer/blob/main/speechtokenizer/model.py | |
# Licensed under Apache License 2.0 | |
from .modules.seanet import SEANetEncoder, SEANetDecoder | |
from .modules.quantization import ResidualVectorQuantizer | |
import torch.nn as nn | |
from einops import rearrange | |
import torch | |
import numpy as np | |
class SpeechTokenizer(nn.Module): | |
def __init__(self, config): | |
""" | |
Parameters | |
---------- | |
config : json | |
Model Config. | |
""" | |
super().__init__() | |
self.encoder = SEANetEncoder( | |
n_filters=config.get("n_filters"), | |
dimension=config.get("dimension"), | |
ratios=config.get("strides"), | |
lstm=config.get("lstm_layers"), | |
bidirectional=config.get("bidirectional"), | |
dilation_base=config.get("dilation_base"), | |
residual_kernel_size=config.get("residual_kernel_size"), | |
n_residual_layers=config.get("n_residual_layers"), | |
activation=config.get("activation"), | |
) | |
self.sample_rate = config.get("sample_rate") | |
self.n_q = config.get("n_q") | |
self.downsample_rate = np.prod(config.get("strides")) | |
if config.get("dimension") != config.get("semantic_dimension"): | |
self.transform = nn.Linear( | |
config.get("dimension"), config.get("semantic_dimension") | |
) | |
else: | |
self.transform = nn.Identity() | |
self.quantizer = ResidualVectorQuantizer( | |
dimension=config.get("dimension"), | |
n_q=config.get("n_q"), | |
bins=config.get("codebook_size"), | |
) | |
self.decoder = SEANetDecoder( | |
n_filters=config.get("n_filters"), | |
dimension=config.get("dimension"), | |
ratios=config.get("strides"), | |
lstm=config.get("lstm_layers"), | |
bidirectional=False, | |
dilation_base=config.get("dilation_base"), | |
residual_kernel_size=config.get("residual_kernel_size"), | |
n_residual_layers=config.get("n_residual_layers"), | |
activation=config.get("activation"), | |
) | |
def load_from_checkpoint(cls, config_path: str, ckpt_path: str): | |
""" | |
Parameters | |
---------- | |
config_path : str | |
Path of model configuration file. | |
ckpt_path : str | |
Path of model checkpoint. | |
Returns | |
------- | |
model : SpeechTokenizer | |
SpeechTokenizer model. | |
""" | |
import json | |
with open(config_path) as f: | |
cfg = json.load(f) | |
model = cls(cfg) | |
params = torch.load(ckpt_path, map_location="cpu") | |
model.load_state_dict(params) | |
return model | |
def forward(self, x: torch.tensor, n_q: int = None, layers: list = [0]): | |
""" | |
Parameters | |
---------- | |
x : torch.tensor | |
Input wavs. Shape: (batch, channels, timesteps). | |
n_q : int, optional | |
Number of quantizers in RVQ used to encode. The default is all layers. | |
layers : list[int], optional | |
Layers of RVQ should return quantized result. The default is the first layer. | |
Returns | |
------- | |
o : torch.tensor | |
Output wavs. Shape: (batch, channels, timesteps). | |
commit_loss : torch.tensor | |
Commitment loss from residual vector quantizers. | |
feature : torch.tensor | |
Output of RVQ's first layer. Shape: (batch, timesteps, dimension) | |
""" | |
n_q = n_q if n_q else self.n_q | |
e = self.encoder(x) | |
quantized, codes, commit_loss, quantized_list = self.quantizer( | |
e, n_q=n_q, layers=layers | |
) | |
feature = rearrange(quantized_list[0], "b d t -> b t d") | |
feature = self.transform(feature) | |
o = self.decoder(quantized) | |
return o, commit_loss, feature | |
def forward_feature(self, x: torch.tensor, layers: list = None): | |
""" | |
Parameters | |
---------- | |
x : torch.tensor | |
Input wavs. Shape should be (batch, channels, timesteps). | |
layers : list[int], optional | |
Layers of RVQ should return quantized result. The default is all layers. | |
Returns | |
------- | |
quantized_list : list[torch.tensor] | |
Quantized of required layers. | |
""" | |
e = self.encoder(x) | |
layers = layers if layers else list(range(self.n_q)) | |
quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers) | |
return quantized_list | |
def encode(self, x: torch.tensor, n_q: int = None, st: int = None): | |
""" | |
Parameters | |
---------- | |
x : torch.tensor | |
Input wavs. Shape: (batch, channels, timesteps). | |
n_q : int, optional | |
Number of quantizers in RVQ used to encode. The default is all layers. | |
st : int, optional | |
Start quantizer index in RVQ. The default is 0. | |
Returns | |
------- | |
codes : torch.tensor | |
Output indices for each quantizer. Shape: (n_q, batch, timesteps) | |
""" | |
e = self.encoder(x) | |
if st is None: | |
st = 0 | |
n_q = n_q if n_q else self.n_q | |
codes = self.quantizer.encode(e, n_q=n_q, st=st) | |
return codes | |
def decode(self, codes: torch.tensor, st: int = 0): | |
""" | |
Parameters | |
---------- | |
codes : torch.tensor | |
Indices for each quantizer. Shape: (n_q, batch, timesteps). | |
st : int, optional | |
Start quantizer index in RVQ. The default is 0. | |
Returns | |
------- | |
o : torch.tensor | |
Reconstruct wavs from codes. Shape: (batch, channels, timesteps) | |
""" | |
quantized = self.quantizer.decode(codes, st=st) | |
o = self.decoder(quantized) | |
return o | |