mrfakename's picture
Super-squash branch 'main' using huggingface_hub
0102e16 verified
raw
history blame
5.15 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import Optional
from typing import Tuple
import torch
import torch.nn as nn
# from funasr_detach.layers.abs_normalize import AbsNormalize
# from funasr_detach.models.base_model import FunASRModel
# from funasr_detach.models.encoder.abs_encoder import AbsEncoder
from funasr_detach.frontends.abs_frontend import AbsFrontend
# from funasr_detach.models.preencoder.abs_preencoder import AbsPreEncoder
# from funasr_detach.models.specaug.abs_specaug import AbsSpecAug
from funasr_detach.train_utils.device_funcs import force_gatherable
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class Data2VecPretrainModel(nn.Module):
"""Data2Vec Pretrain model"""
def __init__(
self,
frontend=None,
specaug=None,
normalize=None,
encoder=None,
preencoder=None,
):
super().__init__()
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.preencoder = preencoder
self.encoder = encoder
self.num_updates = 0
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
# Check that batch_size is unified
assert speech.shape[0] == speech_lengths.shape[0], (
speech.shape,
speech_lengths.shape,
)
self.encoder.set_num_updates(self.num_updates)
# 1. Encoder
encoder_out = self.encode(speech, speech_lengths)
losses = encoder_out["losses"]
loss = sum(losses.values())
sample_size = encoder_out["sample_size"]
loss = loss.sum() / sample_size
target_var = float(encoder_out["target_var"])
pred_var = float(encoder_out["pred_var"])
ema_decay = float(encoder_out["ema_decay"])
stats = dict(
loss=torch.clone(loss.detach()),
target_var=target_var,
pred_var=pred_var,
ema_decay=ema_decay,
)
loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device)
return loss, stats, weight
def collect_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
"""Frontend + Encoder.
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
if min(speech_lengths) == max(
speech_lengths
): # for clipping, set speech_lengths as None
speech_lengths = None
encoder_out = self.encoder(
feats, speech_lengths, mask=True, features_only=False
)
return encoder_out
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
def set_num_updates(self, num_updates):
self.num_updates = num_updates
def get_num_updates(self):
return self.num_updates