|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This code contains the spectrogram and Hybrid version of Demucs. |
|
""" |
|
import math |
|
|
|
from .filtering import wiener |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from fractions import Fraction |
|
from einops import rearrange |
|
|
|
from .transformer import CrossTransformerEncoder |
|
|
|
from .demucs import rescale_module |
|
from .states import capture_init |
|
from .spec import spectro, ispectro |
|
from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer |
|
|
|
|
|
class HTDemucs(nn.Module): |
|
""" |
|
Spectrogram and hybrid Demucs model. |
|
The spectrogram model has the same structure as Demucs, except the first few layers are over the |
|
frequency axis, until there is only 1 frequency, and then it moves to time convolutions. |
|
Frequency layers can still access information across time steps thanks to the DConv residual. |
|
|
|
Hybrid model have a parallel time branch. At some layer, the time branch has the same stride |
|
as the frequency branch and then the two are combined. The opposite happens in the decoder. |
|
|
|
Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), |
|
or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on |
|
Open Unmix implementation [Stoter et al. 2019]. |
|
|
|
The loss is always on the temporal domain, by backpropagating through the above |
|
output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks |
|
a bit Wiener filtering, as doing more iteration at test time will change the spectrogram |
|
contribution, without changing the one from the waveform, which will lead to worse performance. |
|
I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. |
|
CaC on the other hand provides similar performance for hybrid, and works naturally with |
|
hybrid models. |
|
|
|
This model also uses frequency embeddings are used to improve efficiency on convolutions |
|
over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). |
|
|
|
Unlike classic Demucs, there is no resampling here, and normalization is always applied. |
|
""" |
|
|
|
@capture_init |
|
def __init__( |
|
self, |
|
sources, |
|
|
|
audio_channels=2, |
|
channels=48, |
|
channels_time=None, |
|
growth=2, |
|
|
|
nfft=4096, |
|
wiener_iters=0, |
|
end_iters=0, |
|
wiener_residual=False, |
|
cac=True, |
|
|
|
depth=4, |
|
rewrite=True, |
|
|
|
multi_freqs=None, |
|
multi_freqs_depth=3, |
|
freq_emb=0.2, |
|
emb_scale=10, |
|
emb_smooth=True, |
|
|
|
kernel_size=8, |
|
time_stride=2, |
|
stride=4, |
|
context=1, |
|
context_enc=0, |
|
|
|
norm_starts=4, |
|
norm_groups=4, |
|
|
|
dconv_mode=1, |
|
dconv_depth=2, |
|
dconv_comp=8, |
|
dconv_init=1e-3, |
|
|
|
bottom_channels=0, |
|
|
|
t_layers=5, |
|
t_emb="sin", |
|
t_hidden_scale=4.0, |
|
t_heads=8, |
|
t_dropout=0.0, |
|
t_max_positions=10000, |
|
t_norm_in=True, |
|
t_norm_in_group=False, |
|
t_group_norm=False, |
|
t_norm_first=True, |
|
t_norm_out=True, |
|
t_max_period=10000.0, |
|
t_weight_decay=0.0, |
|
t_lr=None, |
|
t_layer_scale=True, |
|
t_gelu=True, |
|
t_weight_pos_embed=1.0, |
|
t_sin_random_shift=0, |
|
t_cape_mean_normalize=True, |
|
t_cape_augment=True, |
|
t_cape_glob_loc_scale=[5000.0, 1.0, 1.4], |
|
t_sparse_self_attn=False, |
|
t_sparse_cross_attn=False, |
|
t_mask_type="diag", |
|
t_mask_random_seed=42, |
|
t_sparse_attn_window=500, |
|
t_global_window=100, |
|
t_sparsity=0.95, |
|
t_auto_sparsity=False, |
|
|
|
t_cross_first=False, |
|
|
|
rescale=0.1, |
|
|
|
samplerate=44100, |
|
segment=10, |
|
use_train_segment=True, |
|
): |
|
""" |
|
Args: |
|
sources (list[str]): list of source names. |
|
audio_channels (int): input/output audio channels. |
|
channels (int): initial number of hidden channels. |
|
channels_time: if not None, use a different `channels` value for the time branch. |
|
growth: increase the number of hidden channels by this factor at each layer. |
|
nfft: number of fft bins. Note that changing this require careful computation of |
|
various shape parameters and will not work out of the box for hybrid models. |
|
wiener_iters: when using Wiener filtering, number of iterations at test time. |
|
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. |
|
wiener_residual: add residual source before wiener filtering. |
|
cac: uses complex as channels, i.e. complex numbers are 2 channels each |
|
in input and output. no further processing is done before ISTFT. |
|
depth (int): number of layers in the encoder and in the decoder. |
|
rewrite (bool): add 1x1 convolution to each layer. |
|
multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. |
|
multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost |
|
layers will be wrapped. |
|
freq_emb: add frequency embedding after the first frequency layer if > 0, |
|
the actual value controls the weight of the embedding. |
|
emb_scale: equivalent to scaling the embedding learning rate |
|
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). |
|
kernel_size: kernel_size for encoder and decoder layers. |
|
stride: stride for encoder and decoder layers. |
|
time_stride: stride for the final time layer, after the merge. |
|
context: context for 1x1 conv in the decoder. |
|
context_enc: context for 1x1 conv in the encoder. |
|
norm_starts: layer at which group norm starts being used. |
|
decoder layers are numbered in reverse order. |
|
norm_groups: number of groups for group norm. |
|
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. |
|
dconv_depth: depth of residual DConv branch. |
|
dconv_comp: compression of DConv branch. |
|
dconv_attn: adds attention layers in DConv branch starting at this layer. |
|
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. |
|
dconv_init: initial scale for the DConv branch LayerScale. |
|
bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the |
|
transformer in order to change the number of channels |
|
t_layers: number of layers in each branch (waveform and spec) of the transformer |
|
t_emb: "sin", "cape" or "scaled" |
|
t_hidden_scale: the hidden scale of the Feedforward parts of the transformer |
|
for instance if C = 384 (the number of channels in the transformer) and |
|
t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension |
|
384 * 4 = 1536 |
|
t_heads: number of heads for the transformer |
|
t_dropout: dropout in the transformer |
|
t_max_positions: max_positions for the "scaled" positional embedding, only |
|
useful if t_emb="scaled" |
|
t_norm_in: (bool) norm before addinf positional embedding and getting into the |
|
transformer layers |
|
t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the |
|
timesteps (GroupNorm with group=1) |
|
t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the |
|
timesteps (GroupNorm with group=1) |
|
t_norm_first: (bool) if True the norm is before the attention and before the FFN |
|
t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer |
|
t_max_period: (float) denominator in the sinusoidal embedding expression |
|
t_weight_decay: (float) weight decay for the transformer |
|
t_lr: (float) specific learning rate for the transformer |
|
t_layer_scale: (bool) Layer Scale for the transformer |
|
t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else |
|
t_weight_pos_embed: (float) weighting of the positional embedding |
|
t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings |
|
see: https://arxiv.org/abs/2106.03143 |
|
t_cape_augment: (bool) if t_emb="cape", must be True during training and False |
|
during the inference, see: https://arxiv.org/abs/2106.03143 |
|
t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters |
|
see: https://arxiv.org/abs/2106.03143 |
|
t_sparse_self_attn: (bool) if True, the self attentions are sparse |
|
t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it |
|
unless you designed really specific masks) |
|
t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination |
|
with '_' between: i.e. "diag_jmask_random" (note that this is permutation |
|
invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag") |
|
t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed |
|
that generated the random part of the mask |
|
t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and |
|
a key (j), the mask is True id |i-j|<=t_sparse_attn_window |
|
t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :] |
|
and mask[:, :t_global_window] will be True |
|
t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity |
|
level of the random part of the mask. |
|
t_cross_first: (bool) if True cross attention is the first layer of the |
|
transformer (False seems to be better) |
|
rescale: weight rescaling trick |
|
use_train_segment: (bool) if True, the actual size that is used during the |
|
training is used during inference. |
|
""" |
|
super().__init__() |
|
self.cac = cac |
|
self.wiener_residual = wiener_residual |
|
self.audio_channels = audio_channels |
|
self.sources = sources |
|
self.kernel_size = kernel_size |
|
self.context = context |
|
self.stride = stride |
|
self.depth = depth |
|
self.bottom_channels = bottom_channels |
|
self.channels = channels |
|
self.samplerate = samplerate |
|
self.segment = segment |
|
self.use_train_segment = use_train_segment |
|
self.nfft = nfft |
|
self.hop_length = nfft // 4 |
|
self.wiener_iters = wiener_iters |
|
self.end_iters = end_iters |
|
self.freq_emb = None |
|
assert wiener_iters == end_iters |
|
|
|
self.encoder = nn.ModuleList() |
|
self.decoder = nn.ModuleList() |
|
|
|
self.tencoder = nn.ModuleList() |
|
self.tdecoder = nn.ModuleList() |
|
|
|
chin = audio_channels |
|
chin_z = chin |
|
if self.cac: |
|
chin_z *= 2 |
|
chout = channels_time or channels |
|
chout_z = channels |
|
freqs = nfft // 2 |
|
|
|
for index in range(depth): |
|
norm = index >= norm_starts |
|
freq = freqs > 1 |
|
stri = stride |
|
ker = kernel_size |
|
if not freq: |
|
assert freqs == 1 |
|
ker = time_stride * 2 |
|
stri = time_stride |
|
|
|
pad = True |
|
last_freq = False |
|
if freq and freqs <= kernel_size: |
|
ker = freqs |
|
pad = False |
|
last_freq = True |
|
|
|
kw = { |
|
"kernel_size": ker, |
|
"stride": stri, |
|
"freq": freq, |
|
"pad": pad, |
|
"norm": norm, |
|
"rewrite": rewrite, |
|
"norm_groups": norm_groups, |
|
"dconv_kw": { |
|
"depth": dconv_depth, |
|
"compress": dconv_comp, |
|
"init": dconv_init, |
|
"gelu": True, |
|
}, |
|
} |
|
kwt = dict(kw) |
|
kwt["freq"] = 0 |
|
kwt["kernel_size"] = kernel_size |
|
kwt["stride"] = stride |
|
kwt["pad"] = True |
|
kw_dec = dict(kw) |
|
multi = False |
|
if multi_freqs and index < multi_freqs_depth: |
|
multi = True |
|
kw_dec["context_freq"] = False |
|
|
|
if last_freq: |
|
chout_z = max(chout, chout_z) |
|
chout = chout_z |
|
|
|
enc = HEncLayer( |
|
chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw |
|
) |
|
if freq: |
|
tenc = HEncLayer( |
|
chin, |
|
chout, |
|
dconv=dconv_mode & 1, |
|
context=context_enc, |
|
empty=last_freq, |
|
**kwt |
|
) |
|
self.tencoder.append(tenc) |
|
|
|
if multi: |
|
enc = MultiWrap(enc, multi_freqs) |
|
self.encoder.append(enc) |
|
if index == 0: |
|
chin = self.audio_channels * len(self.sources) |
|
chin_z = chin |
|
if self.cac: |
|
chin_z *= 2 |
|
dec = HDecLayer( |
|
chout_z, |
|
chin_z, |
|
dconv=dconv_mode & 2, |
|
last=index == 0, |
|
context=context, |
|
**kw_dec |
|
) |
|
if multi: |
|
dec = MultiWrap(dec, multi_freqs) |
|
if freq: |
|
tdec = HDecLayer( |
|
chout, |
|
chin, |
|
dconv=dconv_mode & 2, |
|
empty=last_freq, |
|
last=index == 0, |
|
context=context, |
|
**kwt |
|
) |
|
self.tdecoder.insert(0, tdec) |
|
self.decoder.insert(0, dec) |
|
|
|
chin = chout |
|
chin_z = chout_z |
|
chout = int(growth * chout) |
|
chout_z = int(growth * chout_z) |
|
if freq: |
|
if freqs <= kernel_size: |
|
freqs = 1 |
|
else: |
|
freqs //= stride |
|
if index == 0 and freq_emb: |
|
self.freq_emb = ScaledEmbedding( |
|
freqs, chin_z, smooth=emb_smooth, scale=emb_scale |
|
) |
|
self.freq_emb_scale = freq_emb |
|
|
|
if rescale: |
|
rescale_module(self, reference=rescale) |
|
|
|
transformer_channels = channels * growth ** (depth - 1) |
|
if bottom_channels: |
|
self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1) |
|
self.channel_downsampler = nn.Conv1d( |
|
bottom_channels, transformer_channels, 1 |
|
) |
|
self.channel_upsampler_t = nn.Conv1d( |
|
transformer_channels, bottom_channels, 1 |
|
) |
|
self.channel_downsampler_t = nn.Conv1d( |
|
bottom_channels, transformer_channels, 1 |
|
) |
|
|
|
transformer_channels = bottom_channels |
|
|
|
if t_layers > 0: |
|
self.crosstransformer = CrossTransformerEncoder( |
|
dim=transformer_channels, |
|
emb=t_emb, |
|
hidden_scale=t_hidden_scale, |
|
num_heads=t_heads, |
|
num_layers=t_layers, |
|
cross_first=t_cross_first, |
|
dropout=t_dropout, |
|
max_positions=t_max_positions, |
|
norm_in=t_norm_in, |
|
norm_in_group=t_norm_in_group, |
|
group_norm=t_group_norm, |
|
norm_first=t_norm_first, |
|
norm_out=t_norm_out, |
|
max_period=t_max_period, |
|
weight_decay=t_weight_decay, |
|
lr=t_lr, |
|
layer_scale=t_layer_scale, |
|
gelu=t_gelu, |
|
sin_random_shift=t_sin_random_shift, |
|
weight_pos_embed=t_weight_pos_embed, |
|
cape_mean_normalize=t_cape_mean_normalize, |
|
cape_augment=t_cape_augment, |
|
cape_glob_loc_scale=t_cape_glob_loc_scale, |
|
sparse_self_attn=t_sparse_self_attn, |
|
sparse_cross_attn=t_sparse_cross_attn, |
|
mask_type=t_mask_type, |
|
mask_random_seed=t_mask_random_seed, |
|
sparse_attn_window=t_sparse_attn_window, |
|
global_window=t_global_window, |
|
sparsity=t_sparsity, |
|
auto_sparsity=t_auto_sparsity, |
|
) |
|
else: |
|
self.crosstransformer = None |
|
|
|
def _spec(self, x): |
|
hl = self.hop_length |
|
nfft = self.nfft |
|
x0 = x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert hl == nfft // 4 |
|
le = int(math.ceil(x.shape[-1] / hl)) |
|
pad = hl // 2 * 3 |
|
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") |
|
|
|
z = spectro(x, nfft, hl)[..., :-1, :] |
|
assert z.shape[-1] == le + 4, (z.shape, x.shape, le) |
|
z = z[..., 2: 2 + le] |
|
return z |
|
|
|
def _ispec(self, z, length=None, scale=0): |
|
hl = self.hop_length // (4**scale) |
|
z = F.pad(z, (0, 0, 0, 1)) |
|
z = F.pad(z, (2, 2)) |
|
pad = hl // 2 * 3 |
|
le = hl * int(math.ceil(length / hl)) + 2 * pad |
|
x = ispectro(z, hl, length=le) |
|
x = x[..., pad: pad + length] |
|
return x |
|
|
|
def _magnitude(self, z): |
|
|
|
|
|
if self.cac: |
|
B, C, Fr, T = z.shape |
|
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) |
|
m = m.reshape(B, C * 2, Fr, T) |
|
else: |
|
m = z.abs() |
|
return m |
|
|
|
def _mask(self, z, m): |
|
|
|
|
|
niters = self.wiener_iters |
|
if self.cac: |
|
B, S, C, Fr, T = m.shape |
|
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) |
|
out = torch.view_as_complex(out.contiguous()) |
|
return out |
|
if self.training: |
|
niters = self.end_iters |
|
if niters < 0: |
|
z = z[:, None] |
|
return z / (1e-8 + z.abs()) * m |
|
else: |
|
return self._wiener(m, z, niters) |
|
|
|
def _wiener(self, mag_out, mix_stft, niters): |
|
|
|
init = mix_stft.dtype |
|
wiener_win_len = 300 |
|
residual = self.wiener_residual |
|
|
|
B, S, C, Fq, T = mag_out.shape |
|
mag_out = mag_out.permute(0, 4, 3, 2, 1) |
|
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) |
|
|
|
outs = [] |
|
for sample in range(B): |
|
pos = 0 |
|
out = [] |
|
for pos in range(0, T, wiener_win_len): |
|
frame = slice(pos, pos + wiener_win_len) |
|
z_out = wiener( |
|
mag_out[sample, frame], |
|
mix_stft[sample, frame], |
|
niters, |
|
residual=residual, |
|
) |
|
out.append(z_out.transpose(-1, -2)) |
|
outs.append(torch.cat(out, dim=0)) |
|
out = torch.view_as_complex(torch.stack(outs, 0)) |
|
out = out.permute(0, 4, 3, 2, 1).contiguous() |
|
if residual: |
|
out = out[:, :-1] |
|
assert list(out.shape) == [B, S, C, Fq, T] |
|
return out.to(init) |
|
|
|
def valid_length(self, length: int): |
|
""" |
|
Return a length that is appropriate for evaluation. |
|
In our case, always return the training length, unless |
|
it is smaller than the given length, in which case this |
|
raises an error. |
|
""" |
|
if not self.use_train_segment: |
|
return length |
|
training_length = int(self.segment * self.samplerate) |
|
if training_length < length: |
|
raise ValueError( |
|
f"Given length {length} is longer than " |
|
f"training length {training_length}") |
|
return training_length |
|
|
|
def forward(self, mix): |
|
length = mix.shape[-1] |
|
length_pre_pad = None |
|
if self.use_train_segment: |
|
if self.training: |
|
self.segment = Fraction(mix.shape[-1], self.samplerate) |
|
else: |
|
training_length = int(self.segment * self.samplerate) |
|
if mix.shape[-1] < training_length: |
|
length_pre_pad = mix.shape[-1] |
|
mix = F.pad(mix, (0, training_length - length_pre_pad)) |
|
z = self._spec(mix) |
|
mag = self._magnitude(z).to(mix.device) |
|
x = mag |
|
|
|
B, C, Fq, T = x.shape |
|
|
|
|
|
mean = x.mean(dim=(1, 2, 3), keepdim=True) |
|
std = x.std(dim=(1, 2, 3), keepdim=True) |
|
x = (x - mean) / (1e-5 + std) |
|
|
|
|
|
|
|
xt = mix |
|
meant = xt.mean(dim=(1, 2), keepdim=True) |
|
stdt = xt.std(dim=(1, 2), keepdim=True) |
|
xt = (xt - meant) / (1e-5 + stdt) |
|
|
|
|
|
saved = [] |
|
saved_t = [] |
|
lengths = [] |
|
lengths_t = [] |
|
for idx, encode in enumerate(self.encoder): |
|
lengths.append(x.shape[-1]) |
|
inject = None |
|
if idx < len(self.tencoder): |
|
|
|
lengths_t.append(xt.shape[-1]) |
|
tenc = self.tencoder[idx] |
|
xt = tenc(xt) |
|
if not tenc.empty: |
|
|
|
saved_t.append(xt) |
|
else: |
|
|
|
|
|
inject = xt |
|
x = encode(x, inject) |
|
if idx == 0 and self.freq_emb is not None: |
|
|
|
|
|
frs = torch.arange(x.shape[-2], device=x.device) |
|
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) |
|
x = x + self.freq_emb_scale * emb |
|
|
|
saved.append(x) |
|
if self.crosstransformer: |
|
if self.bottom_channels: |
|
b, c, f, t = x.shape |
|
x = rearrange(x, "b c f t-> b c (f t)") |
|
x = self.channel_upsampler(x) |
|
x = rearrange(x, "b c (f t)-> b c f t", f=f) |
|
xt = self.channel_upsampler_t(xt) |
|
|
|
x, xt = self.crosstransformer(x, xt) |
|
|
|
if self.bottom_channels: |
|
x = rearrange(x, "b c f t-> b c (f t)") |
|
x = self.channel_downsampler(x) |
|
x = rearrange(x, "b c (f t)-> b c f t", f=f) |
|
xt = self.channel_downsampler_t(xt) |
|
|
|
for idx, decode in enumerate(self.decoder): |
|
skip = saved.pop(-1) |
|
x, pre = decode(x, skip, lengths.pop(-1)) |
|
|
|
|
|
|
|
offset = self.depth - len(self.tdecoder) |
|
if idx >= offset: |
|
tdec = self.tdecoder[idx - offset] |
|
length_t = lengths_t.pop(-1) |
|
if tdec.empty: |
|
assert pre.shape[2] == 1, pre.shape |
|
pre = pre[:, :, 0] |
|
xt, _ = tdec(pre, None, length_t) |
|
else: |
|
skip = saved_t.pop(-1) |
|
xt, _ = tdec(xt, skip, length_t) |
|
|
|
|
|
assert len(saved) == 0 |
|
assert len(lengths_t) == 0 |
|
assert len(saved_t) == 0 |
|
|
|
S = len(self.sources) |
|
x = x.view(B, S, -1, Fq, T) |
|
x = x * std[:, None] + mean[:, None] |
|
|
|
|
|
|
|
|
|
|
|
|
|
device_type = x.device.type |
|
device_load = f"{device_type}:{x.device.index}" if not device_type == 'mps' else device_type |
|
x_is_other_gpu = not device_type in ["cuda", "cpu"] |
|
|
|
if x_is_other_gpu: |
|
x = x.cpu() |
|
|
|
zout = self._mask(z, x) |
|
if self.use_train_segment: |
|
if self.training: |
|
x = self._ispec(zout, length) |
|
else: |
|
x = self._ispec(zout, training_length) |
|
else: |
|
x = self._ispec(zout, length) |
|
|
|
|
|
if x_is_other_gpu: |
|
x = x.to(device_load) |
|
|
|
if self.use_train_segment: |
|
if self.training: |
|
xt = xt.view(B, S, -1, length) |
|
else: |
|
xt = xt.view(B, S, -1, training_length) |
|
else: |
|
xt = xt.view(B, S, -1, length) |
|
xt = xt * stdt[:, None] + meant[:, None] |
|
x = xt + x |
|
if length_pre_pad: |
|
x = x[..., :length_pre_pad] |
|
return x |
|
|