maskgct-audio-lab / modules /encoder /condition_encoder.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
7.93 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import torch.nn as nn
from torchaudio.models import Conformer
from models.svc.transformer.transformer import PositionalEncoding
from utils.f0 import f0_to_coarse
class ContentEncoder(nn.Module):
def __init__(self, cfg, input_dim, output_dim):
super().__init__()
self.cfg = cfg
assert input_dim != 0
self.nn = nn.Linear(input_dim, output_dim)
# Introduce conformer or not
if (
"use_conformer_for_content_features" in cfg
and cfg.use_conformer_for_content_features
):
self.pos_encoder = PositionalEncoding(input_dim)
self.conformer = Conformer(
input_dim=input_dim,
num_heads=2,
ffn_dim=256,
num_layers=6,
depthwise_conv_kernel_size=3,
)
else:
self.conformer = None
def forward(self, x, length=None):
# x: (N, seq_len, input_dim) -> (N, seq_len, output_dim)
if self.conformer:
x = self.pos_encoder(x)
x, _ = self.conformer(x, length)
return self.nn(x)
class MelodyEncoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.input_dim = self.cfg.input_melody_dim
self.output_dim = self.cfg.output_melody_dim
self.n_bins = self.cfg.n_bins_melody
if self.input_dim != 0:
if self.n_bins == 0:
# Not use quantization
self.nn = nn.Linear(self.input_dim, self.output_dim)
else:
self.f0_min = cfg.f0_min
self.f0_max = cfg.f0_max
self.nn = nn.Embedding(
num_embeddings=self.n_bins,
embedding_dim=self.output_dim,
padding_idx=None,
)
self.uv_embedding = nn.Embedding(2, self.output_dim)
def forward(self, x, uv=None, length=None):
# x: (B, frame_len)
if self.n_bins == 0:
x = x.unsqueeze(-1)
else:
x = f0_to_coarse(x, self.n_bins, self.f0_min, self.f0_max)
x = self.nn(x)
if self.cfg.use_uv:
uv = self.uv_embedding(uv)
x = x + uv
return x
class LoudnessEncoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.input_dim = self.cfg.input_loudness_dim
self.output_dim = self.cfg.output_loudness_dim
self.n_bins = self.cfg.n_bins_loudness
if self.input_dim != 0:
if self.n_bins == 0:
# Not use quantization
self.nn = nn.Linear(self.input_dim, self.output_dim)
else:
# TODO: set empirically now
self.loudness_min = 1e-30
self.loudness_max = 1.5
self.energy_bins = nn.Parameter(
torch.exp(
torch.linspace(
np.log(self.loudness_min),
np.log(self.loudness_max),
self.n_bins - 1,
)
),
requires_grad=False,
)
self.nn = nn.Embedding(
num_embeddings=self.n_bins,
embedding_dim=self.output_dim,
padding_idx=None,
)
def forward(self, x):
# x: (N, frame_len)
if self.n_bins == 0:
x = x.unsqueeze(-1)
else:
x = torch.bucketize(x, self.energy_bins)
return self.nn(x)
class SingerEncoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.input_dim = 1
self.output_dim = self.cfg.output_singer_dim
self.nn = nn.Embedding(
num_embeddings=cfg.singer_table_size,
embedding_dim=self.output_dim,
padding_idx=None,
)
def forward(self, x):
# x: (N, 1) -> (N, 1, output_dim)
return self.nn(x)
class ConditionEncoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.merge_mode = cfg.merge_mode
### Semantic Features ###
if cfg.use_whisper:
self.whisper_encoder = ContentEncoder(
self.cfg, self.cfg.whisper_dim, self.cfg.content_encoder_dim
)
if cfg.use_contentvec:
self.contentvec_encoder = ContentEncoder(
self.cfg, self.cfg.contentvec_dim, self.cfg.content_encoder_dim
)
if cfg.use_mert:
self.mert_encoder = ContentEncoder(
self.cfg, self.cfg.mert_dim, self.cfg.content_encoder_dim
)
if cfg.use_wenet:
self.wenet_encoder = ContentEncoder(
self.cfg, self.cfg.wenet_dim, self.cfg.content_encoder_dim
)
### Prosody Features ###
if cfg.use_f0:
self.melody_encoder = MelodyEncoder(self.cfg)
if cfg.use_energy:
self.loudness_encoder = LoudnessEncoder(self.cfg)
### Speaker Features ###
if cfg.use_spkid:
self.singer_encoder = SingerEncoder(self.cfg)
def forward(self, x):
outputs = []
if self.cfg.use_f0:
if self.cfg.use_uv:
pitch_enc_out = self.melody_encoder(
x["frame_pitch"], uv=x["frame_uv"], length=x["target_len"]
)
else:
pitch_enc_out = self.melody_encoder(
x["frame_pitch"], uv=None, length=x["target_len"]
)
outputs.append(pitch_enc_out)
if self.cfg.use_energy:
loudness_enc_out = self.loudness_encoder(x["frame_energy"])
outputs.append(loudness_enc_out)
if self.cfg.use_whisper:
# whisper_feat: [b, T, 1024]
whiser_enc_out = self.whisper_encoder(
x["whisper_feat"], length=x["target_len"]
)
outputs.append(whiser_enc_out)
seq_len = whiser_enc_out.shape[1]
if self.cfg.use_contentvec:
contentvec_enc_out = self.contentvec_encoder(
x["contentvec_feat"], length=x["target_len"]
)
outputs.append(contentvec_enc_out)
seq_len = contentvec_enc_out.shape[1]
if self.cfg.use_mert:
mert_enc_out = self.mert_encoder(x["mert_feat"], length=x["target_len"])
outputs.append(mert_enc_out)
seq_len = mert_enc_out.shape[1]
if self.cfg.use_wenet:
wenet_enc_out = self.wenet_encoder(x["wenet_feat"], length=x["target_len"])
outputs.append(wenet_enc_out)
seq_len = wenet_enc_out.shape[1]
if self.cfg.use_spkid:
speaker_enc_out = self.singer_encoder(x["spk_id"]) # [b, 1, 384]
assert (
"whisper_feat" in x.keys()
or "contentvec_feat" in x.keys()
or "mert_feat" in x.keys()
or "wenet_feat" in x.keys()
)
singer_info = speaker_enc_out.expand(-1, seq_len, -1)
outputs.append(singer_info)
encoder_output = None
if self.merge_mode == "concat":
encoder_output = torch.cat(outputs, dim=-1)
if self.merge_mode == "add":
# (#modules, N, seq_len, output_dim)
outputs = torch.cat([out[None, :, :, :] for out in outputs], dim=0)
# (N, seq_len, output_dim)
encoder_output = torch.sum(outputs, dim=0)
return encoder_output