Glow-HiFi-TTS / Hmodel.py
marigold334's picture
Upload 10 files
41989ff
raw
history blame
No virus
12.7 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm, spectral_norm, remove_weight_norm
# V2 model์„ ๊ธฐ์ค€์œผ๋กœ ํ•œ๋‹ค.
class ResBlock(nn.Module):
def __init__(self, channels, kernel_size):
"""
channels:
kernel_size: 3, 7, 11 ์ค‘ ํ•˜๋‚˜
"""
super(ResBlock, self).__init__()
# padding = (kernel_size-1)*dilation//2 ("same")
self.convs1 = nn.ModuleList([
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1,
padding=(kernel_size-1)*1//2)),
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1,
padding=(kernel_size-1)*1//2))
])
self.convs2 = nn.ModuleList([
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=3,
padding=(kernel_size-1)*3//2)),
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1,
padding=(kernel_size-1)*1//2))
])
self.convs3 = nn.ModuleList([
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=5,
padding=(kernel_size-1)*5//2)),
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1,
padding=(kernel_size-1)*1//2))
])
self.modules = [self.convs1, self.convs2, self.convs3]
# ํ‰๊ท ์ด 0, ํ‘œ์ค€ํŽธ์ฐจ๊ฐ€ 0.01์ธ ์ •๊ทœ๋ถ„ํฌ๋กœ ๊ฐ€์ค‘์น˜ ์ดˆ๊ธฐํ™”
for module in self.modules:
for conv in module:
nn.init.normal_(conv.weight, mean=0.0, std=0.01)
def forward(self, x):
"""
=====inputs=====
x: (B, channels, F) # mel-spectrogram์œผ๋กœ๋ถ€ํ„ฐ ์–ป์–ด์ง„ input features
=====outputs=====
x: (B, channels, F) # mel-spectrogram์œผ๋กœ๋ถ€ํ„ฐ ์–ป์–ด์ง„ output features
"""
for module in self.modules:
for conv in module:
y = F.leaky_relu(x, 0.1)
y = conv(y)
x = x + y
return x
def remove_weight_norm(self):
for module in self.modules:
for conv in module:
remove_weight_norm(conv)
class MRF(nn.Module):
def __init__(self, channels):
"""
channels:
"""
super(MRF, self).__init__()
self.res_blocks = nn.ModuleList([
ResBlock(channels, kernel_size=3),
ResBlock(channels, kernel_size=7),
ResBlock(channels, kernel_size=11),
])
def forward(self, x):
"""
=====inputs=====
x: (B, channels, F)
=====outputs=====
x: (B, channels, F)
"""
skip_list = []
for res_block in self.res_blocks:
skip_x = res_block(x)
skip_list.append(skip_x)
x = sum(skip_list) / len(self.res_blocks)
return x
def remove_weight_norm(self):
for block in self.res_blocks:
block.remove_weight_norm()
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.pre_conv = weight_norm(nn.Conv1d(80, 128, kernel_size=7, stride=1, dilation=1,
padding=(7-1)//2)) # (B, 80, F) -> (B, 128, F)
nn.init.normal_(self.pre_conv.weight, mean=0.0, std=0.01) # ๋…ผ๋ฌธ ์ €์ž ๊ตฌํ˜„์—๋Š” ์—†์Œ.
self.up_convs = nn.ModuleList()
self.mrfs = nn.ModuleList()
ku = [16, 16, 4, 4]
for i in range(4):
# ku//2 ๋ฐฐ upsampling
channels = 128//(2**(i+1))
up_conv = weight_norm(nn.ConvTranspose1d(128//(2**i), channels, kernel_size=ku[i], stride=ku[i]//2,
padding=(ku[i]-ku[i]//2)//2))
# (B, 128, F) -(1)-> (B, 64, F*8) -(2)-> (B, 32, F*8*8) -(3)-> (B, 16, F*8*8*2) -(4)-> (B, 8, F*8*8*2*2)
nn.init.normal_(up_conv.weight, mean=0.0, std=0.01)
self.up_convs.append(up_conv)
# MRF
mrf = MRF(channels) # (B, channels, F) -> (B, channels, F)
self.mrfs.append(mrf)
self.post_conv = weight_norm(nn.Conv1d(8, 1, kernel_size=7, stride=1, dilation=1,
padding=(7-1)//2)) # (B, 8, F*256) -> (B, 1, F*256)
nn.init.normal_(self.post_conv.weight, mean=0.0, std=0.01)
def forward(self, x):
"""
=====inputs=====
x: (B, 80, F) # mel_spectrogram
=====outputs=====
x: (B, 1, F*256) # waveform
"""
x = self.pre_conv(x) # (B, 80, F) -> (B, 128, F)
for i in range(4):
x = F.leaky_relu(x, 0.1)
x = self.up_convs[i](x)
x = self.mrfs[i](x)
# final: (B, 128, F) -> (B, 8, F*256)
x = F.leaky_relu(x, 0.1)
x = self.post_conv(x) # (B, 8, F*256) -> (B, 1, F*256)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.up_convs:
remove_weight_norm(l)
for l in self.mrfs:
l.remove_weight_norm()
remove_weight_norm(self.pre_conv)
remove_weight_norm(self.post_conv)
class SubPD(nn.Module):
def __init__(self, period):
#period: 2, 3, 5, 7, 11 ์ค‘ ํ•˜๋‚˜
super(SubPD, self).__init__()
self.period = period
self.convs = nn.ModuleList()
channels = 1
for i in range(1, 5): # ๋…ผ๋ฌธ ์ €์ž์˜ ๋ณ€ํ˜• ๊ตฌํ˜„ ๋Œ€์‹  ๋…ผ๋ฌธ๋Œ€๋กœ ๊ตฌํ˜„ํ•จ.
conv = weight_norm(nn.Conv2d(channels, 2**(5+i), kernel_size=(5, 1), stride=(3, 1), dilation=1, padding=0))
self.convs.append(conv)
channels = 2**(5+i)
# (B, 1, [T/p]+1, p) -(1)-> (B, 64, ?, p) -(2)-> (B, 128, ?, p) -(3)-> (B, 256, ?, p) -(4)-> (B, 512, ?, p)
last_conv = weight_norm(nn.Conv2d(channels, 1024, kernel_size=(5, 1), stride=(1, 1), dilation=1,
padding=(2, 0))) # (B, 512, ?, p) -> (B, 1024, ?, p)
self.convs.append(last_conv)
self.post_conv = weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1), stride=(1, 1), dilation=1,
padding=(1, 0))) # (B, 1024, ?, p) -> (B, 1, ?, p)
def forward(self, waveform):
"""
=====inputs=====
waveform: (B, 1, T)
=====outputs=====
x: (B, ?) # flatten๋œ real/fake ๋ฒกํ„ฐ (0~1(?))
features: feature๋ฅผ ๋ชจ๋‘ ๋ชจ์•„๋†“์€ list (Feature Matching Loss๋ฅผ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•จ.)
"""
features = []
B, _, T = waveform.size()
P = self.period
# padding
if T % P != 0:
padding = P - (T % P)
waveform = F.pad(waveform, (0, padding), "reflect") # ์•ž์ชฝ์— 0, ๋’ค์ชฝ์— padding๋งŒํผ ํŒจ๋”ฉ, reflect๋Š” ๋งˆ์น˜ ๊ฑฐ์šธ์— ๋ฐ˜์‚ฌ๋˜๋“ฏ์ด ํŒจ๋”ฉํ•จ.
# ex) [1, 2, 3, 4, 5]๋ฅผ ์•ž์ชฝ์— 2, ๋’ค์ชฝ์— 3๋งŒํผ reflect ๋ชจ๋“œ๋กœ padding -> [3, 2, 1, 2, 3, 4, 5, 4, 3, 2]
T += padding
# reshape
x = waveform.view(B, 1, T//P, P) # (B, 1, [T/P]+1, P)
for conv in self.convs:
x = conv(x)
x = F.leaky_relu(x, 0.1)
features.append(x)
x = self.post_conv(x)
features.append(x)
x = torch.flatten(x, 1, -1) # index 1๋ฒˆ์งธ ์ฐจ์›๋ถ€ํ„ฐ ๋งˆ์ง€๋ง‰ ์ฐจ์›๊นŒ์ง€ flatten | (B, ?)
##### sigmoid ํ•จ์ˆ˜๋‚˜ cliping ๊ณผ์ •์„ ๊ฑฐ์น˜์ง€ ์•Š์•„๋„ ๋˜๋Š”๊ฐ€...?
return x, features
class MPD(nn.Module):
def __init__(self):
super(MPD, self).__init__()
self.sub_pds = nn.ModuleList([
SubPD(2), SubPD(3), SubPD(5), SubPD(7), SubPD(11),
]) # (B, 1, T) -> (B, ?), features list
def forward(self, real_waveform, gen_waveform):
"""
=====inputs=====
real_waveform: (B, 1, T) # ์‹ค์ œ ์Œ์„ฑ
gen_waveform: (B, 1, T) # ์ƒ์„ฑ ์Œ์„ฑ
=====outputs=====
real_outputs: (B, ?) list (len=5) # ์‹ค์ œ ์Œ์„ฑ์— ๋Œ€ํ•œ SubPD outputs list
gen_outputs: (B, ?) list # ์ƒ์„ฑ ์Œ์„ฑ์— ๋Œ€ํ•œ SubPD outputs list
real_features: features list # ์‹ค์ œ ์Œ์„ฑ์— ๋Œ€ํ•œ SubPD features list
gen_features: features list # ์ƒ์„ฑ ์Œ์„ฑ์— ๋Œ€ํ•œ SubPD features list
"""
real_outputs, gen_outputs, real_features, gen_features = [], [], [], []
for sub_pd in self.sub_pds:
real_output, real_feature = sub_pd(real_waveform)
gen_output, gen_feature = sub_pd(gen_waveform)
real_outputs.append(real_output)
gen_outputs.append(gen_output)
real_features.append(real_feature)
gen_features.append(gen_feature)
return real_outputs, gen_outputs, real_features, gen_features
class SubSD(nn.Module):
def __init__(self, first=False):
"""
first: boolean (first๊ฐ€ True์ด๋ฉด spectral normalization์„ ์ ์šฉํ•œ๋‹ค.)
"""
super(SubSD, self).__init__()
norm = spectral_norm if first else weight_norm # first๊ฐ€ True์ด๋ฉด spectral_norm, ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด weight_norm
self.convs = nn.ModuleList([ # Mel-GAN ๋…ผ๋ฌธ์— ๋งž๊ฒŒ ๊ตฌํ˜„
norm(nn.Conv1d(1, 16, kernel_size=15, stride=1, padding=(15-1)//2)), # (B, 1, T) -> (B, 16, T)
norm(nn.Conv1d(16, 64, kernel_size=41, stride=4, groups=4, padding=(41-1)//2)), # (B, 16, T) -> (B, 64, T/4(?))
norm(nn.Conv1d(64, 256, kernel_size=41, stride=4, groups=16, padding=(41-1)//2)), # (B, 64, T/4(?)) -> (B, 256, T/16(?))
norm(nn.Conv1d(256, 1024, kernel_size=41, stride=4, groups=64, padding=(41-1)//2)), # (B, 256, T/16(?)) -> (B, 1024, T/64(?))
norm(nn.Conv1d(1024, 1024, kernel_size=41, stride=4, groups=256, padding=(41-1)//2)), # (B, 1024, T/64(?)) -> (B, 1024, T/256(?))
norm(nn.Conv1d(1024, 1024, kernel_size=5, stride=1, padding=(5-1)//2)) # (B, 1024, T/256(?)) -> (B, 1024, T/256(?))
])
self.post_conv = norm(nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=(3-1)//2)) # (B, 1024, ?) -> (B, 1, ?)
def forward(self, waveform):
"""
=====inputs=====
waveform: (B, 1, T)
=====outputs=====
x: (B, ?) # flatten๋œ real/fake ๋ฒกํ„ฐ (0~1(?))
features: feature๋ฅผ ๋ชจ๋‘ ๋ชจ์•„๋†“์€ list (Feature Matching Loss๋ฅผ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•จ.)
"""
features = []
x = waveform
for conv in self.convs:
x = conv(x)
x = F.leaky_relu(x, 0.1)
features.append(x)
x = self.post_conv(x) # (B, 1, ?)
features.append(x)
x = x.squeeze(1) # (B, ?)
##### sigmoid ํ•จ์ˆ˜๋‚˜ cliping ๊ณผ์ •์„ ๊ฑฐ์น˜์ง€ ์•Š์•„๋„ ๋˜๋Š”๊ฐ€...?
return x, features
class MSD(nn.Module):
def __init__(self):
super(MSD, self).__init__()
self.sub_sds = nn.ModuleList([
SubSD(first=True), SubSD(), SubSD()
]) # (B, 1, T) -> (B, ?), features list
self.avgpool = nn.AvgPool1d(kernel_size=4, stride=2, padding=2) # x2 down sampling
def forward(self, real_waveform, gen_waveform):
"""
=====inputs=====
real_waveform: (B, 1, T) # ์‹ค์ œ ์Œ์„ฑ
gen_waveform: (B, 1, T) # ์ƒ์„ฑ ์Œ์„ฑ
=====outputs=====
real_outputs: (B, ?) list (len=3) # ์‹ค์ œ ์Œ์„ฑ์— ๋Œ€ํ•œ SubSD outputs list
gen_outputs: (B, ?) list # ์ƒ์„ฑ ์Œ์„ฑ์— ๋Œ€ํ•œ SubSD outputs list
real_features: features list # ์‹ค์ œ ์Œ์„ฑ์— ๋Œ€ํ•œ SubSD features list
gen_features: features list # ์ƒ์„ฑ ์Œ์„ฑ์— ๋Œ€ํ•œ SubSD features list
"""
real_outputs, gen_outputs, real_features, gen_features = [], [], [], []
for idx, sub_sd in enumerate(self.sub_sds):
if idx != 0:
real_waveform = self.avgpool(real_waveform)
gen_waveform = self.avgpool(gen_waveform)
real_output, real_feature = sub_sd(real_waveform)
gen_output, gen_feature = sub_sd(gen_waveform)
real_outputs.append(real_output)
gen_outputs.append(gen_output)
real_features.append(real_feature)
gen_features.append(gen_feature)
return real_outputs, gen_outputs, real_features, gen_features