Spaces:
Running
on
Zero
Running
on
Zero
Hecheng0625
commited on
Upload 12 files
Browse files- Amphion/models/ns3_codec/README.md +160 -0
- Amphion/models/ns3_codec/__init__.py +1 -0
- Amphion/models/ns3_codec/alias_free_torch/__init__.py +6 -0
- Amphion/models/ns3_codec/alias_free_torch/act.py +30 -0
- Amphion/models/ns3_codec/alias_free_torch/filter.py +99 -0
- Amphion/models/ns3_codec/alias_free_torch/resample.py +58 -0
- Amphion/models/ns3_codec/facodec.py +593 -0
- Amphion/models/ns3_codec/gradient_reversal.py +30 -0
- Amphion/models/ns3_codec/quantize/__init__.py +2 -0
- Amphion/models/ns3_codec/quantize/fvq.py +111 -0
- Amphion/models/ns3_codec/quantize/rvq.py +82 -0
- Amphion/models/ns3_codec/transformer.py +217 -0
Amphion/models/ns3_codec/README.md
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## FACodec: Speech Codec with Attribute Factorization used for NaturalSpeech 3
|
2 |
+
|
3 |
+
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/pdf/2403.03100.pdf)
|
4 |
+
[![demo](https://img.shields.io/badge/FACodec-Demo-red)](https://speechresearch.github.io/naturalspeech3/)
|
5 |
+
[![model](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-pink)](https://huggingface.co/amphion/naturalspeech3_facodec)
|
6 |
+
[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Spaces-yellow)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec)
|
7 |
+
|
8 |
+
## Overview
|
9 |
+
|
10 |
+
FACodec is a core component of the advanced text-to-speech (TTS) model NaturalSpeech 3. FACodec converts complex speech waveform into disentangled subspaces representing speech attributes of content, prosody, timbre, and acoustic details and reconstruct high-quality speech waveform from these attributes. FACodec decomposes complex speech into subspaces representing different attributes, thus simplifying the modeling of speech representation.
|
11 |
+
|
12 |
+
Research can use FACodec to develop different modes of TTS models, such as non-autoregressive based discrete diffusion (NaturalSpeech 3) or autoregressive models (like VALL-E).
|
13 |
+
|
14 |
+
<br>
|
15 |
+
<div align="center">
|
16 |
+
<img src="../../imgs/ns3/ns3_overview.png" width="65%">
|
17 |
+
</div>
|
18 |
+
<br>
|
19 |
+
|
20 |
+
<br>
|
21 |
+
<div align="center">
|
22 |
+
<img src="../../imgs/ns3/ns3_facodec.png" width="100%">
|
23 |
+
</div>
|
24 |
+
<br>
|
25 |
+
|
26 |
+
## Useage
|
27 |
+
|
28 |
+
Download the pre-trained FACodec model from HuggingFace: [Pretrained FACodec checkpoint](https://huggingface.co/amphion/naturalspeech3_facodec)
|
29 |
+
|
30 |
+
Install Amphion
|
31 |
+
```bash
|
32 |
+
git https://github.com/open-mmlab/Amphion.git
|
33 |
+
```
|
34 |
+
|
35 |
+
Few lines of code to use the pre-trained FACodec model
|
36 |
+
```python
|
37 |
+
from AmphionOpen.models.ns3_codec import FACodecEncoder, FACodecDecoder
|
38 |
+
|
39 |
+
fa_encoder = FACodecEncoder(
|
40 |
+
ngf=32,
|
41 |
+
up_ratios=[2, 4, 5, 5],
|
42 |
+
out_channels=256,
|
43 |
+
)
|
44 |
+
|
45 |
+
fa_decoder = FACodecDecoder(
|
46 |
+
in_channels=256,
|
47 |
+
upsample_initial_channel=1024,
|
48 |
+
ngf=32,
|
49 |
+
up_ratios=[5, 5, 4, 2],
|
50 |
+
vq_num_q_c=2,
|
51 |
+
vq_num_q_p=1,
|
52 |
+
vq_num_q_r=3,
|
53 |
+
vq_dim=256,
|
54 |
+
codebook_dim=8,
|
55 |
+
codebook_size_prosody=10,
|
56 |
+
codebook_size_content=10,
|
57 |
+
codebook_size_residual=10,
|
58 |
+
use_gr_x_timbre=True,
|
59 |
+
use_gr_residual_f0=True,
|
60 |
+
use_gr_residual_phone=True,
|
61 |
+
)
|
62 |
+
|
63 |
+
fa_encoder = torch.load("ns3_facodec_encoder.bin")
|
64 |
+
fa_decoder = torch.load("ns3_facodec_decoder.bin")
|
65 |
+
|
66 |
+
fa_encoder.eval()
|
67 |
+
fa_decoder.eval()
|
68 |
+
|
69 |
+
```
|
70 |
+
|
71 |
+
Test
|
72 |
+
```python
|
73 |
+
test_wav_path = "test.wav"
|
74 |
+
test_wav = librosa.load(test_wav_path, sr=16000)[0]
|
75 |
+
test_wav = torch.from_numpy(test_wav).float()
|
76 |
+
test_wav = test_wav.unsqueeze(0).unsqueeze(0)
|
77 |
+
|
78 |
+
with torch.no_grad():
|
79 |
+
|
80 |
+
# encode
|
81 |
+
enc_out = fa_encoder(test_wav)
|
82 |
+
print(enc_out.shape)
|
83 |
+
|
84 |
+
# quantize
|
85 |
+
vq_post_emb, vq_id, _, quantized, spk_embs = fa_decoder(enc_out, eval_vq=False, vq=True)
|
86 |
+
|
87 |
+
# latent after quantization
|
88 |
+
print(vq_post_emb.shape)
|
89 |
+
|
90 |
+
# codes
|
91 |
+
print("vq id shape:", vq_id.shape)
|
92 |
+
|
93 |
+
# get prosody code
|
94 |
+
prosody_code = vq_id[:1]
|
95 |
+
print("prosody code shape:", prosody_code.shape)
|
96 |
+
|
97 |
+
# get content code
|
98 |
+
cotent_code = vq_id[1:3]
|
99 |
+
print("content code shape:", cotent_code.shape)
|
100 |
+
|
101 |
+
# get residual code (acoustic detail codes)
|
102 |
+
residual_code = vq_id[3:]
|
103 |
+
print("residual code shape:", residual_code.shape)
|
104 |
+
|
105 |
+
# speaker embedding
|
106 |
+
print("speaker embedding shape:", spk_embs.shape)
|
107 |
+
|
108 |
+
# decode (recommand)
|
109 |
+
recon_wav = fa_decoder.inference(vq_post_emb, spk_embs)
|
110 |
+
print(recon_wav.shape)
|
111 |
+
sf.write("recon.wav", recon_wav[0][0].cpu().numpy(), 16000)
|
112 |
+
```
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
## Some Q&A
|
117 |
+
|
118 |
+
Q1: What audio sample rate does FACodec support? What is the hop size? How many codes will be generated for each frame?
|
119 |
+
|
120 |
+
A1: FACodec supports 16KHz speech audio. The hop size is 200 samples, and (16000/200) * 6 (total number of codebooks) codes will be generated for each frame.
|
121 |
+
|
122 |
+
Q2: Is it possible to train an autoregressive TTS model like VALL-E using FACodec?
|
123 |
+
|
124 |
+
A2: Yes. In fact, the authors of NaturalSpeech 3 have already employ explore the autoregressive generative model for discrete token generation with FACodec. They use an autoregressive language model to generate prosody codes, followed by a non-autoregressive model to generate the remaining content and acoustic details codes.
|
125 |
+
|
126 |
+
Q3: Is it possible to train a latent diffusion TTS model like NaturalSpeech2 using FACodec?
|
127 |
+
|
128 |
+
A3: Yes. You can use the latent getted after quanzaition as the modelling target for the latent diffusion model.
|
129 |
+
|
130 |
+
Q4: Can FACodec compress and reconstruct audio from other domains? Such as sound effects, music, etc.
|
131 |
+
|
132 |
+
A4: Since FACodec is designed for speech, it may not be suitable for other audio domains. However, it is possible to use the FACodec model to compress and reconstruct audio from other domains, but the quality may not be as good as the original audio.
|
133 |
+
|
134 |
+
Q5: Can FACodec be used for content feature for some other tasks like voice conversion?
|
135 |
+
|
136 |
+
A5: I think the answer is yes. Researchers can use the content code of FACodec as the content feature for voice conversion. We hope to see more research in this direction.
|
137 |
+
|
138 |
+
## Citations
|
139 |
+
|
140 |
+
If you use our FACodec model, please cite the following paper:
|
141 |
+
|
142 |
+
```bibtex
|
143 |
+
@misc{ju2024naturalspeech,
|
144 |
+
title={NaturalSpeech 3: Zero-Shot Speech Synthesis with Factorized Codec and Diffusion Models},
|
145 |
+
author={Zeqian Ju and Yuancheng Wang and Kai Shen and Xu Tan and Detai Xin and Dongchao Yang and Yanqing Liu and Yichong Leng and Kaitao Song and Siliang Tang and Zhizheng Wu and Tao Qin and Xiang-Yang Li and Wei Ye and Shikun Zhang and Jiang Bian and Lei He and Jinyu Li and Sheng Zhao},
|
146 |
+
year={2024},
|
147 |
+
eprint={2403.03100},
|
148 |
+
archivePrefix={arXiv},
|
149 |
+
primaryClass={eess.AS}
|
150 |
+
}
|
151 |
+
|
152 |
+
@article{zhang2023amphion,
|
153 |
+
title={Amphion: An Open-Source Audio, Music and Speech Generation Toolkit},
|
154 |
+
author={Xueyao Zhang and Liumeng Xue and Yicheng Gu and Yuancheng Wang and Haorui He and Chaoren Wang and Xi Chen and Zihao Fang and Haopeng Chen and Junan Zhang and Tze Ying Tang and Lexiao Zou and Mingxuan Wang and Jun Han and Kai Chen and Haizhou Li and Zhizheng Wu},
|
155 |
+
journal={arXiv},
|
156 |
+
year={2024},
|
157 |
+
volume={abs/2312.09911}
|
158 |
+
}
|
159 |
+
```
|
160 |
+
|
Amphion/models/ns3_codec/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .facodec import *
|
Amphion/models/ns3_codec/alias_free_torch/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
from .filter import *
|
5 |
+
from .resample import *
|
6 |
+
from .act import *
|
Amphion/models/ns3_codec/alias_free_torch/act.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from .resample import UpSample1d, DownSample1d
|
6 |
+
|
7 |
+
|
8 |
+
class Activation1d(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
activation,
|
12 |
+
up_ratio: int = 2,
|
13 |
+
down_ratio: int = 2,
|
14 |
+
up_kernel_size: int = 12,
|
15 |
+
down_kernel_size: int = 12,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.up_ratio = up_ratio
|
19 |
+
self.down_ratio = down_ratio
|
20 |
+
self.act = activation
|
21 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
22 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
23 |
+
|
24 |
+
# x: [B,C,T]
|
25 |
+
def forward(self, x):
|
26 |
+
x = self.upsample(x)
|
27 |
+
x = self.act(x)
|
28 |
+
x = self.downsample(x)
|
29 |
+
|
30 |
+
return x
|
Amphion/models/ns3_codec/alias_free_torch/filter.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import math
|
8 |
+
|
9 |
+
if "sinc" in dir(torch):
|
10 |
+
sinc = torch.sinc
|
11 |
+
else:
|
12 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
13 |
+
# https://adefossez.github.io/julius/julius/core.html
|
14 |
+
# LICENSE is in incl_licenses directory.
|
15 |
+
def sinc(x: torch.Tensor):
|
16 |
+
"""
|
17 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
18 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
19 |
+
"""
|
20 |
+
return torch.where(
|
21 |
+
x == 0,
|
22 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
23 |
+
torch.sin(math.pi * x) / math.pi / x,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
28 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
29 |
+
# LICENSE is in incl_licenses directory.
|
30 |
+
def kaiser_sinc_filter1d(
|
31 |
+
cutoff, half_width, kernel_size
|
32 |
+
): # return filter [1,1,kernel_size]
|
33 |
+
even = kernel_size % 2 == 0
|
34 |
+
half_size = kernel_size // 2
|
35 |
+
|
36 |
+
# For kaiser window
|
37 |
+
delta_f = 4 * half_width
|
38 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
39 |
+
if A > 50.0:
|
40 |
+
beta = 0.1102 * (A - 8.7)
|
41 |
+
elif A >= 21.0:
|
42 |
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
43 |
+
else:
|
44 |
+
beta = 0.0
|
45 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
46 |
+
|
47 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
48 |
+
if even:
|
49 |
+
time = torch.arange(-half_size, half_size) + 0.5
|
50 |
+
else:
|
51 |
+
time = torch.arange(kernel_size) - half_size
|
52 |
+
if cutoff == 0:
|
53 |
+
filter_ = torch.zeros_like(time)
|
54 |
+
else:
|
55 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
56 |
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
57 |
+
# of the constant component in the input signal.
|
58 |
+
filter_ /= filter_.sum()
|
59 |
+
filter = filter_.view(1, 1, kernel_size)
|
60 |
+
|
61 |
+
return filter
|
62 |
+
|
63 |
+
|
64 |
+
class LowPassFilter1d(nn.Module):
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
cutoff=0.5,
|
68 |
+
half_width=0.6,
|
69 |
+
stride: int = 1,
|
70 |
+
padding: bool = True,
|
71 |
+
padding_mode: str = "replicate",
|
72 |
+
kernel_size: int = 12,
|
73 |
+
):
|
74 |
+
# kernel_size should be even number for stylegan3 setup,
|
75 |
+
# in this implementation, odd number is also possible.
|
76 |
+
super().__init__()
|
77 |
+
if cutoff < -0.0:
|
78 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
79 |
+
if cutoff > 0.5:
|
80 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
81 |
+
self.kernel_size = kernel_size
|
82 |
+
self.even = kernel_size % 2 == 0
|
83 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
84 |
+
self.pad_right = kernel_size // 2
|
85 |
+
self.stride = stride
|
86 |
+
self.padding = padding
|
87 |
+
self.padding_mode = padding_mode
|
88 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
89 |
+
self.register_buffer("filter", filter)
|
90 |
+
|
91 |
+
# input [B, C, T]
|
92 |
+
def forward(self, x):
|
93 |
+
_, C, _ = x.shape
|
94 |
+
|
95 |
+
if self.padding:
|
96 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
97 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
98 |
+
|
99 |
+
return out
|
Amphion/models/ns3_codec/alias_free_torch/resample.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from .filter import LowPassFilter1d
|
7 |
+
from .filter import kaiser_sinc_filter1d
|
8 |
+
|
9 |
+
|
10 |
+
class UpSample1d(nn.Module):
|
11 |
+
def __init__(self, ratio=2, kernel_size=None):
|
12 |
+
super().__init__()
|
13 |
+
self.ratio = ratio
|
14 |
+
self.kernel_size = (
|
15 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
16 |
+
)
|
17 |
+
self.stride = ratio
|
18 |
+
self.pad = self.kernel_size // ratio - 1
|
19 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
20 |
+
self.pad_right = (
|
21 |
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
22 |
+
)
|
23 |
+
filter = kaiser_sinc_filter1d(
|
24 |
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
25 |
+
)
|
26 |
+
self.register_buffer("filter", filter)
|
27 |
+
|
28 |
+
# x: [B, C, T]
|
29 |
+
def forward(self, x):
|
30 |
+
_, C, _ = x.shape
|
31 |
+
|
32 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
33 |
+
x = self.ratio * F.conv_transpose1d(
|
34 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
35 |
+
)
|
36 |
+
x = x[..., self.pad_left : -self.pad_right]
|
37 |
+
|
38 |
+
return x
|
39 |
+
|
40 |
+
|
41 |
+
class DownSample1d(nn.Module):
|
42 |
+
def __init__(self, ratio=2, kernel_size=None):
|
43 |
+
super().__init__()
|
44 |
+
self.ratio = ratio
|
45 |
+
self.kernel_size = (
|
46 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
47 |
+
)
|
48 |
+
self.lowpass = LowPassFilter1d(
|
49 |
+
cutoff=0.5 / ratio,
|
50 |
+
half_width=0.6 / ratio,
|
51 |
+
stride=ratio,
|
52 |
+
kernel_size=self.kernel_size,
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
xx = self.lowpass(x)
|
57 |
+
|
58 |
+
return xx
|
Amphion/models/ns3_codec/facodec.py
ADDED
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn, sin, pow
|
4 |
+
from torch.nn import Parameter
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.nn.utils import weight_norm
|
7 |
+
from .alias_free_torch import *
|
8 |
+
from .quantize import *
|
9 |
+
from einops import rearrange
|
10 |
+
from einops.layers.torch import Rearrange
|
11 |
+
from .transformer import TransformerEncoder
|
12 |
+
from .gradient_reversal import GradientReversal
|
13 |
+
|
14 |
+
|
15 |
+
def init_weights(m):
|
16 |
+
if isinstance(m, nn.Conv1d):
|
17 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
18 |
+
nn.init.constant_(m.bias, 0)
|
19 |
+
|
20 |
+
|
21 |
+
def WNConv1d(*args, **kwargs):
|
22 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
23 |
+
|
24 |
+
|
25 |
+
def WNConvTranspose1d(*args, **kwargs):
|
26 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
27 |
+
|
28 |
+
|
29 |
+
class CNNLSTM(nn.Module):
|
30 |
+
def __init__(self, indim, outdim, head, global_pred=False):
|
31 |
+
super().__init__()
|
32 |
+
self.global_pred = global_pred
|
33 |
+
self.model = nn.Sequential(
|
34 |
+
ResidualUnit(indim, dilation=1),
|
35 |
+
ResidualUnit(indim, dilation=2),
|
36 |
+
ResidualUnit(indim, dilation=3),
|
37 |
+
Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
|
38 |
+
Rearrange("b c t -> b t c"),
|
39 |
+
)
|
40 |
+
self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
# x: [B, C, T]
|
44 |
+
x = self.model(x)
|
45 |
+
if self.global_pred:
|
46 |
+
x = torch.mean(x, dim=1, keepdim=False)
|
47 |
+
outs = [head(x) for head in self.heads]
|
48 |
+
return outs
|
49 |
+
|
50 |
+
|
51 |
+
class SnakeBeta(nn.Module):
|
52 |
+
"""
|
53 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
54 |
+
Shape:
|
55 |
+
- Input: (B, C, T)
|
56 |
+
- Output: (B, C, T), same shape as the input
|
57 |
+
Parameters:
|
58 |
+
- alpha - trainable parameter that controls frequency
|
59 |
+
- beta - trainable parameter that controls magnitude
|
60 |
+
References:
|
61 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
62 |
+
https://arxiv.org/abs/2006.08195
|
63 |
+
Examples:
|
64 |
+
>>> a1 = snakebeta(256)
|
65 |
+
>>> x = torch.randn(256)
|
66 |
+
>>> x = a1(x)
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(
|
70 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
71 |
+
):
|
72 |
+
"""
|
73 |
+
Initialization.
|
74 |
+
INPUT:
|
75 |
+
- in_features: shape of the input
|
76 |
+
- alpha - trainable parameter that controls frequency
|
77 |
+
- beta - trainable parameter that controls magnitude
|
78 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
79 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
80 |
+
alpha will be trained along with the rest of your model.
|
81 |
+
"""
|
82 |
+
super(SnakeBeta, self).__init__()
|
83 |
+
self.in_features = in_features
|
84 |
+
|
85 |
+
# initialize alpha
|
86 |
+
self.alpha_logscale = alpha_logscale
|
87 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
88 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
89 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
90 |
+
else: # linear scale alphas initialized to ones
|
91 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
92 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
93 |
+
|
94 |
+
self.alpha.requires_grad = alpha_trainable
|
95 |
+
self.beta.requires_grad = alpha_trainable
|
96 |
+
|
97 |
+
self.no_div_by_zero = 0.000000001
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
"""
|
101 |
+
Forward pass of the function.
|
102 |
+
Applies the function to the input elementwise.
|
103 |
+
SnakeBeta := x + 1/b * sin^2 (xa)
|
104 |
+
"""
|
105 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
106 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
107 |
+
if self.alpha_logscale:
|
108 |
+
alpha = torch.exp(alpha)
|
109 |
+
beta = torch.exp(beta)
|
110 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
111 |
+
|
112 |
+
return x
|
113 |
+
|
114 |
+
|
115 |
+
class ResidualUnit(nn.Module):
|
116 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
117 |
+
super().__init__()
|
118 |
+
pad = ((7 - 1) * dilation) // 2
|
119 |
+
self.block = nn.Sequential(
|
120 |
+
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
|
121 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
122 |
+
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
|
123 |
+
WNConv1d(dim, dim, kernel_size=1),
|
124 |
+
)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
return x + self.block(x)
|
128 |
+
|
129 |
+
|
130 |
+
class EncoderBlock(nn.Module):
|
131 |
+
def __init__(self, dim: int = 16, stride: int = 1):
|
132 |
+
super().__init__()
|
133 |
+
self.block = nn.Sequential(
|
134 |
+
ResidualUnit(dim // 2, dilation=1),
|
135 |
+
ResidualUnit(dim // 2, dilation=3),
|
136 |
+
ResidualUnit(dim // 2, dilation=9),
|
137 |
+
Activation1d(activation=SnakeBeta(dim // 2, alpha_logscale=True)),
|
138 |
+
WNConv1d(
|
139 |
+
dim // 2,
|
140 |
+
dim,
|
141 |
+
kernel_size=2 * stride,
|
142 |
+
stride=stride,
|
143 |
+
padding=stride // 2 + stride % 2,
|
144 |
+
),
|
145 |
+
)
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
return self.block(x)
|
149 |
+
|
150 |
+
|
151 |
+
class FACodecEncoder(nn.Module):
|
152 |
+
def __init__(
|
153 |
+
self,
|
154 |
+
ngf=32,
|
155 |
+
up_ratios=(2, 4, 5, 5),
|
156 |
+
out_channels=1024,
|
157 |
+
):
|
158 |
+
super().__init__()
|
159 |
+
self.hop_length = np.prod(up_ratios)
|
160 |
+
self.up_ratios = up_ratios
|
161 |
+
|
162 |
+
# Create first convolution
|
163 |
+
d_model = ngf
|
164 |
+
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
165 |
+
|
166 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
167 |
+
for stride in up_ratios:
|
168 |
+
d_model *= 2
|
169 |
+
self.block += [EncoderBlock(d_model, stride=stride)]
|
170 |
+
|
171 |
+
# Create last convolution
|
172 |
+
self.block += [
|
173 |
+
Activation1d(activation=SnakeBeta(d_model, alpha_logscale=True)),
|
174 |
+
WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
|
175 |
+
]
|
176 |
+
|
177 |
+
# Wrap black into nn.Sequential
|
178 |
+
self.block = nn.Sequential(*self.block)
|
179 |
+
self.enc_dim = d_model
|
180 |
+
|
181 |
+
self.reset_parameters()
|
182 |
+
|
183 |
+
def forward(self, x):
|
184 |
+
out = self.block(x)
|
185 |
+
return out
|
186 |
+
|
187 |
+
def inference(self, x):
|
188 |
+
return self.block(x)
|
189 |
+
|
190 |
+
def remove_weight_norm(self):
|
191 |
+
"""Remove weight normalization module from all of the layers."""
|
192 |
+
|
193 |
+
def _remove_weight_norm(m):
|
194 |
+
try:
|
195 |
+
torch.nn.utils.remove_weight_norm(m)
|
196 |
+
except ValueError: # this module didn't have weight norm
|
197 |
+
return
|
198 |
+
|
199 |
+
self.apply(_remove_weight_norm)
|
200 |
+
|
201 |
+
def apply_weight_norm(self):
|
202 |
+
"""Apply weight normalization module from all of the layers."""
|
203 |
+
|
204 |
+
def _apply_weight_norm(m):
|
205 |
+
if isinstance(m, nn.Conv1d):
|
206 |
+
torch.nn.utils.weight_norm(m)
|
207 |
+
|
208 |
+
self.apply(_apply_weight_norm)
|
209 |
+
|
210 |
+
def reset_parameters(self):
|
211 |
+
self.apply(init_weights)
|
212 |
+
|
213 |
+
|
214 |
+
class DecoderBlock(nn.Module):
|
215 |
+
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
|
216 |
+
super().__init__()
|
217 |
+
self.block = nn.Sequential(
|
218 |
+
Activation1d(activation=SnakeBeta(input_dim, alpha_logscale=True)),
|
219 |
+
WNConvTranspose1d(
|
220 |
+
input_dim,
|
221 |
+
output_dim,
|
222 |
+
kernel_size=2 * stride,
|
223 |
+
stride=stride,
|
224 |
+
padding=stride // 2 + stride % 2,
|
225 |
+
output_padding=stride % 2,
|
226 |
+
),
|
227 |
+
ResidualUnit(output_dim, dilation=1),
|
228 |
+
ResidualUnit(output_dim, dilation=3),
|
229 |
+
ResidualUnit(output_dim, dilation=9),
|
230 |
+
)
|
231 |
+
|
232 |
+
def forward(self, x):
|
233 |
+
return self.block(x)
|
234 |
+
|
235 |
+
|
236 |
+
class FACodecDecoder(nn.Module):
|
237 |
+
def __init__(
|
238 |
+
self,
|
239 |
+
in_channels=256,
|
240 |
+
upsample_initial_channel=1536,
|
241 |
+
ngf=32,
|
242 |
+
up_ratios=(5, 5, 4, 2),
|
243 |
+
vq_num_q_c=2,
|
244 |
+
vq_num_q_p=1,
|
245 |
+
vq_num_q_r=3,
|
246 |
+
vq_dim=1024,
|
247 |
+
vq_commit_weight=0.005,
|
248 |
+
vq_weight_init=False,
|
249 |
+
vq_full_commit_loss=False,
|
250 |
+
codebook_dim=8,
|
251 |
+
codebook_size_prosody=10, # true codebook size is equal to 2^codebook_size
|
252 |
+
codebook_size_content=10,
|
253 |
+
codebook_size_residual=10,
|
254 |
+
quantizer_dropout=0.0,
|
255 |
+
dropout_type="linear",
|
256 |
+
use_gr_content_f0=False,
|
257 |
+
use_gr_prosody_phone=False,
|
258 |
+
use_gr_residual_f0=False,
|
259 |
+
use_gr_residual_phone=False,
|
260 |
+
use_gr_x_timbre=False,
|
261 |
+
use_random_mask_residual=True,
|
262 |
+
prob_random_mask_residual=0.75,
|
263 |
+
):
|
264 |
+
super().__init__()
|
265 |
+
self.hop_length = np.prod(up_ratios)
|
266 |
+
self.ngf = ngf
|
267 |
+
self.up_ratios = up_ratios
|
268 |
+
|
269 |
+
self.use_random_mask_residual = use_random_mask_residual
|
270 |
+
self.prob_random_mask_residual = prob_random_mask_residual
|
271 |
+
|
272 |
+
self.vq_num_q_p = vq_num_q_p
|
273 |
+
self.vq_num_q_c = vq_num_q_c
|
274 |
+
self.vq_num_q_r = vq_num_q_r
|
275 |
+
|
276 |
+
self.codebook_size_prosody = codebook_size_prosody
|
277 |
+
self.codebook_size_content = codebook_size_content
|
278 |
+
self.codebook_size_residual = codebook_size_residual
|
279 |
+
|
280 |
+
quantizer_class = ResidualVQ
|
281 |
+
|
282 |
+
self.quantizer = nn.ModuleList()
|
283 |
+
|
284 |
+
# prosody
|
285 |
+
quantizer = quantizer_class(
|
286 |
+
num_quantizers=vq_num_q_p,
|
287 |
+
dim=vq_dim,
|
288 |
+
codebook_size=codebook_size_prosody,
|
289 |
+
codebook_dim=codebook_dim,
|
290 |
+
threshold_ema_dead_code=2,
|
291 |
+
commitment=vq_commit_weight,
|
292 |
+
weight_init=vq_weight_init,
|
293 |
+
full_commit_loss=vq_full_commit_loss,
|
294 |
+
quantizer_dropout=quantizer_dropout,
|
295 |
+
dropout_type=dropout_type,
|
296 |
+
)
|
297 |
+
self.quantizer.append(quantizer)
|
298 |
+
|
299 |
+
# phone
|
300 |
+
quantizer = quantizer_class(
|
301 |
+
num_quantizers=vq_num_q_c,
|
302 |
+
dim=vq_dim,
|
303 |
+
codebook_size=codebook_size_content,
|
304 |
+
codebook_dim=codebook_dim,
|
305 |
+
threshold_ema_dead_code=2,
|
306 |
+
commitment=vq_commit_weight,
|
307 |
+
weight_init=vq_weight_init,
|
308 |
+
full_commit_loss=vq_full_commit_loss,
|
309 |
+
quantizer_dropout=quantizer_dropout,
|
310 |
+
dropout_type=dropout_type,
|
311 |
+
)
|
312 |
+
self.quantizer.append(quantizer)
|
313 |
+
|
314 |
+
# residual
|
315 |
+
if self.vq_num_q_r > 0:
|
316 |
+
quantizer = quantizer_class(
|
317 |
+
num_quantizers=vq_num_q_r,
|
318 |
+
dim=vq_dim,
|
319 |
+
codebook_size=codebook_size_residual,
|
320 |
+
codebook_dim=codebook_dim,
|
321 |
+
threshold_ema_dead_code=2,
|
322 |
+
commitment=vq_commit_weight,
|
323 |
+
weight_init=vq_weight_init,
|
324 |
+
full_commit_loss=vq_full_commit_loss,
|
325 |
+
quantizer_dropout=quantizer_dropout,
|
326 |
+
dropout_type=dropout_type,
|
327 |
+
)
|
328 |
+
self.quantizer.append(quantizer)
|
329 |
+
|
330 |
+
# Add first conv layer
|
331 |
+
channels = upsample_initial_channel
|
332 |
+
layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
|
333 |
+
|
334 |
+
# Add upsampling + MRF blocks
|
335 |
+
for i, stride in enumerate(up_ratios):
|
336 |
+
input_dim = channels // 2**i
|
337 |
+
output_dim = channels // 2 ** (i + 1)
|
338 |
+
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
339 |
+
|
340 |
+
# Add final conv layer
|
341 |
+
layers += [
|
342 |
+
Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
|
343 |
+
WNConv1d(output_dim, 1, kernel_size=7, padding=3),
|
344 |
+
nn.Tanh(),
|
345 |
+
]
|
346 |
+
|
347 |
+
self.model = nn.Sequential(*layers)
|
348 |
+
|
349 |
+
self.timbre_encoder = TransformerEncoder(
|
350 |
+
enc_emb_tokens=None,
|
351 |
+
encoder_layer=4,
|
352 |
+
encoder_hidden=256,
|
353 |
+
encoder_head=4,
|
354 |
+
conv_filter_size=1024,
|
355 |
+
conv_kernel_size=5,
|
356 |
+
encoder_dropout=0.1,
|
357 |
+
use_cln=False,
|
358 |
+
)
|
359 |
+
|
360 |
+
self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
|
361 |
+
self.timbre_linear.bias.data[:in_channels] = 1
|
362 |
+
self.timbre_linear.bias.data[in_channels:] = 0
|
363 |
+
self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
|
364 |
+
|
365 |
+
self.f0_predictor = CNNLSTM(in_channels, 1, 2)
|
366 |
+
self.phone_predictor = CNNLSTM(in_channels, 5003, 1)
|
367 |
+
|
368 |
+
self.use_gr_content_f0 = use_gr_content_f0
|
369 |
+
self.use_gr_prosody_phone = use_gr_prosody_phone
|
370 |
+
self.use_gr_residual_f0 = use_gr_residual_f0
|
371 |
+
self.use_gr_residual_phone = use_gr_residual_phone
|
372 |
+
self.use_gr_x_timbre = use_gr_x_timbre
|
373 |
+
|
374 |
+
if self.vq_num_q_r > 0 and self.use_gr_residual_f0:
|
375 |
+
self.res_f0_predictor = nn.Sequential(
|
376 |
+
GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
|
377 |
+
)
|
378 |
+
|
379 |
+
if self.vq_num_q_r > 0 and self.use_gr_residual_phone > 0:
|
380 |
+
self.res_phone_predictor = nn.Sequential(
|
381 |
+
GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
|
382 |
+
)
|
383 |
+
|
384 |
+
if self.use_gr_content_f0:
|
385 |
+
self.content_f0_predictor = nn.Sequential(
|
386 |
+
GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2)
|
387 |
+
)
|
388 |
+
|
389 |
+
if self.use_gr_prosody_phone:
|
390 |
+
self.prosody_phone_predictor = nn.Sequential(
|
391 |
+
GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1)
|
392 |
+
)
|
393 |
+
|
394 |
+
if self.use_gr_x_timbre:
|
395 |
+
self.x_timbre_predictor = nn.Sequential(
|
396 |
+
GradientReversal(alpha=1),
|
397 |
+
CNNLSTM(in_channels, 245200, 1, global_pred=True),
|
398 |
+
)
|
399 |
+
|
400 |
+
self.reset_parameters()
|
401 |
+
|
402 |
+
def quantize(self, x, n_quantizers=None):
|
403 |
+
outs, qs, commit_loss, quantized_buf = 0, [], [], []
|
404 |
+
|
405 |
+
# prosody
|
406 |
+
f0_input = x # (B, d, T)
|
407 |
+
f0_quantizer = self.quantizer[0]
|
408 |
+
out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers)
|
409 |
+
outs += out
|
410 |
+
qs.append(q)
|
411 |
+
quantized_buf.append(quantized.sum(0))
|
412 |
+
commit_loss.append(commit)
|
413 |
+
|
414 |
+
# phone
|
415 |
+
phone_input = x
|
416 |
+
phone_quantizer = self.quantizer[1]
|
417 |
+
out, q, commit, quantized = phone_quantizer(
|
418 |
+
phone_input, n_quantizers=n_quantizers
|
419 |
+
)
|
420 |
+
outs += out
|
421 |
+
qs.append(q)
|
422 |
+
quantized_buf.append(quantized.sum(0))
|
423 |
+
commit_loss.append(commit)
|
424 |
+
|
425 |
+
# residual
|
426 |
+
if self.vq_num_q_r > 0:
|
427 |
+
residual_quantizer = self.quantizer[2]
|
428 |
+
residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach()
|
429 |
+
out, q, commit, quantized = residual_quantizer(
|
430 |
+
residual_input, n_quantizers=n_quantizers
|
431 |
+
)
|
432 |
+
outs += out
|
433 |
+
qs.append(q)
|
434 |
+
quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T]
|
435 |
+
commit_loss.append(commit)
|
436 |
+
|
437 |
+
qs = torch.cat(qs, dim=0)
|
438 |
+
commit_loss = torch.cat(commit_loss, dim=0)
|
439 |
+
return outs, qs, commit_loss, quantized_buf
|
440 |
+
|
441 |
+
def forward(
|
442 |
+
self,
|
443 |
+
x,
|
444 |
+
vq=True,
|
445 |
+
get_vq=False,
|
446 |
+
eval_vq=True,
|
447 |
+
speaker_embedding=None,
|
448 |
+
n_quantizers=None,
|
449 |
+
quantized=None,
|
450 |
+
):
|
451 |
+
if get_vq:
|
452 |
+
return self.quantizer.get_emb()
|
453 |
+
if vq is True:
|
454 |
+
if eval_vq:
|
455 |
+
self.quantizer.eval()
|
456 |
+
x_timbre = x
|
457 |
+
outs, qs, commit_loss, quantized_buf = self.quantize(
|
458 |
+
x, n_quantizers=n_quantizers
|
459 |
+
)
|
460 |
+
|
461 |
+
x_timbre = x_timbre.transpose(1, 2)
|
462 |
+
x_timbre = self.timbre_encoder(x_timbre, None, None)
|
463 |
+
x_timbre = x_timbre.transpose(1, 2)
|
464 |
+
spk_embs = torch.mean(x_timbre, dim=2)
|
465 |
+
return outs, qs, commit_loss, quantized_buf, spk_embs
|
466 |
+
|
467 |
+
out = {}
|
468 |
+
|
469 |
+
layer_0 = quantized[0]
|
470 |
+
f0, uv = self.f0_predictor(layer_0)
|
471 |
+
f0 = rearrange(f0, "... 1 -> ...")
|
472 |
+
uv = rearrange(uv, "... 1 -> ...")
|
473 |
+
|
474 |
+
layer_1 = quantized[1]
|
475 |
+
(phone,) = self.phone_predictor(layer_1)
|
476 |
+
|
477 |
+
out = {"f0": f0, "uv": uv, "phone": phone}
|
478 |
+
|
479 |
+
if self.use_gr_prosody_phone:
|
480 |
+
(prosody_phone,) = self.prosody_phone_predictor(layer_0)
|
481 |
+
out["prosody_phone"] = prosody_phone
|
482 |
+
|
483 |
+
if self.use_gr_content_f0:
|
484 |
+
content_f0, content_uv = self.content_f0_predictor(layer_1)
|
485 |
+
content_f0 = rearrange(content_f0, "... 1 -> ...")
|
486 |
+
content_uv = rearrange(content_uv, "... 1 -> ...")
|
487 |
+
out["content_f0"] = content_f0
|
488 |
+
out["content_uv"] = content_uv
|
489 |
+
|
490 |
+
if self.vq_num_q_r > 0:
|
491 |
+
layer_2 = quantized[2]
|
492 |
+
|
493 |
+
if self.use_gr_residual_f0:
|
494 |
+
res_f0, res_uv = self.res_f0_predictor(layer_2)
|
495 |
+
res_f0 = rearrange(res_f0, "... 1 -> ...")
|
496 |
+
res_uv = rearrange(res_uv, "... 1 -> ...")
|
497 |
+
out["res_f0"] = res_f0
|
498 |
+
out["res_uv"] = res_uv
|
499 |
+
|
500 |
+
if self.use_gr_residual_phone:
|
501 |
+
(res_phone,) = self.res_phone_predictor(layer_2)
|
502 |
+
out["res_phone"] = res_phone
|
503 |
+
|
504 |
+
style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
|
505 |
+
gamma, beta = style.chunk(2, 1) # (B, d, 1)
|
506 |
+
if self.vq_num_q_r > 0:
|
507 |
+
if self.use_random_mask_residual:
|
508 |
+
bsz = quantized[2].shape[0]
|
509 |
+
res_mask = np.random.choice(
|
510 |
+
[0, 1],
|
511 |
+
size=bsz,
|
512 |
+
p=[
|
513 |
+
self.prob_random_mask_residual,
|
514 |
+
1 - self.prob_random_mask_residual,
|
515 |
+
],
|
516 |
+
)
|
517 |
+
res_mask = (
|
518 |
+
torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1)
|
519 |
+
) # (B, 1, 1)
|
520 |
+
res_mask = res_mask.to(
|
521 |
+
device=quantized[2].device, dtype=quantized[2].dtype
|
522 |
+
)
|
523 |
+
x = (
|
524 |
+
quantized[0].detach()
|
525 |
+
+ quantized[1].detach()
|
526 |
+
+ quantized[2] * res_mask
|
527 |
+
)
|
528 |
+
# x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] * res_mask
|
529 |
+
else:
|
530 |
+
x = quantized[0].detach() + quantized[1].detach() + quantized[2]
|
531 |
+
# x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2]
|
532 |
+
else:
|
533 |
+
x = quantized[0].detach() + quantized[1].detach()
|
534 |
+
# x = quantized_perturbe[0].detach() + quantized[1].detach()
|
535 |
+
|
536 |
+
if self.use_gr_x_timbre:
|
537 |
+
(x_timbre,) = self.x_timbre_predictor(x)
|
538 |
+
out["x_timbre"] = x_timbre
|
539 |
+
|
540 |
+
x = x.transpose(1, 2)
|
541 |
+
x = self.timbre_norm(x)
|
542 |
+
x = x.transpose(1, 2)
|
543 |
+
x = x * gamma + beta
|
544 |
+
|
545 |
+
x = self.model(x)
|
546 |
+
out["audio"] = x
|
547 |
+
|
548 |
+
return out
|
549 |
+
|
550 |
+
def vq2emb(self, vq, use_residual_code=True):
|
551 |
+
# vq: [num_quantizer, B, T]
|
552 |
+
self.quantizer = self.quantizer.eval()
|
553 |
+
out = 0
|
554 |
+
out += self.quantizer[0].vq2emb(vq[0 : self.vq_num_q_p])
|
555 |
+
out += self.quantizer[1].vq2emb(
|
556 |
+
vq[self.vq_num_q_p : self.vq_num_q_p + self.vq_num_q_c]
|
557 |
+
)
|
558 |
+
if self.vq_num_q_r > 0 and use_residual_code:
|
559 |
+
out += self.quantizer[2].vq2emb(vq[self.vq_num_q_p + self.vq_num_q_c :])
|
560 |
+
return out
|
561 |
+
|
562 |
+
def inference(self, x, speaker_embedding):
|
563 |
+
style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
|
564 |
+
gamma, beta = style.chunk(2, 1) # (B, d, 1)
|
565 |
+
x = x.transpose(1, 2)
|
566 |
+
x = self.timbre_norm(x)
|
567 |
+
x = x.transpose(1, 2)
|
568 |
+
x = x * gamma + beta
|
569 |
+
x = self.model(x)
|
570 |
+
return x
|
571 |
+
|
572 |
+
def remove_weight_norm(self):
|
573 |
+
"""Remove weight normalization module from all of the layers."""
|
574 |
+
|
575 |
+
def _remove_weight_norm(m):
|
576 |
+
try:
|
577 |
+
torch.nn.utils.remove_weight_norm(m)
|
578 |
+
except ValueError: # this module didn't have weight norm
|
579 |
+
return
|
580 |
+
|
581 |
+
self.apply(_remove_weight_norm)
|
582 |
+
|
583 |
+
def apply_weight_norm(self):
|
584 |
+
"""Apply weight normalization module from all of the layers."""
|
585 |
+
|
586 |
+
def _apply_weight_norm(m):
|
587 |
+
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
|
588 |
+
torch.nn.utils.weight_norm(m)
|
589 |
+
|
590 |
+
self.apply(_apply_weight_norm)
|
591 |
+
|
592 |
+
def reset_parameters(self):
|
593 |
+
self.apply(init_weights)
|
Amphion/models/ns3_codec/gradient_reversal.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.autograd import Function
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class GradientReversal(Function):
|
7 |
+
@staticmethod
|
8 |
+
def forward(ctx, x, alpha):
|
9 |
+
ctx.save_for_backward(x, alpha)
|
10 |
+
return x
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def backward(ctx, grad_output):
|
14 |
+
grad_input = None
|
15 |
+
_, alpha = ctx.saved_tensors
|
16 |
+
if ctx.needs_input_grad[0]:
|
17 |
+
grad_input = -alpha * grad_output
|
18 |
+
return grad_input, None
|
19 |
+
|
20 |
+
|
21 |
+
revgrad = GradientReversal.apply
|
22 |
+
|
23 |
+
|
24 |
+
class GradientReversal(nn.Module):
|
25 |
+
def __init__(self, alpha):
|
26 |
+
super().__init__()
|
27 |
+
self.alpha = torch.tensor(alpha, requires_grad=False)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
return revgrad(x, self.alpha)
|
Amphion/models/ns3_codec/quantize/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .fvq import *
|
2 |
+
from .rvq import *
|
Amphion/models/ns3_codec/quantize/fvq.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from torch.nn.utils import weight_norm
|
9 |
+
|
10 |
+
|
11 |
+
class FactorizedVectorQuantize(nn.Module):
|
12 |
+
def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs):
|
13 |
+
super().__init__()
|
14 |
+
self.codebook_size = codebook_size
|
15 |
+
self.codebook_dim = codebook_dim
|
16 |
+
self.commitment = commitment
|
17 |
+
|
18 |
+
if dim != self.codebook_dim:
|
19 |
+
self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim))
|
20 |
+
self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim))
|
21 |
+
else:
|
22 |
+
self.in_proj = nn.Identity()
|
23 |
+
self.out_proj = nn.Identity()
|
24 |
+
self._codebook = nn.Embedding(codebook_size, self.codebook_dim)
|
25 |
+
|
26 |
+
@property
|
27 |
+
def codebook(self):
|
28 |
+
return self._codebook
|
29 |
+
|
30 |
+
def forward(self, z):
|
31 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
32 |
+
the corresponding codebook vectors
|
33 |
+
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
z : Tensor[B x D x T]
|
37 |
+
|
38 |
+
Returns
|
39 |
+
-------
|
40 |
+
Tensor[B x D x T]
|
41 |
+
Quantized continuous representation of input
|
42 |
+
Tensor[1]
|
43 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
44 |
+
entries
|
45 |
+
Tensor[1]
|
46 |
+
Codebook loss to update the codebook
|
47 |
+
Tensor[B x T]
|
48 |
+
Codebook indices (quantized discrete representation of input)
|
49 |
+
Tensor[B x D x T]
|
50 |
+
Projected latents (continuous representation of input before quantization)
|
51 |
+
"""
|
52 |
+
# transpose since we use linear
|
53 |
+
|
54 |
+
z = rearrange(z, "b d t -> b t d")
|
55 |
+
|
56 |
+
# Factorized codes project input into low-dimensional space
|
57 |
+
z_e = self.in_proj(z) # z_e : (B x T x D)
|
58 |
+
z_e = rearrange(z_e, "b t d -> b d t")
|
59 |
+
z_q, indices = self.decode_latents(z_e)
|
60 |
+
|
61 |
+
if self.training:
|
62 |
+
commitment_loss = (
|
63 |
+
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
64 |
+
* self.commitment
|
65 |
+
)
|
66 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
67 |
+
commit_loss = commitment_loss + codebook_loss
|
68 |
+
else:
|
69 |
+
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
70 |
+
|
71 |
+
z_q = (
|
72 |
+
z_e + (z_q - z_e).detach()
|
73 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
74 |
+
|
75 |
+
z_q = rearrange(z_q, "b d t -> b t d")
|
76 |
+
z_q = self.out_proj(z_q)
|
77 |
+
z_q = rearrange(z_q, "b t d -> b d t")
|
78 |
+
|
79 |
+
return z_q, indices, commit_loss
|
80 |
+
|
81 |
+
def vq2emb(self, vq, proj=True):
|
82 |
+
emb = self.embed_code(vq)
|
83 |
+
if proj:
|
84 |
+
emb = self.out_proj(emb)
|
85 |
+
return emb.transpose(1, 2)
|
86 |
+
|
87 |
+
def get_emb(self):
|
88 |
+
return self.codebook.weight
|
89 |
+
|
90 |
+
def embed_code(self, embed_id):
|
91 |
+
return F.embedding(embed_id, self.codebook.weight)
|
92 |
+
|
93 |
+
def decode_code(self, embed_id):
|
94 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
95 |
+
|
96 |
+
def decode_latents(self, latents):
|
97 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
98 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
99 |
+
# L2 normalize encodings and codebook
|
100 |
+
encodings = F.normalize(encodings)
|
101 |
+
codebook = F.normalize(codebook)
|
102 |
+
|
103 |
+
# Compute euclidean distance with codebook
|
104 |
+
dist = (
|
105 |
+
encodings.pow(2).sum(1, keepdim=True)
|
106 |
+
- 2 * encodings @ codebook.t()
|
107 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
108 |
+
)
|
109 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
110 |
+
z_q = self.decode_code(indices)
|
111 |
+
return z_q, indices
|
Amphion/models/ns3_codec/quantize/rvq.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from .fvq import FactorizedVectorQuantize
|
5 |
+
|
6 |
+
|
7 |
+
class ResidualVQ(nn.Module):
|
8 |
+
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
|
9 |
+
|
10 |
+
def __init__(self, *, num_quantizers, codebook_size, **kwargs):
|
11 |
+
super().__init__()
|
12 |
+
VQ = FactorizedVectorQuantize
|
13 |
+
if type(codebook_size) == int:
|
14 |
+
codebook_size = [codebook_size] * num_quantizers
|
15 |
+
self.layers = nn.ModuleList(
|
16 |
+
[VQ(codebook_size=2**size, **kwargs) for size in codebook_size]
|
17 |
+
)
|
18 |
+
self.num_quantizers = num_quantizers
|
19 |
+
self.quantizer_dropout = kwargs.get("quantizer_dropout", 0.0)
|
20 |
+
self.dropout_type = kwargs.get("dropout_type", None)
|
21 |
+
|
22 |
+
def forward(self, x, n_quantizers=None):
|
23 |
+
quantized_out = 0.0
|
24 |
+
residual = x
|
25 |
+
|
26 |
+
all_losses = []
|
27 |
+
all_indices = []
|
28 |
+
all_quantized = []
|
29 |
+
|
30 |
+
if n_quantizers is None:
|
31 |
+
n_quantizers = self.num_quantizers
|
32 |
+
if self.training:
|
33 |
+
n_quantizers = torch.ones((x.shape[0],)) * self.num_quantizers + 1
|
34 |
+
if self.dropout_type == "linear":
|
35 |
+
dropout = torch.randint(1, self.num_quantizers + 1, (x.shape[0],))
|
36 |
+
elif self.dropout_type == "exp":
|
37 |
+
dropout = torch.randint(
|
38 |
+
1, int(math.log2(self.num_quantizers)), (x.shape[0],)
|
39 |
+
)
|
40 |
+
dropout = torch.pow(2, dropout)
|
41 |
+
n_dropout = int(x.shape[0] * self.quantizer_dropout)
|
42 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
43 |
+
n_quantizers = n_quantizers.to(x.device)
|
44 |
+
|
45 |
+
for idx, layer in enumerate(self.layers):
|
46 |
+
if not self.training and idx >= n_quantizers:
|
47 |
+
break
|
48 |
+
quantized, indices, loss = layer(residual)
|
49 |
+
|
50 |
+
mask = (
|
51 |
+
torch.full((x.shape[0],), fill_value=idx, device=x.device)
|
52 |
+
< n_quantizers
|
53 |
+
)
|
54 |
+
|
55 |
+
residual = residual - quantized
|
56 |
+
|
57 |
+
quantized_out = quantized_out + quantized * mask[:, None, None]
|
58 |
+
|
59 |
+
# loss
|
60 |
+
loss = (loss * mask).mean()
|
61 |
+
|
62 |
+
all_indices.append(indices)
|
63 |
+
all_losses.append(loss)
|
64 |
+
all_quantized.append(quantized)
|
65 |
+
all_losses, all_indices, all_quantized = map(
|
66 |
+
torch.stack, (all_losses, all_indices, all_quantized)
|
67 |
+
)
|
68 |
+
return quantized_out, all_indices, all_losses, all_quantized
|
69 |
+
|
70 |
+
def vq2emb(self, vq):
|
71 |
+
# vq: [n_quantizers, B, T]
|
72 |
+
quantized_out = 0.0
|
73 |
+
for idx, layer in enumerate(self.layers):
|
74 |
+
quantized = layer.vq2emb(vq[idx])
|
75 |
+
quantized_out += quantized
|
76 |
+
return quantized_out
|
77 |
+
|
78 |
+
def get_emb(self):
|
79 |
+
embs = []
|
80 |
+
for idx, layer in enumerate(self.layers):
|
81 |
+
embs.append(layer.get_emb())
|
82 |
+
return embs
|
Amphion/models/ns3_codec/transformer.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import math
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class StyleAdaptiveLayerNorm(nn.Module):
|
9 |
+
def __init__(self, normalized_shape, eps=1e-5):
|
10 |
+
super().__init__()
|
11 |
+
self.in_dim = normalized_shape
|
12 |
+
self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False)
|
13 |
+
self.style = nn.Linear(self.in_dim, self.in_dim * 2)
|
14 |
+
self.style.bias.data[: self.in_dim] = 1
|
15 |
+
self.style.bias.data[self.in_dim :] = 0
|
16 |
+
|
17 |
+
|
18 |
+
class PositionalEncoding(nn.Module):
|
19 |
+
def __init__(self, d_model, dropout, max_len=5000):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
self.dropout = dropout
|
23 |
+
position = torch.arange(max_len).unsqueeze(1)
|
24 |
+
div_term = torch.exp(
|
25 |
+
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
|
26 |
+
)
|
27 |
+
pe = torch.zeros(max_len, 1, d_model)
|
28 |
+
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
29 |
+
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
30 |
+
self.register_buffer("pe", pe)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x = x + self.pe[: x.size(0)]
|
34 |
+
return F.dropout(x, self.dropout, training=self.training)
|
35 |
+
|
36 |
+
|
37 |
+
class TransformerFFNLayer(nn.Module):
|
38 |
+
def __init__(
|
39 |
+
self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.encoder_hidden = encoder_hidden
|
44 |
+
self.conv_filter_size = conv_filter_size
|
45 |
+
self.conv_kernel_size = conv_kernel_size
|
46 |
+
self.encoder_dropout = encoder_dropout
|
47 |
+
|
48 |
+
self.ffn_1 = nn.Conv1d(
|
49 |
+
self.encoder_hidden,
|
50 |
+
self.conv_filter_size,
|
51 |
+
self.conv_kernel_size,
|
52 |
+
padding=self.conv_kernel_size // 2,
|
53 |
+
)
|
54 |
+
self.ffn_1.weight.data.normal_(0.0, 0.02)
|
55 |
+
self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden)
|
56 |
+
self.ffn_2.weight.data.normal_(0.0, 0.02)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
# x: (B, T, d)
|
60 |
+
x = self.ffn_1(x.permute(0, 2, 1)).permute(
|
61 |
+
0, 2, 1
|
62 |
+
) # (B, T, d) -> (B, d, T) -> (B, T, d)
|
63 |
+
x = F.relu(x)
|
64 |
+
x = F.dropout(x, self.encoder_dropout, training=self.training)
|
65 |
+
x = self.ffn_2(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
class TransformerEncoderLayer(nn.Module):
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
encoder_hidden,
|
73 |
+
encoder_head,
|
74 |
+
conv_filter_size,
|
75 |
+
conv_kernel_size,
|
76 |
+
encoder_dropout,
|
77 |
+
use_cln,
|
78 |
+
):
|
79 |
+
super().__init__()
|
80 |
+
self.encoder_hidden = encoder_hidden
|
81 |
+
self.encoder_head = encoder_head
|
82 |
+
self.conv_filter_size = conv_filter_size
|
83 |
+
self.conv_kernel_size = conv_kernel_size
|
84 |
+
self.encoder_dropout = encoder_dropout
|
85 |
+
self.use_cln = use_cln
|
86 |
+
|
87 |
+
if not self.use_cln:
|
88 |
+
self.ln_1 = nn.LayerNorm(self.encoder_hidden)
|
89 |
+
self.ln_2 = nn.LayerNorm(self.encoder_hidden)
|
90 |
+
else:
|
91 |
+
self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden)
|
92 |
+
self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden)
|
93 |
+
|
94 |
+
self.self_attn = nn.MultiheadAttention(
|
95 |
+
self.encoder_hidden, self.encoder_head, batch_first=True
|
96 |
+
)
|
97 |
+
|
98 |
+
self.ffn = TransformerFFNLayer(
|
99 |
+
self.encoder_hidden,
|
100 |
+
self.conv_filter_size,
|
101 |
+
self.conv_kernel_size,
|
102 |
+
self.encoder_dropout,
|
103 |
+
)
|
104 |
+
|
105 |
+
def forward(self, x, key_padding_mask, conditon=None):
|
106 |
+
# x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d)
|
107 |
+
|
108 |
+
# self attention
|
109 |
+
residual = x
|
110 |
+
if self.use_cln:
|
111 |
+
x = self.ln_1(x, conditon)
|
112 |
+
else:
|
113 |
+
x = self.ln_1(x)
|
114 |
+
|
115 |
+
if key_padding_mask != None:
|
116 |
+
key_padding_mask_input = ~(key_padding_mask.bool())
|
117 |
+
else:
|
118 |
+
key_padding_mask_input = None
|
119 |
+
x, _ = self.self_attn(
|
120 |
+
query=x, key=x, value=x, key_padding_mask=key_padding_mask_input
|
121 |
+
)
|
122 |
+
x = F.dropout(x, self.encoder_dropout, training=self.training)
|
123 |
+
x = residual + x
|
124 |
+
|
125 |
+
# ffn
|
126 |
+
residual = x
|
127 |
+
if self.use_cln:
|
128 |
+
x = self.ln_2(x, conditon)
|
129 |
+
else:
|
130 |
+
x = self.ln_2(x)
|
131 |
+
x = self.ffn(x)
|
132 |
+
x = residual + x
|
133 |
+
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
class TransformerEncoder(nn.Module):
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
enc_emb_tokens=None,
|
141 |
+
encoder_layer=4,
|
142 |
+
encoder_hidden=256,
|
143 |
+
encoder_head=4,
|
144 |
+
conv_filter_size=1024,
|
145 |
+
conv_kernel_size=5,
|
146 |
+
encoder_dropout=0.1,
|
147 |
+
use_cln=False,
|
148 |
+
cfg=None,
|
149 |
+
):
|
150 |
+
super().__init__()
|
151 |
+
|
152 |
+
self.encoder_layer = (
|
153 |
+
encoder_layer if encoder_layer is not None else cfg.encoder_layer
|
154 |
+
)
|
155 |
+
self.encoder_hidden = (
|
156 |
+
encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden
|
157 |
+
)
|
158 |
+
self.encoder_head = (
|
159 |
+
encoder_head if encoder_head is not None else cfg.encoder_head
|
160 |
+
)
|
161 |
+
self.conv_filter_size = (
|
162 |
+
conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size
|
163 |
+
)
|
164 |
+
self.conv_kernel_size = (
|
165 |
+
conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size
|
166 |
+
)
|
167 |
+
self.encoder_dropout = (
|
168 |
+
encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout
|
169 |
+
)
|
170 |
+
self.use_cln = use_cln if use_cln is not None else cfg.use_cln
|
171 |
+
|
172 |
+
if enc_emb_tokens != None:
|
173 |
+
self.use_enc_emb = True
|
174 |
+
self.enc_emb_tokens = enc_emb_tokens
|
175 |
+
else:
|
176 |
+
self.use_enc_emb = False
|
177 |
+
|
178 |
+
self.position_emb = PositionalEncoding(
|
179 |
+
self.encoder_hidden, self.encoder_dropout
|
180 |
+
)
|
181 |
+
|
182 |
+
self.layers = nn.ModuleList([])
|
183 |
+
self.layers.extend(
|
184 |
+
[
|
185 |
+
TransformerEncoderLayer(
|
186 |
+
self.encoder_hidden,
|
187 |
+
self.encoder_head,
|
188 |
+
self.conv_filter_size,
|
189 |
+
self.conv_kernel_size,
|
190 |
+
self.encoder_dropout,
|
191 |
+
self.use_cln,
|
192 |
+
)
|
193 |
+
for i in range(self.encoder_layer)
|
194 |
+
]
|
195 |
+
)
|
196 |
+
|
197 |
+
if self.use_cln:
|
198 |
+
self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden)
|
199 |
+
else:
|
200 |
+
self.last_ln = nn.LayerNorm(self.encoder_hidden)
|
201 |
+
|
202 |
+
def forward(self, x, key_padding_mask, condition=None):
|
203 |
+
if len(x.shape) == 2 and self.use_enc_emb:
|
204 |
+
x = self.enc_emb_tokens(x)
|
205 |
+
x = self.position_emb(x)
|
206 |
+
else:
|
207 |
+
x = self.position_emb(x) # (B, T, d)
|
208 |
+
|
209 |
+
for layer in self.layers:
|
210 |
+
x = layer(x, key_padding_mask, condition)
|
211 |
+
|
212 |
+
if self.use_cln:
|
213 |
+
x = self.last_ln(x, condition)
|
214 |
+
else:
|
215 |
+
x = self.last_ln(x)
|
216 |
+
|
217 |
+
return x
|