nevreal commited on
Commit
1ec1063
·
verified ·
1 Parent(s): 4ef50e4

Delete infer_pack

Browse files
infer_pack/__pycache__/attentions.cpython-39.pyc DELETED
Binary file (9.92 kB)
 
infer_pack/__pycache__/commons.cpython-39.pyc DELETED
Binary file (5.88 kB)
 
infer_pack/__pycache__/models.cpython-39.pyc DELETED
Binary file (16.7 kB)
 
infer_pack/__pycache__/modules.cpython-39.pyc DELETED
Binary file (11.9 kB)
 
infer_pack/__pycache__/transforms.cpython-39.pyc DELETED
Binary file (3.92 kB)
 
infer_pack/attentions.py DELETED
@@ -1,417 +0,0 @@
1
- import copy
2
- import math
3
- import numpy as np
4
- import torch
5
- from torch import nn
6
- from torch.nn import functional as F
7
-
8
- from infer_pack import commons
9
- from infer_pack import modules
10
- from infer_pack.modules import LayerNorm
11
-
12
-
13
- class Encoder(nn.Module):
14
- def __init__(
15
- self,
16
- hidden_channels,
17
- filter_channels,
18
- n_heads,
19
- n_layers,
20
- kernel_size=1,
21
- p_dropout=0.0,
22
- window_size=10,
23
- **kwargs
24
- ):
25
- super().__init__()
26
- self.hidden_channels = hidden_channels
27
- self.filter_channels = filter_channels
28
- self.n_heads = n_heads
29
- self.n_layers = n_layers
30
- self.kernel_size = kernel_size
31
- self.p_dropout = p_dropout
32
- self.window_size = window_size
33
-
34
- self.drop = nn.Dropout(p_dropout)
35
- self.attn_layers = nn.ModuleList()
36
- self.norm_layers_1 = nn.ModuleList()
37
- self.ffn_layers = nn.ModuleList()
38
- self.norm_layers_2 = nn.ModuleList()
39
- for i in range(self.n_layers):
40
- self.attn_layers.append(
41
- MultiHeadAttention(
42
- hidden_channels,
43
- hidden_channels,
44
- n_heads,
45
- p_dropout=p_dropout,
46
- window_size=window_size,
47
- )
48
- )
49
- self.norm_layers_1.append(LayerNorm(hidden_channels))
50
- self.ffn_layers.append(
51
- FFN(
52
- hidden_channels,
53
- hidden_channels,
54
- filter_channels,
55
- kernel_size,
56
- p_dropout=p_dropout,
57
- )
58
- )
59
- self.norm_layers_2.append(LayerNorm(hidden_channels))
60
-
61
- def forward(self, x, x_mask):
62
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
63
- x = x * x_mask
64
- for i in range(self.n_layers):
65
- y = self.attn_layers[i](x, x, attn_mask)
66
- y = self.drop(y)
67
- x = self.norm_layers_1[i](x + y)
68
-
69
- y = self.ffn_layers[i](x, x_mask)
70
- y = self.drop(y)
71
- x = self.norm_layers_2[i](x + y)
72
- x = x * x_mask
73
- return x
74
-
75
-
76
- class Decoder(nn.Module):
77
- def __init__(
78
- self,
79
- hidden_channels,
80
- filter_channels,
81
- n_heads,
82
- n_layers,
83
- kernel_size=1,
84
- p_dropout=0.0,
85
- proximal_bias=False,
86
- proximal_init=True,
87
- **kwargs
88
- ):
89
- super().__init__()
90
- self.hidden_channels = hidden_channels
91
- self.filter_channels = filter_channels
92
- self.n_heads = n_heads
93
- self.n_layers = n_layers
94
- self.kernel_size = kernel_size
95
- self.p_dropout = p_dropout
96
- self.proximal_bias = proximal_bias
97
- self.proximal_init = proximal_init
98
-
99
- self.drop = nn.Dropout(p_dropout)
100
- self.self_attn_layers = nn.ModuleList()
101
- self.norm_layers_0 = nn.ModuleList()
102
- self.encdec_attn_layers = nn.ModuleList()
103
- self.norm_layers_1 = nn.ModuleList()
104
- self.ffn_layers = nn.ModuleList()
105
- self.norm_layers_2 = nn.ModuleList()
106
- for i in range(self.n_layers):
107
- self.self_attn_layers.append(
108
- MultiHeadAttention(
109
- hidden_channels,
110
- hidden_channels,
111
- n_heads,
112
- p_dropout=p_dropout,
113
- proximal_bias=proximal_bias,
114
- proximal_init=proximal_init,
115
- )
116
- )
117
- self.norm_layers_0.append(LayerNorm(hidden_channels))
118
- self.encdec_attn_layers.append(
119
- MultiHeadAttention(
120
- hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
121
- )
122
- )
123
- self.norm_layers_1.append(LayerNorm(hidden_channels))
124
- self.ffn_layers.append(
125
- FFN(
126
- hidden_channels,
127
- hidden_channels,
128
- filter_channels,
129
- kernel_size,
130
- p_dropout=p_dropout,
131
- causal=True,
132
- )
133
- )
134
- self.norm_layers_2.append(LayerNorm(hidden_channels))
135
-
136
- def forward(self, x, x_mask, h, h_mask):
137
- """
138
- x: decoder input
139
- h: encoder output
140
- """
141
- self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
142
- device=x.device, dtype=x.dtype
143
- )
144
- encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
145
- x = x * x_mask
146
- for i in range(self.n_layers):
147
- y = self.self_attn_layers[i](x, x, self_attn_mask)
148
- y = self.drop(y)
149
- x = self.norm_layers_0[i](x + y)
150
-
151
- y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
152
- y = self.drop(y)
153
- x = self.norm_layers_1[i](x + y)
154
-
155
- y = self.ffn_layers[i](x, x_mask)
156
- y = self.drop(y)
157
- x = self.norm_layers_2[i](x + y)
158
- x = x * x_mask
159
- return x
160
-
161
-
162
- class MultiHeadAttention(nn.Module):
163
- def __init__(
164
- self,
165
- channels,
166
- out_channels,
167
- n_heads,
168
- p_dropout=0.0,
169
- window_size=None,
170
- heads_share=True,
171
- block_length=None,
172
- proximal_bias=False,
173
- proximal_init=False,
174
- ):
175
- super().__init__()
176
- assert channels % n_heads == 0
177
-
178
- self.channels = channels
179
- self.out_channels = out_channels
180
- self.n_heads = n_heads
181
- self.p_dropout = p_dropout
182
- self.window_size = window_size
183
- self.heads_share = heads_share
184
- self.block_length = block_length
185
- self.proximal_bias = proximal_bias
186
- self.proximal_init = proximal_init
187
- self.attn = None
188
-
189
- self.k_channels = channels // n_heads
190
- self.conv_q = nn.Conv1d(channels, channels, 1)
191
- self.conv_k = nn.Conv1d(channels, channels, 1)
192
- self.conv_v = nn.Conv1d(channels, channels, 1)
193
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
194
- self.drop = nn.Dropout(p_dropout)
195
-
196
- if window_size is not None:
197
- n_heads_rel = 1 if heads_share else n_heads
198
- rel_stddev = self.k_channels**-0.5
199
- self.emb_rel_k = nn.Parameter(
200
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
201
- * rel_stddev
202
- )
203
- self.emb_rel_v = nn.Parameter(
204
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
205
- * rel_stddev
206
- )
207
-
208
- nn.init.xavier_uniform_(self.conv_q.weight)
209
- nn.init.xavier_uniform_(self.conv_k.weight)
210
- nn.init.xavier_uniform_(self.conv_v.weight)
211
- if proximal_init:
212
- with torch.no_grad():
213
- self.conv_k.weight.copy_(self.conv_q.weight)
214
- self.conv_k.bias.copy_(self.conv_q.bias)
215
-
216
- def forward(self, x, c, attn_mask=None):
217
- q = self.conv_q(x)
218
- k = self.conv_k(c)
219
- v = self.conv_v(c)
220
-
221
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
222
-
223
- x = self.conv_o(x)
224
- return x
225
-
226
- def attention(self, query, key, value, mask=None):
227
- # reshape [b, d, t] -> [b, n_h, t, d_k]
228
- b, d, t_s, t_t = (*key.size(), query.size(2))
229
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
230
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
231
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
232
-
233
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
234
- if self.window_size is not None:
235
- assert (
236
- t_s == t_t
237
- ), "Relative attention is only available for self-attention."
238
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
239
- rel_logits = self._matmul_with_relative_keys(
240
- query / math.sqrt(self.k_channels), key_relative_embeddings
241
- )
242
- scores_local = self._relative_position_to_absolute_position(rel_logits)
243
- scores = scores + scores_local
244
- if self.proximal_bias:
245
- assert t_s == t_t, "Proximal bias is only available for self-attention."
246
- scores = scores + self._attention_bias_proximal(t_s).to(
247
- device=scores.device, dtype=scores.dtype
248
- )
249
- if mask is not None:
250
- scores = scores.masked_fill(mask == 0, -1e4)
251
- if self.block_length is not None:
252
- assert (
253
- t_s == t_t
254
- ), "Local attention is only available for self-attention."
255
- block_mask = (
256
- torch.ones_like(scores)
257
- .triu(-self.block_length)
258
- .tril(self.block_length)
259
- )
260
- scores = scores.masked_fill(block_mask == 0, -1e4)
261
- p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
262
- p_attn = self.drop(p_attn)
263
- output = torch.matmul(p_attn, value)
264
- if self.window_size is not None:
265
- relative_weights = self._absolute_position_to_relative_position(p_attn)
266
- value_relative_embeddings = self._get_relative_embeddings(
267
- self.emb_rel_v, t_s
268
- )
269
- output = output + self._matmul_with_relative_values(
270
- relative_weights, value_relative_embeddings
271
- )
272
- output = (
273
- output.transpose(2, 3).contiguous().view(b, d, t_t)
274
- ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
275
- return output, p_attn
276
-
277
- def _matmul_with_relative_values(self, x, y):
278
- """
279
- x: [b, h, l, m]
280
- y: [h or 1, m, d]
281
- ret: [b, h, l, d]
282
- """
283
- ret = torch.matmul(x, y.unsqueeze(0))
284
- return ret
285
-
286
- def _matmul_with_relative_keys(self, x, y):
287
- """
288
- x: [b, h, l, d]
289
- y: [h or 1, m, d]
290
- ret: [b, h, l, m]
291
- """
292
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
293
- return ret
294
-
295
- def _get_relative_embeddings(self, relative_embeddings, length):
296
- max_relative_position = 2 * self.window_size + 1
297
- # Pad first before slice to avoid using cond ops.
298
- pad_length = max(length - (self.window_size + 1), 0)
299
- slice_start_position = max((self.window_size + 1) - length, 0)
300
- slice_end_position = slice_start_position + 2 * length - 1
301
- if pad_length > 0:
302
- padded_relative_embeddings = F.pad(
303
- relative_embeddings,
304
- commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
305
- )
306
- else:
307
- padded_relative_embeddings = relative_embeddings
308
- used_relative_embeddings = padded_relative_embeddings[
309
- :, slice_start_position:slice_end_position
310
- ]
311
- return used_relative_embeddings
312
-
313
- def _relative_position_to_absolute_position(self, x):
314
- """
315
- x: [b, h, l, 2*l-1]
316
- ret: [b, h, l, l]
317
- """
318
- batch, heads, length, _ = x.size()
319
- # Concat columns of pad to shift from relative to absolute indexing.
320
- x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
321
-
322
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
323
- x_flat = x.view([batch, heads, length * 2 * length])
324
- x_flat = F.pad(
325
- x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
326
- )
327
-
328
- # Reshape and slice out the padded elements.
329
- x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
330
- :, :, :length, length - 1 :
331
- ]
332
- return x_final
333
-
334
- def _absolute_position_to_relative_position(self, x):
335
- """
336
- x: [b, h, l, l]
337
- ret: [b, h, l, 2*l-1]
338
- """
339
- batch, heads, length, _ = x.size()
340
- # padd along column
341
- x = F.pad(
342
- x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
343
- )
344
- x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
345
- # add 0's in the beginning that will skew the elements after reshape
346
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
347
- x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
348
- return x_final
349
-
350
- def _attention_bias_proximal(self, length):
351
- """Bias for self-attention to encourage attention to close positions.
352
- Args:
353
- length: an integer scalar.
354
- Returns:
355
- a Tensor with shape [1, 1, length, length]
356
- """
357
- r = torch.arange(length, dtype=torch.float32)
358
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
359
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
360
-
361
-
362
- class FFN(nn.Module):
363
- def __init__(
364
- self,
365
- in_channels,
366
- out_channels,
367
- filter_channels,
368
- kernel_size,
369
- p_dropout=0.0,
370
- activation=None,
371
- causal=False,
372
- ):
373
- super().__init__()
374
- self.in_channels = in_channels
375
- self.out_channels = out_channels
376
- self.filter_channels = filter_channels
377
- self.kernel_size = kernel_size
378
- self.p_dropout = p_dropout
379
- self.activation = activation
380
- self.causal = causal
381
-
382
- if causal:
383
- self.padding = self._causal_padding
384
- else:
385
- self.padding = self._same_padding
386
-
387
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
388
- self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
389
- self.drop = nn.Dropout(p_dropout)
390
-
391
- def forward(self, x, x_mask):
392
- x = self.conv_1(self.padding(x * x_mask))
393
- if self.activation == "gelu":
394
- x = x * torch.sigmoid(1.702 * x)
395
- else:
396
- x = torch.relu(x)
397
- x = self.drop(x)
398
- x = self.conv_2(self.padding(x * x_mask))
399
- return x * x_mask
400
-
401
- def _causal_padding(self, x):
402
- if self.kernel_size == 1:
403
- return x
404
- pad_l = self.kernel_size - 1
405
- pad_r = 0
406
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
407
- x = F.pad(x, commons.convert_pad_shape(padding))
408
- return x
409
-
410
- def _same_padding(self, x):
411
- if self.kernel_size == 1:
412
- return x
413
- pad_l = (self.kernel_size - 1) // 2
414
- pad_r = self.kernel_size // 2
415
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
416
- x = F.pad(x, commons.convert_pad_shape(padding))
417
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/commons.py DELETED
@@ -1,164 +0,0 @@
1
- import math
2
- import numpy as np
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
-
7
-
8
- def init_weights(m, mean=0.0, std=0.01):
9
- classname = m.__class__.__name__
10
- if classname.find("Conv") != -1:
11
- m.weight.data.normal_(mean, std)
12
-
13
-
14
- def get_padding(kernel_size, dilation=1):
15
- return int((kernel_size * dilation - dilation) / 2)
16
-
17
-
18
- def convert_pad_shape(pad_shape):
19
- l = pad_shape[::-1]
20
- pad_shape = [item for sublist in l for item in sublist]
21
- return pad_shape
22
-
23
-
24
- def kl_divergence(m_p, logs_p, m_q, logs_q):
25
- """KL(P||Q)"""
26
- kl = (logs_q - logs_p) - 0.5
27
- kl += (
28
- 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
29
- )
30
- return kl
31
-
32
-
33
- def rand_gumbel(shape):
34
- """Sample from the Gumbel distribution, protect from overflows."""
35
- uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
36
- return -torch.log(-torch.log(uniform_samples))
37
-
38
-
39
- def rand_gumbel_like(x):
40
- g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
41
- return g
42
-
43
-
44
- def slice_segments(x, ids_str, segment_size=4):
45
- ret = torch.zeros_like(x[:, :, :segment_size])
46
- for i in range(x.size(0)):
47
- idx_str = ids_str[i]
48
- idx_end = idx_str + segment_size
49
- ret[i] = x[i, :, idx_str:idx_end]
50
- return ret
51
- def slice_segments2(x, ids_str, segment_size=4):
52
- ret = torch.zeros_like(x[:, :segment_size])
53
- for i in range(x.size(0)):
54
- idx_str = ids_str[i]
55
- idx_end = idx_str + segment_size
56
- ret[i] = x[i, idx_str:idx_end]
57
- return ret
58
-
59
-
60
- def rand_slice_segments(x, x_lengths=None, segment_size=4):
61
- b, d, t = x.size()
62
- if x_lengths is None:
63
- x_lengths = t
64
- ids_str_max = x_lengths - segment_size + 1
65
- ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
66
- ret = slice_segments(x, ids_str, segment_size)
67
- return ret, ids_str
68
-
69
-
70
- def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
71
- position = torch.arange(length, dtype=torch.float)
72
- num_timescales = channels // 2
73
- log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
74
- num_timescales - 1
75
- )
76
- inv_timescales = min_timescale * torch.exp(
77
- torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
78
- )
79
- scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
80
- signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
81
- signal = F.pad(signal, [0, 0, 0, channels % 2])
82
- signal = signal.view(1, channels, length)
83
- return signal
84
-
85
-
86
- def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
87
- b, channels, length = x.size()
88
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
89
- return x + signal.to(dtype=x.dtype, device=x.device)
90
-
91
-
92
- def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
93
- b, channels, length = x.size()
94
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
95
- return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
96
-
97
-
98
- def subsequent_mask(length):
99
- mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
100
- return mask
101
-
102
-
103
- @torch.jit.script
104
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
105
- n_channels_int = n_channels[0]
106
- in_act = input_a + input_b
107
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
108
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
109
- acts = t_act * s_act
110
- return acts
111
-
112
-
113
- def convert_pad_shape(pad_shape):
114
- l = pad_shape[::-1]
115
- pad_shape = [item for sublist in l for item in sublist]
116
- return pad_shape
117
-
118
-
119
- def shift_1d(x):
120
- x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
121
- return x
122
-
123
-
124
- def sequence_mask(length, max_length=None):
125
- if max_length is None:
126
- max_length = length.max()
127
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
128
- return x.unsqueeze(0) < length.unsqueeze(1)
129
-
130
-
131
- def generate_path(duration, mask):
132
- """
133
- duration: [b, 1, t_x]
134
- mask: [b, 1, t_y, t_x]
135
- """
136
- device = duration.device
137
-
138
- b, _, t_y, t_x = mask.shape
139
- cum_duration = torch.cumsum(duration, -1)
140
-
141
- cum_duration_flat = cum_duration.view(b * t_x)
142
- path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
143
- path = path.view(b, t_x, t_y)
144
- path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
145
- path = path.unsqueeze(1).transpose(2, 3) * mask
146
- return path
147
-
148
-
149
- def clip_grad_value_(parameters, clip_value, norm_type=2):
150
- if isinstance(parameters, torch.Tensor):
151
- parameters = [parameters]
152
- parameters = list(filter(lambda p: p.grad is not None, parameters))
153
- norm_type = float(norm_type)
154
- if clip_value is not None:
155
- clip_value = float(clip_value)
156
-
157
- total_norm = 0
158
- for p in parameters:
159
- param_norm = p.grad.data.norm(norm_type)
160
- total_norm += param_norm.item() ** norm_type
161
- if clip_value is not None:
162
- p.grad.data.clamp_(min=-clip_value, max=clip_value)
163
- total_norm = total_norm ** (1.0 / norm_type)
164
- return total_norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/models.py DELETED
@@ -1,664 +0,0 @@
1
- import math,pdb,os
2
- from time import time as ttime
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from infer_pack import modules
7
- from infer_pack import attentions
8
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
9
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
10
- from infer_pack.commons import init_weights
11
- import numpy as np
12
- from infer_pack import commons
13
- class TextEncoder256(nn.Module):
14
- def __init__(
15
- self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True ):
16
- super().__init__()
17
- self.out_channels = out_channels
18
- self.hidden_channels = hidden_channels
19
- self.filter_channels = filter_channels
20
- self.n_heads = n_heads
21
- self.n_layers = n_layers
22
- self.kernel_size = kernel_size
23
- self.p_dropout = p_dropout
24
- self.emb_phone = nn.Linear(256, hidden_channels)
25
- self.lrelu=nn.LeakyReLU(0.1,inplace=True)
26
- if(f0==True):
27
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
28
- self.encoder = attentions.Encoder(
29
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
30
- )
31
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
32
-
33
- def forward(self, phone, pitch, lengths):
34
- if(pitch==None):
35
- x = self.emb_phone(phone)
36
- else:
37
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
38
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
39
- x=self.lrelu(x)
40
- x = torch.transpose(x, 1, -1) # [b, h, t]
41
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
42
- x.dtype
43
- )
44
- x = self.encoder(x * x_mask, x_mask)
45
- stats = self.proj(x) * x_mask
46
-
47
- m, logs = torch.split(stats, self.out_channels, dim=1)
48
- return m, logs, x_mask
49
- class TextEncoder256km(nn.Module):
50
- def __init__(
51
- self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, f0=True ):
52
- super().__init__()
53
- self.out_channels = out_channels
54
- self.hidden_channels = hidden_channels
55
- self.filter_channels = filter_channels
56
- self.n_heads = n_heads
57
- self.n_layers = n_layers
58
- self.kernel_size = kernel_size
59
- self.p_dropout = p_dropout
60
- # self.emb_phone = nn.Linear(256, hidden_channels)
61
- self.emb_phone = nn.Embedding(500, hidden_channels)
62
- self.lrelu=nn.LeakyReLU(0.1,inplace=True)
63
- if(f0==True):
64
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
65
- self.encoder = attentions.Encoder(
66
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
67
- )
68
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
69
-
70
- def forward(self, phone, pitch, lengths):
71
- if(pitch==None):
72
- x = self.emb_phone(phone)
73
- else:
74
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
75
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
76
- x=self.lrelu(x)
77
- x = torch.transpose(x, 1, -1) # [b, h, t]
78
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
79
- x.dtype
80
- )
81
- x = self.encoder(x * x_mask, x_mask)
82
- stats = self.proj(x) * x_mask
83
-
84
- m, logs = torch.split(stats, self.out_channels, dim=1)
85
- return m, logs, x_mask
86
- class ResidualCouplingBlock(nn.Module):
87
- def __init__(
88
- self,
89
- channels,
90
- hidden_channels,
91
- kernel_size,
92
- dilation_rate,
93
- n_layers,
94
- n_flows=4,
95
- gin_channels=0,
96
- ):
97
- super().__init__()
98
- self.channels = channels
99
- self.hidden_channels = hidden_channels
100
- self.kernel_size = kernel_size
101
- self.dilation_rate = dilation_rate
102
- self.n_layers = n_layers
103
- self.n_flows = n_flows
104
- self.gin_channels = gin_channels
105
-
106
- self.flows = nn.ModuleList()
107
- for i in range(n_flows):
108
- self.flows.append(
109
- modules.ResidualCouplingLayer(
110
- channels,
111
- hidden_channels,
112
- kernel_size,
113
- dilation_rate,
114
- n_layers,
115
- gin_channels=gin_channels,
116
- mean_only=True,
117
- )
118
- )
119
- self.flows.append(modules.Flip())
120
-
121
- def forward(self, x, x_mask, g=None, reverse=False):
122
- if not reverse:
123
- for flow in self.flows:
124
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
125
- else:
126
- for flow in reversed(self.flows):
127
- x = flow(x, x_mask, g=g, reverse=reverse)
128
- return x
129
-
130
- def remove_weight_norm(self):
131
- for i in range(self.n_flows):
132
- self.flows[i * 2].remove_weight_norm()
133
- class PosteriorEncoder(nn.Module):
134
- def __init__(
135
- self,
136
- in_channels,
137
- out_channels,
138
- hidden_channels,
139
- kernel_size,
140
- dilation_rate,
141
- n_layers,
142
- gin_channels=0,
143
- ):
144
- super().__init__()
145
- self.in_channels = in_channels
146
- self.out_channels = out_channels
147
- self.hidden_channels = hidden_channels
148
- self.kernel_size = kernel_size
149
- self.dilation_rate = dilation_rate
150
- self.n_layers = n_layers
151
- self.gin_channels = gin_channels
152
-
153
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
154
- self.enc = modules.WN(
155
- hidden_channels,
156
- kernel_size,
157
- dilation_rate,
158
- n_layers,
159
- gin_channels=gin_channels,
160
- )
161
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
162
-
163
- def forward(self, x, x_lengths, g=None):
164
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
165
- x.dtype
166
- )
167
- x = self.pre(x) * x_mask
168
- x = self.enc(x, x_mask, g=g)
169
- stats = self.proj(x) * x_mask
170
- m, logs = torch.split(stats, self.out_channels, dim=1)
171
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
172
- return z, m, logs, x_mask
173
-
174
- def remove_weight_norm(self):
175
- self.enc.remove_weight_norm()
176
- class Generator(torch.nn.Module):
177
- def __init__(
178
- self,
179
- initial_channel,
180
- resblock,
181
- resblock_kernel_sizes,
182
- resblock_dilation_sizes,
183
- upsample_rates,
184
- upsample_initial_channel,
185
- upsample_kernel_sizes,
186
- gin_channels=0,
187
- ):
188
- super(Generator, self).__init__()
189
- self.num_kernels = len(resblock_kernel_sizes)
190
- self.num_upsamples = len(upsample_rates)
191
- self.conv_pre = Conv1d(
192
- initial_channel, upsample_initial_channel, 7, 1, padding=3
193
- )
194
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
195
-
196
- self.ups = nn.ModuleList()
197
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
198
- self.ups.append(
199
- weight_norm(
200
- ConvTranspose1d(
201
- upsample_initial_channel // (2**i),
202
- upsample_initial_channel // (2 ** (i + 1)),
203
- k,
204
- u,
205
- padding=(k - u) // 2,
206
- )
207
- )
208
- )
209
-
210
- self.resblocks = nn.ModuleList()
211
- for i in range(len(self.ups)):
212
- ch = upsample_initial_channel // (2 ** (i + 1))
213
- for j, (k, d) in enumerate(
214
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
215
- ):
216
- self.resblocks.append(resblock(ch, k, d))
217
-
218
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
219
- self.ups.apply(init_weights)
220
-
221
- if gin_channels != 0:
222
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
223
-
224
- def forward(self, x, g=None):
225
- x = self.conv_pre(x)
226
- if g is not None:
227
- x = x + self.cond(g)
228
-
229
- for i in range(self.num_upsamples):
230
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
231
- x = self.ups[i](x)
232
- xs = None
233
- for j in range(self.num_kernels):
234
- if xs is None:
235
- xs = self.resblocks[i * self.num_kernels + j](x)
236
- else:
237
- xs += self.resblocks[i * self.num_kernels + j](x)
238
- x = xs / self.num_kernels
239
- x = F.leaky_relu(x)
240
- x = self.conv_post(x)
241
- x = torch.tanh(x)
242
-
243
- return x
244
-
245
- def remove_weight_norm(self):
246
- for l in self.ups:
247
- remove_weight_norm(l)
248
- for l in self.resblocks:
249
- l.remove_weight_norm()
250
- class SineGen(torch.nn.Module):
251
- """ Definition of sine generator
252
- SineGen(samp_rate, harmonic_num = 0,
253
- sine_amp = 0.1, noise_std = 0.003,
254
- voiced_threshold = 0,
255
- flag_for_pulse=False)
256
- samp_rate: sampling rate in Hz
257
- harmonic_num: number of harmonic overtones (default 0)
258
- sine_amp: amplitude of sine-wavefrom (default 0.1)
259
- noise_std: std of Gaussian noise (default 0.003)
260
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
261
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
262
- Note: when flag_for_pulse is True, the first time step of a voiced
263
- segment is always sin(np.pi) or cos(0)
264
- """
265
-
266
- def __init__(self, samp_rate, harmonic_num=0,
267
- sine_amp=0.1, noise_std=0.003,
268
- voiced_threshold=0,
269
- flag_for_pulse=False):
270
- super(SineGen, self).__init__()
271
- self.sine_amp = sine_amp
272
- self.noise_std = noise_std
273
- self.harmonic_num = harmonic_num
274
- self.dim = self.harmonic_num + 1
275
- self.sampling_rate = samp_rate
276
- self.voiced_threshold = voiced_threshold
277
-
278
- def _f02uv(self, f0):
279
- # generate uv signal
280
- uv = torch.ones_like(f0)
281
- uv = uv * (f0 > self.voiced_threshold)
282
- return uv
283
-
284
- def forward(self, f0,upp):
285
- """ sine_tensor, uv = forward(f0)
286
- input F0: tensor(batchsize=1, length, dim=1)
287
- f0 for unvoiced steps should be 0
288
- output sine_tensor: tensor(batchsize=1, length, dim)
289
- output uv: tensor(batchsize=1, length, 1)
290
- """
291
- with torch.no_grad():
292
- f0 = f0[:, None].transpose(1, 2)
293
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,device=f0.device)
294
- # fundamental component
295
- f0_buf[:, :, 0] = f0[:, :, 0]
296
- for idx in np.arange(self.harmonic_num):f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)# idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
297
- rad_values = (f0_buf / self.sampling_rate) % 1###%1意味着n_har的乘积无法后处理优化
298
- rand_ini = torch.rand(f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device)
299
- rand_ini[:, 0] = 0
300
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
301
- tmp_over_one = torch.cumsum(rad_values, 1)# % 1 #####%1意味着后面的cumsum无法再优化
302
- tmp_over_one*=upp
303
- tmp_over_one=F.interpolate(tmp_over_one.transpose(2, 1), scale_factor=upp, mode='linear', align_corners=True).transpose(2, 1)
304
- rad_values=F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)#######
305
- tmp_over_one%=1
306
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
307
- cumsum_shift = torch.zeros_like(rad_values)
308
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
309
- sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
310
- sine_waves = sine_waves * self.sine_amp
311
- uv = self._f02uv(f0)
312
- uv = F.interpolate(uv.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
313
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
314
- noise = noise_amp * torch.randn_like(sine_waves)
315
- sine_waves = sine_waves * uv + noise
316
- return sine_waves, uv, noise
317
- class SourceModuleHnNSF(torch.nn.Module):
318
- """ SourceModule for hn-nsf
319
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
320
- add_noise_std=0.003, voiced_threshod=0)
321
- sampling_rate: sampling_rate in Hz
322
- harmonic_num: number of harmonic above F0 (default: 0)
323
- sine_amp: amplitude of sine source signal (default: 0.1)
324
- add_noise_std: std of additive Gaussian noise (default: 0.003)
325
- note that amplitude of noise in unvoiced is decided
326
- by sine_amp
327
- voiced_threshold: threhold to set U/V given F0 (default: 0)
328
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
329
- F0_sampled (batchsize, length, 1)
330
- Sine_source (batchsize, length, 1)
331
- noise_source (batchsize, length 1)
332
- uv (batchsize, length, 1)
333
- """
334
-
335
- def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
336
- add_noise_std=0.003, voiced_threshod=0,is_half=True):
337
- super(SourceModuleHnNSF, self).__init__()
338
-
339
- self.sine_amp = sine_amp
340
- self.noise_std = add_noise_std
341
- self.is_half=is_half
342
- # to produce sine waveforms
343
- self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
344
- sine_amp, add_noise_std, voiced_threshod)
345
-
346
- # to merge source harmonics into a single excitation
347
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
348
- self.l_tanh = torch.nn.Tanh()
349
-
350
- def forward(self, x,upp=None):
351
- sine_wavs, uv, _ = self.l_sin_gen(x,upp)
352
- if(self.is_half==True):sine_wavs=sine_wavs.half()
353
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
354
- return sine_merge,None,None# noise, uv
355
- class GeneratorNSF(torch.nn.Module):
356
- def __init__(
357
- self,
358
- initial_channel,
359
- resblock,
360
- resblock_kernel_sizes,
361
- resblock_dilation_sizes,
362
- upsample_rates,
363
- upsample_initial_channel,
364
- upsample_kernel_sizes,
365
- gin_channels=0,
366
- sr=40000,
367
- is_half=False
368
- ):
369
- super(GeneratorNSF, self).__init__()
370
- self.num_kernels = len(resblock_kernel_sizes)
371
- self.num_upsamples = len(upsample_rates)
372
-
373
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
374
- self.m_source = SourceModuleHnNSF(
375
- sampling_rate=sr,
376
- harmonic_num=0,
377
- is_half=is_half
378
- )
379
- self.noise_convs = nn.ModuleList()
380
- self.conv_pre = Conv1d(
381
- initial_channel, upsample_initial_channel, 7, 1, padding=3
382
- )
383
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
384
-
385
- self.ups = nn.ModuleList()
386
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
387
- c_cur = upsample_initial_channel // (2 ** (i + 1))
388
- self.ups.append(
389
- weight_norm(
390
- ConvTranspose1d(
391
- upsample_initial_channel // (2**i),
392
- upsample_initial_channel // (2 ** (i + 1)),
393
- k,
394
- u,
395
- padding=(k - u) // 2,
396
- )
397
- )
398
- )
399
- if i + 1 < len(upsample_rates):
400
- stride_f0 = np.prod(upsample_rates[i + 1:])
401
- self.noise_convs.append(Conv1d(
402
- 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
403
- else:
404
- self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
405
-
406
- self.resblocks = nn.ModuleList()
407
- for i in range(len(self.ups)):
408
- ch = upsample_initial_channel // (2 ** (i + 1))
409
- for j, (k, d) in enumerate(
410
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
411
- ):
412
- self.resblocks.append(resblock(ch, k, d))
413
-
414
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
415
- self.ups.apply(init_weights)
416
-
417
- if gin_channels != 0:
418
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
419
-
420
- self.upp=np.prod(upsample_rates)
421
-
422
- def forward(self, x, f0,g=None):
423
- har_source, noi_source, uv = self.m_source(f0,self.upp)
424
- har_source = har_source.transpose(1, 2)
425
- x = self.conv_pre(x)
426
- if g is not None:
427
- x = x + self.cond(g)
428
-
429
- for i in range(self.num_upsamples):
430
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
431
- x = self.ups[i](x)
432
- x_source = self.noise_convs[i](har_source)
433
- x = x + x_source
434
- xs = None
435
- for j in range(self.num_kernels):
436
- if xs is None:
437
- xs = self.resblocks[i * self.num_kernels + j](x)
438
- else:
439
- xs += self.resblocks[i * self.num_kernels + j](x)
440
- x = xs / self.num_kernels
441
- x = F.leaky_relu(x)
442
- x = self.conv_post(x)
443
- x = torch.tanh(x)
444
- return x
445
-
446
- def remove_weight_norm(self):
447
- for l in self.ups:
448
- remove_weight_norm(l)
449
- for l in self.resblocks:
450
- l.remove_weight_norm()
451
- class SynthesizerTrnMs256NSF(nn.Module):
452
- """
453
- Synthesizer for Training
454
- """
455
-
456
- def __init__(
457
- self,
458
- spec_channels,
459
- segment_size,
460
- inter_channels,
461
- hidden_channels,
462
- filter_channels,
463
- n_heads,
464
- n_layers,
465
- kernel_size,
466
- p_dropout,
467
- resblock,
468
- resblock_kernel_sizes,
469
- resblock_dilation_sizes,
470
- upsample_rates,
471
- upsample_initial_channel,
472
- upsample_kernel_sizes,
473
- spk_embed_dim,
474
- gin_channels=0,
475
- sr=40000,
476
- **kwargs
477
- ):
478
-
479
- super().__init__()
480
- self.spec_channels = spec_channels
481
- self.inter_channels = inter_channels
482
- self.hidden_channels = hidden_channels
483
- self.filter_channels = filter_channels
484
- self.n_heads = n_heads
485
- self.n_layers = n_layers
486
- self.kernel_size = kernel_size
487
- self.p_dropout = p_dropout
488
- self.resblock = resblock
489
- self.resblock_kernel_sizes = resblock_kernel_sizes
490
- self.resblock_dilation_sizes = resblock_dilation_sizes
491
- self.upsample_rates = upsample_rates
492
- self.upsample_initial_channel = upsample_initial_channel
493
- self.upsample_kernel_sizes = upsample_kernel_sizes
494
- self.segment_size = segment_size
495
- self.gin_channels = gin_channels
496
- self.spk_embed_dim=spk_embed_dim
497
- self.enc_p = TextEncoder256(
498
- inter_channels,
499
- hidden_channels,
500
- filter_channels,
501
- n_heads,
502
- n_layers,
503
- kernel_size,
504
- p_dropout,
505
- )
506
- self.dec = GeneratorNSF(
507
- inter_channels,
508
- resblock,
509
- resblock_kernel_sizes,
510
- resblock_dilation_sizes,
511
- upsample_rates,
512
- upsample_initial_channel,
513
- upsample_kernel_sizes,
514
- gin_channels=0,
515
- sr=sr,
516
- is_half=kwargs["is_half"]
517
- )
518
- self.enc_q = PosteriorEncoder(
519
- spec_channels,
520
- inter_channels,
521
- hidden_channels,
522
- 5,
523
- 1,
524
- 16,
525
- gin_channels=gin_channels,
526
- )
527
- self.flow = ResidualCouplingBlock(
528
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
529
- )
530
- self.emb_g = nn.Linear(self.spk_embed_dim, gin_channels)
531
-
532
- def remove_weight_norm(self):
533
- self.dec.remove_weight_norm()
534
- self.flow.remove_weight_norm()
535
- self.enc_q.remove_weight_norm()
536
-
537
- def infer(self, phone, phone_lengths, pitch,pitchf, ds,max_len=None):
538
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
539
- if("float16"in str(m_p.dtype)):ds=ds.half()
540
- ds=ds.to(m_p.device)
541
- g = self.emb_g(ds).unsqueeze(-1) # [b, h, 1]#
542
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66) * x_mask
543
-
544
- z = self.flow(z_p, x_mask, g=g, reverse=True)
545
- o = self.dec((z * x_mask)[:, :, :max_len],pitchf, g=None)
546
- return o, x_mask, (z, z_p, m_p, logs_p)
547
- class SynthesizerTrn256NSFkm(nn.Module):
548
- """
549
- Synthesizer for Training
550
- """
551
-
552
- def __init__(
553
- self,
554
- spec_channels,
555
- segment_size,
556
- inter_channels,
557
- hidden_channels,
558
- filter_channels,
559
- n_heads,
560
- n_layers,
561
- kernel_size,
562
- p_dropout,
563
- resblock,
564
- resblock_kernel_sizes,
565
- resblock_dilation_sizes,
566
- upsample_rates,
567
- upsample_initial_channel,
568
- upsample_kernel_sizes,
569
- spk_embed_dim,
570
- gin_channels=0,
571
- sr=40000,
572
- **kwargs
573
- ):
574
-
575
- super().__init__()
576
- self.spec_channels = spec_channels
577
- self.inter_channels = inter_channels
578
- self.hidden_channels = hidden_channels
579
- self.filter_channels = filter_channels
580
- self.n_heads = n_heads
581
- self.n_layers = n_layers
582
- self.kernel_size = kernel_size
583
- self.p_dropout = p_dropout
584
- self.resblock = resblock
585
- self.resblock_kernel_sizes = resblock_kernel_sizes
586
- self.resblock_dilation_sizes = resblock_dilation_sizes
587
- self.upsample_rates = upsample_rates
588
- self.upsample_initial_channel = upsample_initial_channel
589
- self.upsample_kernel_sizes = upsample_kernel_sizes
590
- self.segment_size = segment_size
591
- self.gin_channels = gin_channels
592
-
593
- self.enc_p = TextEncoder256km(
594
- inter_channels,
595
- hidden_channels,
596
- filter_channels,
597
- n_heads,
598
- n_layers,
599
- kernel_size,
600
- p_dropout,
601
- )
602
- self.dec = GeneratorNSF(
603
- inter_channels,
604
- resblock,
605
- resblock_kernel_sizes,
606
- resblock_dilation_sizes,
607
- upsample_rates,
608
- upsample_initial_channel,
609
- upsample_kernel_sizes,
610
- gin_channels=0,
611
- sr=sr,
612
- is_half=kwargs["is_half"]
613
- )
614
- self.enc_q = PosteriorEncoder(
615
- spec_channels,
616
- inter_channels,
617
- hidden_channels,
618
- 5,
619
- 1,
620
- 16,
621
- gin_channels=gin_channels,
622
- )
623
- self.flow = ResidualCouplingBlock(
624
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
625
- )
626
-
627
- def remove_weight_norm(self):
628
- self.dec.remove_weight_norm()
629
- self.flow.remove_weight_norm()
630
- self.enc_q.remove_weight_norm()
631
-
632
- def forward(self, phone, phone_lengths, pitch, pitchf, y, y_lengths):
633
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
634
-
635
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=None)
636
- z_p = self.flow(z, y_mask, g=None)
637
-
638
- z_slice, ids_slice = commons.rand_slice_segments(
639
- z, y_lengths, self.segment_size
640
- )
641
-
642
- pitchf = commons.slice_segments2(
643
- pitchf, ids_slice, self.segment_size
644
- )
645
- o = self.dec(z_slice, pitchf,g=None)
646
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
647
-
648
- def infer(self, phone, phone_lengths, pitch, nsff0,max_len=None):
649
- # torch.cuda.synchronize()
650
- # t0=ttime()
651
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
652
- # torch.cuda.synchronize()
653
- # t1=ttime()
654
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66) * x_mask
655
- # torch.cuda.synchronize()
656
- # t2=ttime()
657
- z = self.flow(z_p, x_mask, g=None, reverse=True)
658
- # torch.cuda.synchronize()
659
- # t3=ttime()
660
- o = self.dec((z * x_mask)[:, :, :max_len], nsff0,g=None)
661
- # torch.cuda.synchronize()
662
- # t4=ttime()
663
- # print(1233333333333333333333333,t1-t0,t2-t1,t3-t2,t4-t3)
664
- return o, x_mask, (z, z_p, m_p, logs_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/modules.py DELETED
@@ -1,522 +0,0 @@
1
- import copy
2
- import math
3
- import numpy as np
4
- import scipy
5
- import torch
6
- from torch import nn
7
- from torch.nn import functional as F
8
-
9
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
- from torch.nn.utils import weight_norm, remove_weight_norm
11
-
12
- from infer_pack import commons
13
- from infer_pack.commons import init_weights, get_padding
14
- from infer_pack.transforms import piecewise_rational_quadratic_transform
15
-
16
-
17
- LRELU_SLOPE = 0.1
18
-
19
-
20
- class LayerNorm(nn.Module):
21
- def __init__(self, channels, eps=1e-5):
22
- super().__init__()
23
- self.channels = channels
24
- self.eps = eps
25
-
26
- self.gamma = nn.Parameter(torch.ones(channels))
27
- self.beta = nn.Parameter(torch.zeros(channels))
28
-
29
- def forward(self, x):
30
- x = x.transpose(1, -1)
31
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
- return x.transpose(1, -1)
33
-
34
-
35
- class ConvReluNorm(nn.Module):
36
- def __init__(
37
- self,
38
- in_channels,
39
- hidden_channels,
40
- out_channels,
41
- kernel_size,
42
- n_layers,
43
- p_dropout,
44
- ):
45
- super().__init__()
46
- self.in_channels = in_channels
47
- self.hidden_channels = hidden_channels
48
- self.out_channels = out_channels
49
- self.kernel_size = kernel_size
50
- self.n_layers = n_layers
51
- self.p_dropout = p_dropout
52
- assert n_layers > 1, "Number of layers should be larger than 0."
53
-
54
- self.conv_layers = nn.ModuleList()
55
- self.norm_layers = nn.ModuleList()
56
- self.conv_layers.append(
57
- nn.Conv1d(
58
- in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
59
- )
60
- )
61
- self.norm_layers.append(LayerNorm(hidden_channels))
62
- self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
63
- for _ in range(n_layers - 1):
64
- self.conv_layers.append(
65
- nn.Conv1d(
66
- hidden_channels,
67
- hidden_channels,
68
- kernel_size,
69
- padding=kernel_size // 2,
70
- )
71
- )
72
- self.norm_layers.append(LayerNorm(hidden_channels))
73
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74
- self.proj.weight.data.zero_()
75
- self.proj.bias.data.zero_()
76
-
77
- def forward(self, x, x_mask):
78
- x_org = x
79
- for i in range(self.n_layers):
80
- x = self.conv_layers[i](x * x_mask)
81
- x = self.norm_layers[i](x)
82
- x = self.relu_drop(x)
83
- x = x_org + self.proj(x)
84
- return x * x_mask
85
-
86
-
87
- class DDSConv(nn.Module):
88
- """
89
- Dialted and Depth-Separable Convolution
90
- """
91
-
92
- def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93
- super().__init__()
94
- self.channels = channels
95
- self.kernel_size = kernel_size
96
- self.n_layers = n_layers
97
- self.p_dropout = p_dropout
98
-
99
- self.drop = nn.Dropout(p_dropout)
100
- self.convs_sep = nn.ModuleList()
101
- self.convs_1x1 = nn.ModuleList()
102
- self.norms_1 = nn.ModuleList()
103
- self.norms_2 = nn.ModuleList()
104
- for i in range(n_layers):
105
- dilation = kernel_size**i
106
- padding = (kernel_size * dilation - dilation) // 2
107
- self.convs_sep.append(
108
- nn.Conv1d(
109
- channels,
110
- channels,
111
- kernel_size,
112
- groups=channels,
113
- dilation=dilation,
114
- padding=padding,
115
- )
116
- )
117
- self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
118
- self.norms_1.append(LayerNorm(channels))
119
- self.norms_2.append(LayerNorm(channels))
120
-
121
- def forward(self, x, x_mask, g=None):
122
- if g is not None:
123
- x = x + g
124
- for i in range(self.n_layers):
125
- y = self.convs_sep[i](x * x_mask)
126
- y = self.norms_1[i](y)
127
- y = F.gelu(y)
128
- y = self.convs_1x1[i](y)
129
- y = self.norms_2[i](y)
130
- y = F.gelu(y)
131
- y = self.drop(y)
132
- x = x + y
133
- return x * x_mask
134
-
135
-
136
- class WN(torch.nn.Module):
137
- def __init__(
138
- self,
139
- hidden_channels,
140
- kernel_size,
141
- dilation_rate,
142
- n_layers,
143
- gin_channels=0,
144
- p_dropout=0,
145
- ):
146
- super(WN, self).__init__()
147
- assert kernel_size % 2 == 1
148
- self.hidden_channels = hidden_channels
149
- self.kernel_size = (kernel_size,)
150
- self.dilation_rate = dilation_rate
151
- self.n_layers = n_layers
152
- self.gin_channels = gin_channels
153
- self.p_dropout = p_dropout
154
-
155
- self.in_layers = torch.nn.ModuleList()
156
- self.res_skip_layers = torch.nn.ModuleList()
157
- self.drop = nn.Dropout(p_dropout)
158
-
159
- if gin_channels != 0:
160
- cond_layer = torch.nn.Conv1d(
161
- gin_channels, 2 * hidden_channels * n_layers, 1
162
- )
163
- self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
164
-
165
- for i in range(n_layers):
166
- dilation = dilation_rate**i
167
- padding = int((kernel_size * dilation - dilation) / 2)
168
- in_layer = torch.nn.Conv1d(
169
- hidden_channels,
170
- 2 * hidden_channels,
171
- kernel_size,
172
- dilation=dilation,
173
- padding=padding,
174
- )
175
- in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
176
- self.in_layers.append(in_layer)
177
-
178
- # last one is not necessary
179
- if i < n_layers - 1:
180
- res_skip_channels = 2 * hidden_channels
181
- else:
182
- res_skip_channels = hidden_channels
183
-
184
- res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
185
- res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
186
- self.res_skip_layers.append(res_skip_layer)
187
-
188
- def forward(self, x, x_mask, g=None, **kwargs):
189
- output = torch.zeros_like(x)
190
- n_channels_tensor = torch.IntTensor([self.hidden_channels])
191
-
192
- if g is not None:
193
- g = self.cond_layer(g)
194
-
195
- for i in range(self.n_layers):
196
- x_in = self.in_layers[i](x)
197
- if g is not None:
198
- cond_offset = i * 2 * self.hidden_channels
199
- g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
200
- else:
201
- g_l = torch.zeros_like(x_in)
202
-
203
- acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
204
- acts = self.drop(acts)
205
-
206
- res_skip_acts = self.res_skip_layers[i](acts)
207
- if i < self.n_layers - 1:
208
- res_acts = res_skip_acts[:, : self.hidden_channels, :]
209
- x = (x + res_acts) * x_mask
210
- output = output + res_skip_acts[:, self.hidden_channels :, :]
211
- else:
212
- output = output + res_skip_acts
213
- return output * x_mask
214
-
215
- def remove_weight_norm(self):
216
- if self.gin_channels != 0:
217
- torch.nn.utils.remove_weight_norm(self.cond_layer)
218
- for l in self.in_layers:
219
- torch.nn.utils.remove_weight_norm(l)
220
- for l in self.res_skip_layers:
221
- torch.nn.utils.remove_weight_norm(l)
222
-
223
-
224
- class ResBlock1(torch.nn.Module):
225
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
226
- super(ResBlock1, self).__init__()
227
- self.convs1 = nn.ModuleList(
228
- [
229
- weight_norm(
230
- Conv1d(
231
- channels,
232
- channels,
233
- kernel_size,
234
- 1,
235
- dilation=dilation[0],
236
- padding=get_padding(kernel_size, dilation[0]),
237
- )
238
- ),
239
- weight_norm(
240
- Conv1d(
241
- channels,
242
- channels,
243
- kernel_size,
244
- 1,
245
- dilation=dilation[1],
246
- padding=get_padding(kernel_size, dilation[1]),
247
- )
248
- ),
249
- weight_norm(
250
- Conv1d(
251
- channels,
252
- channels,
253
- kernel_size,
254
- 1,
255
- dilation=dilation[2],
256
- padding=get_padding(kernel_size, dilation[2]),
257
- )
258
- ),
259
- ]
260
- )
261
- self.convs1.apply(init_weights)
262
-
263
- self.convs2 = nn.ModuleList(
264
- [
265
- weight_norm(
266
- Conv1d(
267
- channels,
268
- channels,
269
- kernel_size,
270
- 1,
271
- dilation=1,
272
- padding=get_padding(kernel_size, 1),
273
- )
274
- ),
275
- weight_norm(
276
- Conv1d(
277
- channels,
278
- channels,
279
- kernel_size,
280
- 1,
281
- dilation=1,
282
- padding=get_padding(kernel_size, 1),
283
- )
284
- ),
285
- weight_norm(
286
- Conv1d(
287
- channels,
288
- channels,
289
- kernel_size,
290
- 1,
291
- dilation=1,
292
- padding=get_padding(kernel_size, 1),
293
- )
294
- ),
295
- ]
296
- )
297
- self.convs2.apply(init_weights)
298
-
299
- def forward(self, x, x_mask=None):
300
- for c1, c2 in zip(self.convs1, self.convs2):
301
- xt = F.leaky_relu(x, LRELU_SLOPE)
302
- if x_mask is not None:
303
- xt = xt * x_mask
304
- xt = c1(xt)
305
- xt = F.leaky_relu(xt, LRELU_SLOPE)
306
- if x_mask is not None:
307
- xt = xt * x_mask
308
- xt = c2(xt)
309
- x = xt + x
310
- if x_mask is not None:
311
- x = x * x_mask
312
- return x
313
-
314
- def remove_weight_norm(self):
315
- for l in self.convs1:
316
- remove_weight_norm(l)
317
- for l in self.convs2:
318
- remove_weight_norm(l)
319
-
320
-
321
- class ResBlock2(torch.nn.Module):
322
- def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
323
- super(ResBlock2, self).__init__()
324
- self.convs = nn.ModuleList(
325
- [
326
- weight_norm(
327
- Conv1d(
328
- channels,
329
- channels,
330
- kernel_size,
331
- 1,
332
- dilation=dilation[0],
333
- padding=get_padding(kernel_size, dilation[0]),
334
- )
335
- ),
336
- weight_norm(
337
- Conv1d(
338
- channels,
339
- channels,
340
- kernel_size,
341
- 1,
342
- dilation=dilation[1],
343
- padding=get_padding(kernel_size, dilation[1]),
344
- )
345
- ),
346
- ]
347
- )
348
- self.convs.apply(init_weights)
349
-
350
- def forward(self, x, x_mask=None):
351
- for c in self.convs:
352
- xt = F.leaky_relu(x, LRELU_SLOPE)
353
- if x_mask is not None:
354
- xt = xt * x_mask
355
- xt = c(xt)
356
- x = xt + x
357
- if x_mask is not None:
358
- x = x * x_mask
359
- return x
360
-
361
- def remove_weight_norm(self):
362
- for l in self.convs:
363
- remove_weight_norm(l)
364
-
365
-
366
- class Log(nn.Module):
367
- def forward(self, x, x_mask, reverse=False, **kwargs):
368
- if not reverse:
369
- y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
370
- logdet = torch.sum(-y, [1, 2])
371
- return y, logdet
372
- else:
373
- x = torch.exp(x) * x_mask
374
- return x
375
-
376
-
377
- class Flip(nn.Module):
378
- def forward(self, x, *args, reverse=False, **kwargs):
379
- x = torch.flip(x, [1])
380
- if not reverse:
381
- logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
382
- return x, logdet
383
- else:
384
- return x
385
-
386
-
387
- class ElementwiseAffine(nn.Module):
388
- def __init__(self, channels):
389
- super().__init__()
390
- self.channels = channels
391
- self.m = nn.Parameter(torch.zeros(channels, 1))
392
- self.logs = nn.Parameter(torch.zeros(channels, 1))
393
-
394
- def forward(self, x, x_mask, reverse=False, **kwargs):
395
- if not reverse:
396
- y = self.m + torch.exp(self.logs) * x
397
- y = y * x_mask
398
- logdet = torch.sum(self.logs * x_mask, [1, 2])
399
- return y, logdet
400
- else:
401
- x = (x - self.m) * torch.exp(-self.logs) * x_mask
402
- return x
403
-
404
-
405
- class ResidualCouplingLayer(nn.Module):
406
- def __init__(
407
- self,
408
- channels,
409
- hidden_channels,
410
- kernel_size,
411
- dilation_rate,
412
- n_layers,
413
- p_dropout=0,
414
- gin_channels=0,
415
- mean_only=False,
416
- ):
417
- assert channels % 2 == 0, "channels should be divisible by 2"
418
- super().__init__()
419
- self.channels = channels
420
- self.hidden_channels = hidden_channels
421
- self.kernel_size = kernel_size
422
- self.dilation_rate = dilation_rate
423
- self.n_layers = n_layers
424
- self.half_channels = channels // 2
425
- self.mean_only = mean_only
426
-
427
- self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
428
- self.enc = WN(
429
- hidden_channels,
430
- kernel_size,
431
- dilation_rate,
432
- n_layers,
433
- p_dropout=p_dropout,
434
- gin_channels=gin_channels,
435
- )
436
- self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
437
- self.post.weight.data.zero_()
438
- self.post.bias.data.zero_()
439
-
440
- def forward(self, x, x_mask, g=None, reverse=False):
441
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
442
- h = self.pre(x0) * x_mask
443
- h = self.enc(h, x_mask, g=g)
444
- stats = self.post(h) * x_mask
445
- if not self.mean_only:
446
- m, logs = torch.split(stats, [self.half_channels] * 2, 1)
447
- else:
448
- m = stats
449
- logs = torch.zeros_like(m)
450
-
451
- if not reverse:
452
- x1 = m + x1 * torch.exp(logs) * x_mask
453
- x = torch.cat([x0, x1], 1)
454
- logdet = torch.sum(logs, [1, 2])
455
- return x, logdet
456
- else:
457
- x1 = (x1 - m) * torch.exp(-logs) * x_mask
458
- x = torch.cat([x0, x1], 1)
459
- return x
460
-
461
- def remove_weight_norm(self):
462
- self.enc.remove_weight_norm()
463
-
464
-
465
- class ConvFlow(nn.Module):
466
- def __init__(
467
- self,
468
- in_channels,
469
- filter_channels,
470
- kernel_size,
471
- n_layers,
472
- num_bins=10,
473
- tail_bound=5.0,
474
- ):
475
- super().__init__()
476
- self.in_channels = in_channels
477
- self.filter_channels = filter_channels
478
- self.kernel_size = kernel_size
479
- self.n_layers = n_layers
480
- self.num_bins = num_bins
481
- self.tail_bound = tail_bound
482
- self.half_channels = in_channels // 2
483
-
484
- self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
485
- self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
486
- self.proj = nn.Conv1d(
487
- filter_channels, self.half_channels * (num_bins * 3 - 1), 1
488
- )
489
- self.proj.weight.data.zero_()
490
- self.proj.bias.data.zero_()
491
-
492
- def forward(self, x, x_mask, g=None, reverse=False):
493
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
494
- h = self.pre(x0)
495
- h = self.convs(h, x_mask, g=g)
496
- h = self.proj(h) * x_mask
497
-
498
- b, c, t = x0.shape
499
- h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
500
-
501
- unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
502
- unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
503
- self.filter_channels
504
- )
505
- unnormalized_derivatives = h[..., 2 * self.num_bins :]
506
-
507
- x1, logabsdet = piecewise_rational_quadratic_transform(
508
- x1,
509
- unnormalized_widths,
510
- unnormalized_heights,
511
- unnormalized_derivatives,
512
- inverse=reverse,
513
- tails="linear",
514
- tail_bound=self.tail_bound,
515
- )
516
-
517
- x = torch.cat([x0, x1], 1) * x_mask
518
- logdet = torch.sum(logabsdet * x_mask, [1, 2])
519
- if not reverse:
520
- return x, logdet
521
- else:
522
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/transforms.py DELETED
@@ -1,193 +0,0 @@
1
- import torch
2
- from torch.nn import functional as F
3
-
4
- import numpy as np
5
-
6
-
7
- DEFAULT_MIN_BIN_WIDTH = 1e-3
8
- DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
- DEFAULT_MIN_DERIVATIVE = 1e-3
10
-
11
-
12
- def piecewise_rational_quadratic_transform(inputs,
13
- unnormalized_widths,
14
- unnormalized_heights,
15
- unnormalized_derivatives,
16
- inverse=False,
17
- tails=None,
18
- tail_bound=1.,
19
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
20
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
21
- min_derivative=DEFAULT_MIN_DERIVATIVE):
22
-
23
- if tails is None:
24
- spline_fn = rational_quadratic_spline
25
- spline_kwargs = {}
26
- else:
27
- spline_fn = unconstrained_rational_quadratic_spline
28
- spline_kwargs = {
29
- 'tails': tails,
30
- 'tail_bound': tail_bound
31
- }
32
-
33
- outputs, logabsdet = spline_fn(
34
- inputs=inputs,
35
- unnormalized_widths=unnormalized_widths,
36
- unnormalized_heights=unnormalized_heights,
37
- unnormalized_derivatives=unnormalized_derivatives,
38
- inverse=inverse,
39
- min_bin_width=min_bin_width,
40
- min_bin_height=min_bin_height,
41
- min_derivative=min_derivative,
42
- **spline_kwargs
43
- )
44
- return outputs, logabsdet
45
-
46
-
47
- def searchsorted(bin_locations, inputs, eps=1e-6):
48
- bin_locations[..., -1] += eps
49
- return torch.sum(
50
- inputs[..., None] >= bin_locations,
51
- dim=-1
52
- ) - 1
53
-
54
-
55
- def unconstrained_rational_quadratic_spline(inputs,
56
- unnormalized_widths,
57
- unnormalized_heights,
58
- unnormalized_derivatives,
59
- inverse=False,
60
- tails='linear',
61
- tail_bound=1.,
62
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
63
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
64
- min_derivative=DEFAULT_MIN_DERIVATIVE):
65
- inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
66
- outside_interval_mask = ~inside_interval_mask
67
-
68
- outputs = torch.zeros_like(inputs)
69
- logabsdet = torch.zeros_like(inputs)
70
-
71
- if tails == 'linear':
72
- unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
73
- constant = np.log(np.exp(1 - min_derivative) - 1)
74
- unnormalized_derivatives[..., 0] = constant
75
- unnormalized_derivatives[..., -1] = constant
76
-
77
- outputs[outside_interval_mask] = inputs[outside_interval_mask]
78
- logabsdet[outside_interval_mask] = 0
79
- else:
80
- raise RuntimeError('{} tails are not implemented.'.format(tails))
81
-
82
- outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
83
- inputs=inputs[inside_interval_mask],
84
- unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
- unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
- unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
- inverse=inverse,
88
- left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
89
- min_bin_width=min_bin_width,
90
- min_bin_height=min_bin_height,
91
- min_derivative=min_derivative
92
- )
93
-
94
- return outputs, logabsdet
95
-
96
- def rational_quadratic_spline(inputs,
97
- unnormalized_widths,
98
- unnormalized_heights,
99
- unnormalized_derivatives,
100
- inverse=False,
101
- left=0., right=1., bottom=0., top=1.,
102
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
103
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
104
- min_derivative=DEFAULT_MIN_DERIVATIVE):
105
- if torch.min(inputs) < left or torch.max(inputs) > right:
106
- raise ValueError('Input to a transform is not within its domain')
107
-
108
- num_bins = unnormalized_widths.shape[-1]
109
-
110
- if min_bin_width * num_bins > 1.0:
111
- raise ValueError('Minimal bin width too large for the number of bins')
112
- if min_bin_height * num_bins > 1.0:
113
- raise ValueError('Minimal bin height too large for the number of bins')
114
-
115
- widths = F.softmax(unnormalized_widths, dim=-1)
116
- widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
117
- cumwidths = torch.cumsum(widths, dim=-1)
118
- cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
119
- cumwidths = (right - left) * cumwidths + left
120
- cumwidths[..., 0] = left
121
- cumwidths[..., -1] = right
122
- widths = cumwidths[..., 1:] - cumwidths[..., :-1]
123
-
124
- derivatives = min_derivative + F.softplus(unnormalized_derivatives)
125
-
126
- heights = F.softmax(unnormalized_heights, dim=-1)
127
- heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
128
- cumheights = torch.cumsum(heights, dim=-1)
129
- cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
130
- cumheights = (top - bottom) * cumheights + bottom
131
- cumheights[..., 0] = bottom
132
- cumheights[..., -1] = top
133
- heights = cumheights[..., 1:] - cumheights[..., :-1]
134
-
135
- if inverse:
136
- bin_idx = searchsorted(cumheights, inputs)[..., None]
137
- else:
138
- bin_idx = searchsorted(cumwidths, inputs)[..., None]
139
-
140
- input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
141
- input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
142
-
143
- input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
144
- delta = heights / widths
145
- input_delta = delta.gather(-1, bin_idx)[..., 0]
146
-
147
- input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
148
- input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
149
-
150
- input_heights = heights.gather(-1, bin_idx)[..., 0]
151
-
152
- if inverse:
153
- a = (((inputs - input_cumheights) * (input_derivatives
154
- + input_derivatives_plus_one
155
- - 2 * input_delta)
156
- + input_heights * (input_delta - input_derivatives)))
157
- b = (input_heights * input_derivatives
158
- - (inputs - input_cumheights) * (input_derivatives
159
- + input_derivatives_plus_one
160
- - 2 * input_delta))
161
- c = - input_delta * (inputs - input_cumheights)
162
-
163
- discriminant = b.pow(2) - 4 * a * c
164
- assert (discriminant >= 0).all()
165
-
166
- root = (2 * c) / (-b - torch.sqrt(discriminant))
167
- outputs = root * input_bin_widths + input_cumwidths
168
-
169
- theta_one_minus_theta = root * (1 - root)
170
- denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
171
- * theta_one_minus_theta)
172
- derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
173
- + 2 * input_delta * theta_one_minus_theta
174
- + input_derivatives * (1 - root).pow(2))
175
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
176
-
177
- return outputs, -logabsdet
178
- else:
179
- theta = (inputs - input_cumwidths) / input_bin_widths
180
- theta_one_minus_theta = theta * (1 - theta)
181
-
182
- numerator = input_heights * (input_delta * theta.pow(2)
183
- + input_derivatives * theta_one_minus_theta)
184
- denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
185
- * theta_one_minus_theta)
186
- outputs = input_cumheights + numerator / denominator
187
-
188
- derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
189
- + 2 * input_delta * theta_one_minus_theta
190
- + input_derivatives * (1 - theta).pow(2))
191
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
192
-
193
- return outputs, logabsdet