alibabasglab commited on
Commit
30c9347
·
verified ·
1 Parent(s): ce131df

Delete models/av_mossformer2_tse/av_mossformer_tmp.py

Browse files
models/av_mossformer2_tse/av_mossformer_tmp.py DELETED
@@ -1,252 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import torchaudio
6
-
7
- import math
8
-
9
- from .mossformer.utils.one_path_flash_fsmn import Dual_Path_Model, SBFLASHBlock_DualA
10
- from models.av_mossformer2_tse.visual_frontend import VisualFrontend
11
-
12
- EPS = 1e-8
13
-
14
- class avMossformer(nn.Module):
15
- def __init__(self, args):
16
- super(avMossformer, self).__init__()
17
-
18
- N, L, = args.network_audio.encoder_out_nchannels, args.network_audio.encoder_kernel_size
19
-
20
- self.encoder = Encoder(L, N)
21
- self.separator = Separator(args)
22
- self.decoder = Decoder(args, N, L)
23
-
24
- for p in self.parameters():
25
- if p.dim() > 1:
26
- nn.init.xavier_normal_(p)
27
-
28
- def forward(self, mixture, visual):
29
- """
30
- Args:
31
- mixture: [M, T], M is batch size, T is #samples
32
- Returns:
33
- est_source: [M, C, T]
34
- """
35
- mixture_w = self.encoder(mixture)
36
- est_mask = self.separator(mixture_w, visual)
37
- est_source = self.decoder(mixture_w, est_mask)
38
-
39
- # T changed after conv1d in encoder, fix it here
40
- T_origin = mixture.size(-1)
41
- T_conv = est_source.size(-1)
42
- est_source = F.pad(est_source, (0, T_origin - T_conv))
43
- return est_source
44
-
45
- class Encoder(nn.Module):
46
- def __init__(self, L, N):
47
- super(Encoder, self).__init__()
48
- self.L, self.N = L, N
49
- self.conv1d_U = nn.Conv1d(1, N, kernel_size=L, stride=L // 2, bias=False)
50
-
51
- def forward(self, mixture):
52
- """
53
- Args:
54
- mixture: [M, T], M is batch size, T is #samples
55
- Returns:
56
- mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
57
- """
58
- mixture = torch.unsqueeze(mixture, 1) # [M, 1, T]
59
- mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
60
- return mixture_w
61
-
62
-
63
- class Decoder(nn.Module):
64
- def __init__(self, args, N, L):
65
- super(Decoder, self).__init__()
66
- self.N, self.L, self.args = N, L, args
67
- self.basis_signals = nn.Linear(N, L, bias=False)
68
-
69
- def forward(self, mixture_w, est_mask):
70
- """
71
- Args:
72
- mixture_w: [M, N, K]
73
- est_mask: [M, C, N, K]
74
- Returns:
75
- est_source: [M, C, T]
76
- """
77
- est_source = mixture_w * est_mask
78
- est_source = torch.transpose(est_source, 2, 1) # [M, K, N]
79
- est_source = self.basis_signals(est_source) # [M, K, L]
80
- est_source = overlap_and_add(est_source, self.L//2) # M x C x T
81
- return est_source
82
-
83
-
84
-
85
-
86
- class Separator(nn.Module):
87
- def __init__(self, args):
88
- super(Separator, self).__init__()
89
-
90
- self.layer_norm = nn.GroupNorm(1, args.network_audio.encoder_out_nchannels, eps=1e-8)
91
- self.bottleneck_conv1x1 = nn.Conv1d(args.network_audio.encoder_out_nchannels, args.network_audio.encoder_out_nchannels, 1, bias=False)
92
-
93
- # mossformer 2
94
- intra_model = SBFLASHBlock_DualA(
95
- num_layers=args.network_audio.intra_numlayers,
96
- d_model=args.network_audio.encoder_out_nchannels,
97
- nhead=args.network_audio.intra_nhead,
98
- d_ffn=args.network_audio.intra_dffn,
99
- dropout=args.network_audio.intra_dropout,
100
- use_positional_encoding=args.network_audio.intra_use_positional,
101
- norm_before=args.network_audio.intra_norm_before
102
- )
103
-
104
- self.masknet = Dual_Path_Model(
105
- in_channels=args.network_audio.encoder_out_nchannels,
106
- out_channels=args.network_audio.encoder_out_nchannels,
107
- intra_model=intra_model,
108
- num_layers=args.network_audio.masknet_numlayers,
109
- norm=args.network_audio.masknet_norm,
110
- K=args.network_audio.masknet_chunksize,
111
- num_spks=args.network_audio.masknet_numspks,
112
- skip_around_intra=args.network_audio.masknet_extraskipconnection,
113
- linear_layer_after_inter_intra=args.network_audio.masknet_useextralinearlayer
114
- )
115
-
116
- # reference
117
- # visual
118
- stacks = []
119
- for x in range(5):
120
- stacks +=[VisualConv1D(V=256, H=512)]
121
- self.visual_conv = nn.Sequential(*stacks)
122
- self.v_ds = nn.Conv1d(512, 256, 1, bias=False)
123
- self.av_conv = nn.Conv1d(args.network_audio.encoder_out_nchannels+args.network_reference.emb_size, args.network_audio.encoder_out_nchannels, 1, bias=True)
124
-
125
-
126
- def forward(self, x, visual):
127
- """
128
- Keep this API same with TasNet
129
- Args:
130
- mixture_w: [M, N, K], M is batch size
131
- returns:
132
- est_mask: [M, C, N, K]
133
- """
134
- M, N, D = x.size()
135
-
136
- x = self.layer_norm(x)
137
- x = self.bottleneck_conv1x1(x)
138
-
139
-
140
- visual = visual.transpose(1,2)
141
- visual = self.v_ds(visual)
142
- visual = self.visual_conv(visual)
143
- visual = F.interpolate(visual, (D), mode='linear')
144
-
145
- x = torch.cat((x, visual),1)
146
- x = self.av_conv(x)
147
-
148
- x = self.masknet(x)
149
-
150
- x = x.squeeze(0)
151
-
152
- return x
153
-
154
-
155
-
156
- def overlap_and_add(signal, frame_step):
157
- """Reconstructs a signal from a framed representation.
158
-
159
- Adds potentially overlapping frames of a signal with shape
160
- `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
161
- The resulting tensor has shape `[..., output_size]` where
162
-
163
- output_size = (frames - 1) * frame_step + frame_length
164
-
165
- Args:
166
- signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2.
167
- frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.
168
-
169
- Returns:
170
- A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions.
171
- output_size = (frames - 1) * frame_step + frame_length
172
-
173
- Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
174
- """
175
- outer_dimensions = signal.size()[:-2]
176
- frames, frame_length = signal.size()[-2:]
177
-
178
- subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
179
- subframe_step = frame_step // subframe_length
180
- subframes_per_frame = frame_length // subframe_length
181
- output_size = frame_step * (frames - 1) + frame_length
182
- output_subframes = output_size // subframe_length
183
-
184
- subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
185
-
186
- frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
187
- frame = signal.new_tensor(frame).long().cuda() # signal may in GPU or CPU
188
- frame = frame.contiguous().view(-1)
189
-
190
- result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
191
- result.index_add_(-2, frame, subframe_signal)
192
- result = result.view(*outer_dimensions, -1)
193
- return result
194
-
195
-
196
- class av_mossformer_tmp(nn.Module):
197
- def __init__(self, args):
198
- super(av_mossformer_tmp, self).__init__()
199
- args.causal=0
200
- self.sep_network = avMossformer(args)
201
- self.v_front_end = VisualFrontend(args)
202
-
203
- def forward(self, mixture, ref):
204
- ref = self.v_front_end(ref.unsqueeze(1)).transpose(1,2)
205
- return self.sep_network(mixture, ref)
206
-
207
-
208
-
209
- class VisualConv1D(nn.Module):
210
- def __init__(self, V=256, H=512):
211
- super(VisualConv1D, self).__init__()
212
- relu_0 = nn.ReLU()
213
- norm_0 = GlobalLayerNorm(V)
214
- conv1x1 = nn.Conv1d(V, H, 1, bias=False)
215
- relu = nn.ReLU()
216
- norm_1 = GlobalLayerNorm(H)
217
- dsconv = nn.Conv1d(H, H, 3, stride=1, padding=1,dilation=1, groups=H, bias=False)
218
- prelu = nn.PReLU()
219
- norm_2 = GlobalLayerNorm(H)
220
- pw_conv = nn.Conv1d(H, V, 1, bias=False)
221
- self.net = nn.Sequential(relu_0, norm_0, conv1x1, relu, norm_1 ,dsconv, prelu, norm_2, pw_conv)
222
-
223
- def forward(self, x):
224
- out = self.net(x)
225
- return out + x
226
-
227
- class GlobalLayerNorm(nn.Module):
228
- """Global Layer Normalization (gLN)"""
229
- def __init__(self, channel_size):
230
- super(GlobalLayerNorm, self).__init__()
231
- self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
232
- self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) # [1, N, 1]
233
- self.reset_parameters()
234
-
235
- def reset_parameters(self):
236
- self.gamma.data.fill_(1)
237
- self.beta.data.zero_()
238
-
239
- def forward(self, y):
240
- """
241
- Args:
242
- y: [M, N, K], M is batch size, N is channel size, K is length
243
- Returns:
244
- gLN_y: [M, N, K]
245
- """
246
- # TODO: in torch 1.0, torch.mean() support dim list
247
- mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) #[M, 1, 1]
248
- var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
249
- gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
250
- return gLN_y
251
-
252
-