import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils import weight_norm, spectral_norm, remove_weight_norm # V2 model을 기준으로 한다. class ResBlock(nn.Module): def __init__(self, channels, kernel_size): """ channels: kernel_size: 3, 7, 11 중 하나 """ super(ResBlock, self).__init__() # padding = (kernel_size-1)*dilation//2 ("same") self.convs1 = nn.ModuleList([ weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1, padding=(kernel_size-1)*1//2)), weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1, padding=(kernel_size-1)*1//2)) ]) self.convs2 = nn.ModuleList([ weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=3, padding=(kernel_size-1)*3//2)), weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1, padding=(kernel_size-1)*1//2)) ]) self.convs3 = nn.ModuleList([ weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=5, padding=(kernel_size-1)*5//2)), weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1, padding=(kernel_size-1)*1//2)) ]) self.modules = [self.convs1, self.convs2, self.convs3] # 평균이 0, 표준편차가 0.01인 정규분포로 가중치 초기화 for module in self.modules: for conv in module: nn.init.normal_(conv.weight, mean=0.0, std=0.01) def forward(self, x): """ =====inputs===== x: (B, channels, F) # mel-spectrogram으로부터 얻어진 input features =====outputs===== x: (B, channels, F) # mel-spectrogram으로부터 얻어진 output features """ for module in self.modules: for conv in module: y = F.leaky_relu(x, 0.1) y = conv(y) x = x + y return x def remove_weight_norm(self): for module in self.modules: for conv in module: remove_weight_norm(conv) class MRF(nn.Module): def __init__(self, channels): """ channels: """ super(MRF, self).__init__() self.res_blocks = nn.ModuleList([ ResBlock(channels, kernel_size=3), ResBlock(channels, kernel_size=7), ResBlock(channels, kernel_size=11), ]) def forward(self, x): """ =====inputs===== x: (B, channels, F) =====outputs===== x: (B, channels, F) """ skip_list = [] for res_block in self.res_blocks: skip_x = res_block(x) skip_list.append(skip_x) x = sum(skip_list) / len(self.res_blocks) return x def remove_weight_norm(self): for block in self.res_blocks: block.remove_weight_norm() class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.pre_conv = weight_norm(nn.Conv1d(80, 128, kernel_size=7, stride=1, dilation=1, padding=(7-1)//2)) # (B, 80, F) -> (B, 128, F) nn.init.normal_(self.pre_conv.weight, mean=0.0, std=0.01) # 논문 저자 구현에는 없음. self.up_convs = nn.ModuleList() self.mrfs = nn.ModuleList() ku = [16, 16, 4, 4] for i in range(4): # ku//2 배 upsampling channels = 128//(2**(i+1)) up_conv = weight_norm(nn.ConvTranspose1d(128//(2**i), channels, kernel_size=ku[i], stride=ku[i]//2, padding=(ku[i]-ku[i]//2)//2)) # (B, 128, F) -(1)-> (B, 64, F*8) -(2)-> (B, 32, F*8*8) -(3)-> (B, 16, F*8*8*2) -(4)-> (B, 8, F*8*8*2*2) nn.init.normal_(up_conv.weight, mean=0.0, std=0.01) self.up_convs.append(up_conv) # MRF mrf = MRF(channels) # (B, channels, F) -> (B, channels, F) self.mrfs.append(mrf) self.post_conv = weight_norm(nn.Conv1d(8, 1, kernel_size=7, stride=1, dilation=1, padding=(7-1)//2)) # (B, 8, F*256) -> (B, 1, F*256) nn.init.normal_(self.post_conv.weight, mean=0.0, std=0.01) def forward(self, x): """ =====inputs===== x: (B, 80, F) # mel_spectrogram =====outputs===== x: (B, 1, F*256) # waveform """ x = self.pre_conv(x) # (B, 80, F) -> (B, 128, F) for i in range(4): x = F.leaky_relu(x, 0.1) x = self.up_convs[i](x) x = self.mrfs[i](x) # final: (B, 128, F) -> (B, 8, F*256) x = F.leaky_relu(x, 0.1) x = self.post_conv(x) # (B, 8, F*256) -> (B, 1, F*256) x = torch.tanh(x) return x def remove_weight_norm(self): print('Removing weight norm...') for l in self.up_convs: remove_weight_norm(l) for l in self.mrfs: l.remove_weight_norm() remove_weight_norm(self.pre_conv) remove_weight_norm(self.post_conv) class SubPD(nn.Module): def __init__(self, period): #period: 2, 3, 5, 7, 11 중 하나 super(SubPD, self).__init__() self.period = period self.convs = nn.ModuleList() channels = 1 for i in range(1, 5): # 논문 저자의 변형 구현 대신 논문대로 구현함. conv = weight_norm(nn.Conv2d(channels, 2**(5+i), kernel_size=(5, 1), stride=(3, 1), dilation=1, padding=0)) self.convs.append(conv) channels = 2**(5+i) # (B, 1, [T/p]+1, p) -(1)-> (B, 64, ?, p) -(2)-> (B, 128, ?, p) -(3)-> (B, 256, ?, p) -(4)-> (B, 512, ?, p) last_conv = weight_norm(nn.Conv2d(channels, 1024, kernel_size=(5, 1), stride=(1, 1), dilation=1, padding=(2, 0))) # (B, 512, ?, p) -> (B, 1024, ?, p) self.convs.append(last_conv) self.post_conv = weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1), stride=(1, 1), dilation=1, padding=(1, 0))) # (B, 1024, ?, p) -> (B, 1, ?, p) def forward(self, waveform): """ =====inputs===== waveform: (B, 1, T) =====outputs===== x: (B, ?) # flatten된 real/fake 벡터 (0~1(?)) features: feature를 모두 모아놓은 list (Feature Matching Loss를 계산하기 위함.) """ features = [] B, _, T = waveform.size() P = self.period # padding if T % P != 0: padding = P - (T % P) waveform = F.pad(waveform, (0, padding), "reflect") # 앞쪽에 0, 뒤쪽에 padding만큼 패딩, reflect는 마치 거울에 반사되듯이 패딩함. # ex) [1, 2, 3, 4, 5]를 앞쪽에 2, 뒤쪽에 3만큼 reflect 모드로 padding -> [3, 2, 1, 2, 3, 4, 5, 4, 3, 2] T += padding # reshape x = waveform.view(B, 1, T//P, P) # (B, 1, [T/P]+1, P) for conv in self.convs: x = conv(x) x = F.leaky_relu(x, 0.1) features.append(x) x = self.post_conv(x) features.append(x) x = torch.flatten(x, 1, -1) # index 1번째 차원부터 마지막 차원까지 flatten | (B, ?) ##### sigmoid 함수나 cliping 과정을 거치지 않아도 되는가...? return x, features class MPD(nn.Module): def __init__(self): super(MPD, self).__init__() self.sub_pds = nn.ModuleList([ SubPD(2), SubPD(3), SubPD(5), SubPD(7), SubPD(11), ]) # (B, 1, T) -> (B, ?), features list def forward(self, real_waveform, gen_waveform): """ =====inputs===== real_waveform: (B, 1, T) # 실제 음성 gen_waveform: (B, 1, T) # 생성 음성 =====outputs===== real_outputs: (B, ?) list (len=5) # 실제 음성에 대한 SubPD outputs list gen_outputs: (B, ?) list # 생성 음성에 대한 SubPD outputs list real_features: features list # 실제 음성에 대한 SubPD features list gen_features: features list # 생성 음성에 대한 SubPD features list """ real_outputs, gen_outputs, real_features, gen_features = [], [], [], [] for sub_pd in self.sub_pds: real_output, real_feature = sub_pd(real_waveform) gen_output, gen_feature = sub_pd(gen_waveform) real_outputs.append(real_output) gen_outputs.append(gen_output) real_features.append(real_feature) gen_features.append(gen_feature) return real_outputs, gen_outputs, real_features, gen_features class SubSD(nn.Module): def __init__(self, first=False): """ first: boolean (first가 True이면 spectral normalization을 적용한다.) """ super(SubSD, self).__init__() norm = spectral_norm if first else weight_norm # first가 True이면 spectral_norm, 그렇지 않으면 weight_norm self.convs = nn.ModuleList([ # Mel-GAN 논문에 맞게 구현 norm(nn.Conv1d(1, 16, kernel_size=15, stride=1, padding=(15-1)//2)), # (B, 1, T) -> (B, 16, T) norm(nn.Conv1d(16, 64, kernel_size=41, stride=4, groups=4, padding=(41-1)//2)), # (B, 16, T) -> (B, 64, T/4(?)) norm(nn.Conv1d(64, 256, kernel_size=41, stride=4, groups=16, padding=(41-1)//2)), # (B, 64, T/4(?)) -> (B, 256, T/16(?)) norm(nn.Conv1d(256, 1024, kernel_size=41, stride=4, groups=64, padding=(41-1)//2)), # (B, 256, T/16(?)) -> (B, 1024, T/64(?)) norm(nn.Conv1d(1024, 1024, kernel_size=41, stride=4, groups=256, padding=(41-1)//2)), # (B, 1024, T/64(?)) -> (B, 1024, T/256(?)) norm(nn.Conv1d(1024, 1024, kernel_size=5, stride=1, padding=(5-1)//2)) # (B, 1024, T/256(?)) -> (B, 1024, T/256(?)) ]) self.post_conv = norm(nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=(3-1)//2)) # (B, 1024, ?) -> (B, 1, ?) def forward(self, waveform): """ =====inputs===== waveform: (B, 1, T) =====outputs===== x: (B, ?) # flatten된 real/fake 벡터 (0~1(?)) features: feature를 모두 모아놓은 list (Feature Matching Loss를 계산하기 위함.) """ features = [] x = waveform for conv in self.convs: x = conv(x) x = F.leaky_relu(x, 0.1) features.append(x) x = self.post_conv(x) # (B, 1, ?) features.append(x) x = x.squeeze(1) # (B, ?) ##### sigmoid 함수나 cliping 과정을 거치지 않아도 되는가...? return x, features class MSD(nn.Module): def __init__(self): super(MSD, self).__init__() self.sub_sds = nn.ModuleList([ SubSD(first=True), SubSD(), SubSD() ]) # (B, 1, T) -> (B, ?), features list self.avgpool = nn.AvgPool1d(kernel_size=4, stride=2, padding=2) # x2 down sampling def forward(self, real_waveform, gen_waveform): """ =====inputs===== real_waveform: (B, 1, T) # 실제 음성 gen_waveform: (B, 1, T) # 생성 음성 =====outputs===== real_outputs: (B, ?) list (len=3) # 실제 음성에 대한 SubSD outputs list gen_outputs: (B, ?) list # 생성 음성에 대한 SubSD outputs list real_features: features list # 실제 음성에 대한 SubSD features list gen_features: features list # 생성 음성에 대한 SubSD features list """ real_outputs, gen_outputs, real_features, gen_features = [], [], [], [] for idx, sub_sd in enumerate(self.sub_sds): if idx != 0: real_waveform = self.avgpool(real_waveform) gen_waveform = self.avgpool(gen_waveform) real_output, real_feature = sub_sd(real_waveform) gen_output, gen_feature = sub_sd(gen_waveform) real_outputs.append(real_output) gen_outputs.append(gen_output) real_features.append(real_feature) gen_features.append(gen_feature) return real_outputs, gen_outputs, real_features, gen_features