Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.utils import weight_norm, spectral_norm, remove_weight_norm | |
# V2 model์ ๊ธฐ์ค์ผ๋ก ํ๋ค. | |
class ResBlock(nn.Module): | |
def __init__(self, channels, kernel_size): | |
""" | |
channels: | |
kernel_size: 3, 7, 11 ์ค ํ๋ | |
""" | |
super(ResBlock, self).__init__() | |
# padding = (kernel_size-1)*dilation//2 ("same") | |
self.convs1 = nn.ModuleList([ | |
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1, | |
padding=(kernel_size-1)*1//2)), | |
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1, | |
padding=(kernel_size-1)*1//2)) | |
]) | |
self.convs2 = nn.ModuleList([ | |
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=3, | |
padding=(kernel_size-1)*3//2)), | |
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1, | |
padding=(kernel_size-1)*1//2)) | |
]) | |
self.convs3 = nn.ModuleList([ | |
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=5, | |
padding=(kernel_size-1)*5//2)), | |
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1, | |
padding=(kernel_size-1)*1//2)) | |
]) | |
self.modules = [self.convs1, self.convs2, self.convs3] | |
# ํ๊ท ์ด 0, ํ์คํธ์ฐจ๊ฐ 0.01์ธ ์ ๊ท๋ถํฌ๋ก ๊ฐ์ค์น ์ด๊ธฐํ | |
for module in self.modules: | |
for conv in module: | |
nn.init.normal_(conv.weight, mean=0.0, std=0.01) | |
def forward(self, x): | |
""" | |
=====inputs===== | |
x: (B, channels, F) # mel-spectrogram์ผ๋ก๋ถํฐ ์ป์ด์ง input features | |
=====outputs===== | |
x: (B, channels, F) # mel-spectrogram์ผ๋ก๋ถํฐ ์ป์ด์ง output features | |
""" | |
for module in self.modules: | |
for conv in module: | |
y = F.leaky_relu(x, 0.1) | |
y = conv(y) | |
x = x + y | |
return x | |
def remove_weight_norm(self): | |
for module in self.modules: | |
for conv in module: | |
remove_weight_norm(conv) | |
class MRF(nn.Module): | |
def __init__(self, channels): | |
""" | |
channels: | |
""" | |
super(MRF, self).__init__() | |
self.res_blocks = nn.ModuleList([ | |
ResBlock(channels, kernel_size=3), | |
ResBlock(channels, kernel_size=7), | |
ResBlock(channels, kernel_size=11), | |
]) | |
def forward(self, x): | |
""" | |
=====inputs===== | |
x: (B, channels, F) | |
=====outputs===== | |
x: (B, channels, F) | |
""" | |
skip_list = [] | |
for res_block in self.res_blocks: | |
skip_x = res_block(x) | |
skip_list.append(skip_x) | |
x = sum(skip_list) / len(self.res_blocks) | |
return x | |
def remove_weight_norm(self): | |
for block in self.res_blocks: | |
block.remove_weight_norm() | |
class Generator(nn.Module): | |
def __init__(self): | |
super(Generator, self).__init__() | |
self.pre_conv = weight_norm(nn.Conv1d(80, 128, kernel_size=7, stride=1, dilation=1, | |
padding=(7-1)//2)) # (B, 80, F) -> (B, 128, F) | |
nn.init.normal_(self.pre_conv.weight, mean=0.0, std=0.01) # ๋ ผ๋ฌธ ์ ์ ๊ตฌํ์๋ ์์. | |
self.up_convs = nn.ModuleList() | |
self.mrfs = nn.ModuleList() | |
ku = [16, 16, 4, 4] | |
for i in range(4): | |
# ku//2 ๋ฐฐ upsampling | |
channels = 128//(2**(i+1)) | |
up_conv = weight_norm(nn.ConvTranspose1d(128//(2**i), channels, kernel_size=ku[i], stride=ku[i]//2, | |
padding=(ku[i]-ku[i]//2)//2)) | |
# (B, 128, F) -(1)-> (B, 64, F*8) -(2)-> (B, 32, F*8*8) -(3)-> (B, 16, F*8*8*2) -(4)-> (B, 8, F*8*8*2*2) | |
nn.init.normal_(up_conv.weight, mean=0.0, std=0.01) | |
self.up_convs.append(up_conv) | |
# MRF | |
mrf = MRF(channels) # (B, channels, F) -> (B, channels, F) | |
self.mrfs.append(mrf) | |
self.post_conv = weight_norm(nn.Conv1d(8, 1, kernel_size=7, stride=1, dilation=1, | |
padding=(7-1)//2)) # (B, 8, F*256) -> (B, 1, F*256) | |
nn.init.normal_(self.post_conv.weight, mean=0.0, std=0.01) | |
def forward(self, x): | |
""" | |
=====inputs===== | |
x: (B, 80, F) # mel_spectrogram | |
=====outputs===== | |
x: (B, 1, F*256) # waveform | |
""" | |
x = self.pre_conv(x) # (B, 80, F) -> (B, 128, F) | |
for i in range(4): | |
x = F.leaky_relu(x, 0.1) | |
x = self.up_convs[i](x) | |
x = self.mrfs[i](x) | |
# final: (B, 128, F) -> (B, 8, F*256) | |
x = F.leaky_relu(x, 0.1) | |
x = self.post_conv(x) # (B, 8, F*256) -> (B, 1, F*256) | |
x = torch.tanh(x) | |
return x | |
def remove_weight_norm(self): | |
print('Removing weight norm...') | |
for l in self.up_convs: | |
remove_weight_norm(l) | |
for l in self.mrfs: | |
l.remove_weight_norm() | |
remove_weight_norm(self.pre_conv) | |
remove_weight_norm(self.post_conv) | |
class SubPD(nn.Module): | |
def __init__(self, period): | |
#period: 2, 3, 5, 7, 11 ์ค ํ๋ | |
super(SubPD, self).__init__() | |
self.period = period | |
self.convs = nn.ModuleList() | |
channels = 1 | |
for i in range(1, 5): # ๋ ผ๋ฌธ ์ ์์ ๋ณํ ๊ตฌํ ๋์ ๋ ผ๋ฌธ๋๋ก ๊ตฌํํจ. | |
conv = weight_norm(nn.Conv2d(channels, 2**(5+i), kernel_size=(5, 1), stride=(3, 1), dilation=1, padding=0)) | |
self.convs.append(conv) | |
channels = 2**(5+i) | |
# (B, 1, [T/p]+1, p) -(1)-> (B, 64, ?, p) -(2)-> (B, 128, ?, p) -(3)-> (B, 256, ?, p) -(4)-> (B, 512, ?, p) | |
last_conv = weight_norm(nn.Conv2d(channels, 1024, kernel_size=(5, 1), stride=(1, 1), dilation=1, | |
padding=(2, 0))) # (B, 512, ?, p) -> (B, 1024, ?, p) | |
self.convs.append(last_conv) | |
self.post_conv = weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1), stride=(1, 1), dilation=1, | |
padding=(1, 0))) # (B, 1024, ?, p) -> (B, 1, ?, p) | |
def forward(self, waveform): | |
""" | |
=====inputs===== | |
waveform: (B, 1, T) | |
=====outputs===== | |
x: (B, ?) # flatten๋ real/fake ๋ฒกํฐ (0~1(?)) | |
features: feature๋ฅผ ๋ชจ๋ ๋ชจ์๋์ list (Feature Matching Loss๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํจ.) | |
""" | |
features = [] | |
B, _, T = waveform.size() | |
P = self.period | |
# padding | |
if T % P != 0: | |
padding = P - (T % P) | |
waveform = F.pad(waveform, (0, padding), "reflect") # ์์ชฝ์ 0, ๋ค์ชฝ์ padding๋งํผ ํจ๋ฉ, reflect๋ ๋ง์น ๊ฑฐ์ธ์ ๋ฐ์ฌ๋๋ฏ์ด ํจ๋ฉํจ. | |
# ex) [1, 2, 3, 4, 5]๋ฅผ ์์ชฝ์ 2, ๋ค์ชฝ์ 3๋งํผ reflect ๋ชจ๋๋ก padding -> [3, 2, 1, 2, 3, 4, 5, 4, 3, 2] | |
T += padding | |
# reshape | |
x = waveform.view(B, 1, T//P, P) # (B, 1, [T/P]+1, P) | |
for conv in self.convs: | |
x = conv(x) | |
x = F.leaky_relu(x, 0.1) | |
features.append(x) | |
x = self.post_conv(x) | |
features.append(x) | |
x = torch.flatten(x, 1, -1) # index 1๋ฒ์งธ ์ฐจ์๋ถํฐ ๋ง์ง๋ง ์ฐจ์๊น์ง flatten | (B, ?) | |
##### sigmoid ํจ์๋ cliping ๊ณผ์ ์ ๊ฑฐ์น์ง ์์๋ ๋๋๊ฐ...? | |
return x, features | |
class MPD(nn.Module): | |
def __init__(self): | |
super(MPD, self).__init__() | |
self.sub_pds = nn.ModuleList([ | |
SubPD(2), SubPD(3), SubPD(5), SubPD(7), SubPD(11), | |
]) # (B, 1, T) -> (B, ?), features list | |
def forward(self, real_waveform, gen_waveform): | |
""" | |
=====inputs===== | |
real_waveform: (B, 1, T) # ์ค์ ์์ฑ | |
gen_waveform: (B, 1, T) # ์์ฑ ์์ฑ | |
=====outputs===== | |
real_outputs: (B, ?) list (len=5) # ์ค์ ์์ฑ์ ๋ํ SubPD outputs list | |
gen_outputs: (B, ?) list # ์์ฑ ์์ฑ์ ๋ํ SubPD outputs list | |
real_features: features list # ์ค์ ์์ฑ์ ๋ํ SubPD features list | |
gen_features: features list # ์์ฑ ์์ฑ์ ๋ํ SubPD features list | |
""" | |
real_outputs, gen_outputs, real_features, gen_features = [], [], [], [] | |
for sub_pd in self.sub_pds: | |
real_output, real_feature = sub_pd(real_waveform) | |
gen_output, gen_feature = sub_pd(gen_waveform) | |
real_outputs.append(real_output) | |
gen_outputs.append(gen_output) | |
real_features.append(real_feature) | |
gen_features.append(gen_feature) | |
return real_outputs, gen_outputs, real_features, gen_features | |
class SubSD(nn.Module): | |
def __init__(self, first=False): | |
""" | |
first: boolean (first๊ฐ True์ด๋ฉด spectral normalization์ ์ ์ฉํ๋ค.) | |
""" | |
super(SubSD, self).__init__() | |
norm = spectral_norm if first else weight_norm # first๊ฐ True์ด๋ฉด spectral_norm, ๊ทธ๋ ์ง ์์ผ๋ฉด weight_norm | |
self.convs = nn.ModuleList([ # Mel-GAN ๋ ผ๋ฌธ์ ๋ง๊ฒ ๊ตฌํ | |
norm(nn.Conv1d(1, 16, kernel_size=15, stride=1, padding=(15-1)//2)), # (B, 1, T) -> (B, 16, T) | |
norm(nn.Conv1d(16, 64, kernel_size=41, stride=4, groups=4, padding=(41-1)//2)), # (B, 16, T) -> (B, 64, T/4(?)) | |
norm(nn.Conv1d(64, 256, kernel_size=41, stride=4, groups=16, padding=(41-1)//2)), # (B, 64, T/4(?)) -> (B, 256, T/16(?)) | |
norm(nn.Conv1d(256, 1024, kernel_size=41, stride=4, groups=64, padding=(41-1)//2)), # (B, 256, T/16(?)) -> (B, 1024, T/64(?)) | |
norm(nn.Conv1d(1024, 1024, kernel_size=41, stride=4, groups=256, padding=(41-1)//2)), # (B, 1024, T/64(?)) -> (B, 1024, T/256(?)) | |
norm(nn.Conv1d(1024, 1024, kernel_size=5, stride=1, padding=(5-1)//2)) # (B, 1024, T/256(?)) -> (B, 1024, T/256(?)) | |
]) | |
self.post_conv = norm(nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=(3-1)//2)) # (B, 1024, ?) -> (B, 1, ?) | |
def forward(self, waveform): | |
""" | |
=====inputs===== | |
waveform: (B, 1, T) | |
=====outputs===== | |
x: (B, ?) # flatten๋ real/fake ๋ฒกํฐ (0~1(?)) | |
features: feature๋ฅผ ๋ชจ๋ ๋ชจ์๋์ list (Feature Matching Loss๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํจ.) | |
""" | |
features = [] | |
x = waveform | |
for conv in self.convs: | |
x = conv(x) | |
x = F.leaky_relu(x, 0.1) | |
features.append(x) | |
x = self.post_conv(x) # (B, 1, ?) | |
features.append(x) | |
x = x.squeeze(1) # (B, ?) | |
##### sigmoid ํจ์๋ cliping ๊ณผ์ ์ ๊ฑฐ์น์ง ์์๋ ๋๋๊ฐ...? | |
return x, features | |
class MSD(nn.Module): | |
def __init__(self): | |
super(MSD, self).__init__() | |
self.sub_sds = nn.ModuleList([ | |
SubSD(first=True), SubSD(), SubSD() | |
]) # (B, 1, T) -> (B, ?), features list | |
self.avgpool = nn.AvgPool1d(kernel_size=4, stride=2, padding=2) # x2 down sampling | |
def forward(self, real_waveform, gen_waveform): | |
""" | |
=====inputs===== | |
real_waveform: (B, 1, T) # ์ค์ ์์ฑ | |
gen_waveform: (B, 1, T) # ์์ฑ ์์ฑ | |
=====outputs===== | |
real_outputs: (B, ?) list (len=3) # ์ค์ ์์ฑ์ ๋ํ SubSD outputs list | |
gen_outputs: (B, ?) list # ์์ฑ ์์ฑ์ ๋ํ SubSD outputs list | |
real_features: features list # ์ค์ ์์ฑ์ ๋ํ SubSD features list | |
gen_features: features list # ์์ฑ ์์ฑ์ ๋ํ SubSD features list | |
""" | |
real_outputs, gen_outputs, real_features, gen_features = [], [], [], [] | |
for idx, sub_sd in enumerate(self.sub_sds): | |
if idx != 0: | |
real_waveform = self.avgpool(real_waveform) | |
gen_waveform = self.avgpool(gen_waveform) | |
real_output, real_feature = sub_sd(real_waveform) | |
gen_output, gen_feature = sub_sd(gen_waveform) | |
real_outputs.append(real_output) | |
gen_outputs.append(gen_output) | |
real_features.append(real_feature) | |
gen_features.append(gen_feature) | |
return real_outputs, gen_outputs, real_features, gen_features |