#!/usr/bin/env python3 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) import logging import random from contextlib import contextmanager from distutils.version import LooseVersion from itertools import permutations from typing import Dict from typing import Optional from typing import Tuple, List import numpy as np import torch from torch.nn import functional as F from funasr_detach.models.transformer.utils.nets_utils import to_device from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask from funasr_detach.models.decoder.abs_decoder import AbsDecoder from funasr_detach.models.encoder.abs_encoder import AbsEncoder from funasr_detach.frontends.abs_frontend import AbsFrontend from funasr_detach.models.specaug.abs_specaug import AbsSpecAug from funasr_detach.models.specaug.abs_profileaug import AbsProfileAug from funasr_detach.layers.abs_normalize import AbsNormalize from funasr_detach.train_utils.device_funcs import force_gatherable from funasr_detach.models.base_model import FunASRModel from funasr_detach.losses.label_smoothing_loss import ( LabelSmoothingLoss, SequenceBinaryCrossEntropy, ) from funasr_detach.utils.misc import int2vec from funasr_detach.utils.hinter import hint_once if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast else: # Nothing to do if torch<1.6.0 @contextmanager def autocast(enabled=True): yield class DiarSondModel(FunASRModel): """Speaker overlap-aware neural diarization model reference: https://arxiv.org/abs/2211.10243 """ def __init__( self, vocab_size: int, frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], profileaug: Optional[AbsProfileAug], normalize: Optional[AbsNormalize], encoder: torch.nn.Module, speaker_encoder: Optional[torch.nn.Module], ci_scorer: torch.nn.Module, cd_scorer: Optional[torch.nn.Module], decoder: torch.nn.Module, token_list: list, lsm_weight: float = 0.1, length_normalized_loss: bool = False, max_spk_num: int = 16, label_aggregator: Optional[torch.nn.Module] = None, normalize_speech_speaker: bool = False, ignore_id: int = -1, speaker_discrimination_loss_weight: float = 1.0, inter_score_loss_weight: float = 0.0, inputs_type: str = "raw", model_regularizer_weight: float = 0.0, freeze_encoder: bool = False, onfly_shuffle_speaker: bool = True, ): super().__init__() self.encoder = encoder self.speaker_encoder = speaker_encoder self.ci_scorer = ci_scorer self.cd_scorer = cd_scorer self.normalize = normalize self.frontend = frontend self.specaug = specaug self.profileaug = profileaug self.label_aggregator = label_aggregator self.decoder = decoder self.token_list = token_list self.max_spk_num = max_spk_num self.normalize_speech_speaker = normalize_speech_speaker self.ignore_id = ignore_id self.model_regularizer_weight = model_regularizer_weight self.freeze_encoder = freeze_encoder self.onfly_shuffle_speaker = onfly_shuffle_speaker self.criterion_diar = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, smoothing=lsm_weight, normalize_length=length_normalized_loss, ) self.criterion_bce = SequenceBinaryCrossEntropy( normalize_length=length_normalized_loss ) self.pse_embedding = self.generate_pse_embedding() self.power_weight = torch.from_numpy( 2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :] ).float() self.int_token_arr = torch.from_numpy( np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :] ).int() self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight self.inter_score_loss_weight = inter_score_loss_weight self.forward_steps = 0 self.inputs_type = inputs_type self.to_regularize_parameters = None def get_regularize_parameters(self): to_regularize_parameters, normal_parameters = [], [] for name, param in self.named_parameters(): if ( "encoder" in name and "weight" in name and "bn" not in name and ( "conv2" in name or "conv1" in name or "conv_sc" in name or "dense" in name ) ): to_regularize_parameters.append((name, param)) else: normal_parameters.append((name, param)) self.to_regularize_parameters = to_regularize_parameters return to_regularize_parameters, normal_parameters def generate_pse_embedding(self): embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float32) for idx, pse_label in enumerate(self.token_list): emb = int2vec(int(pse_label), vec_dim=self.max_spk_num, dtype=np.float32) embedding[idx] = emb return torch.from_numpy(embedding) def rand_permute_speaker(self, raw_profile, raw_binary_labels): """ raw_profile: B, N, D raw_binary_labels: B, T, N """ assert ( raw_profile.shape[1] == raw_binary_labels.shape[2] ), "Num profile: {}, Num label: {}".format( raw_profile.shape[1], raw_binary_labels.shape[-1] ) profile = torch.clone(raw_profile) binary_labels = torch.clone(raw_binary_labels) bsz, num_spk = profile.shape[0], profile.shape[1] for i in range(bsz): idx = list(range(num_spk)) random.shuffle(idx) profile[i] = profile[i][idx, :] binary_labels[i] = binary_labels[i][:, idx] return profile, binary_labels def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor = None, profile: torch.Tensor = None, profile_lengths: torch.Tensor = None, binary_labels: torch.Tensor = None, binary_labels_lengths: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss Args: speech: (Batch, samples) or (Batch, frames, input_size) speech_lengths: (Batch,) default None for chunk interator, because the chunk-iterator does not have the speech_lengths returned. see in espnet2/iterators/chunk_iter_factory.py profile: (Batch, N_spk, dim) profile_lengths: (Batch,) binary_labels: (Batch, frames, max_spk_num) binary_labels_lengths: (Batch,) """ assert speech.shape[0] <= binary_labels.shape[0], ( speech.shape, binary_labels.shape, ) batch_size = speech.shape[0] if self.freeze_encoder: hint_once("Freeze encoder", "freeze_encoder", rank=0) self.encoder.eval() self.forward_steps = self.forward_steps + 1 if self.pse_embedding.device != speech.device: self.pse_embedding = self.pse_embedding.to(speech.device) self.power_weight = self.power_weight.to(speech.device) self.int_token_arr = self.int_token_arr.to(speech.device) if self.onfly_shuffle_speaker: hint_once( "On-the-fly shuffle speaker permutation.", "onfly_shuffle_speaker", rank=0, ) profile, binary_labels = self.rand_permute_speaker(profile, binary_labels) # 0a. Aggregate time-domain labels to match forward outputs if self.label_aggregator is not None: binary_labels, binary_labels_lengths = self.label_aggregator( binary_labels, binary_labels_lengths ) # 0b. augment profiles if self.profileaug is not None and self.training: speech, profile, binary_labels = self.profileaug( speech, speech_lengths, profile, profile_lengths, binary_labels, binary_labels_lengths, ) # 1. Calculate power-set encoding (PSE) labels pad_bin_labels = F.pad( binary_labels, (0, self.max_spk_num - binary_labels.shape[2]), "constant", 0.0, ) raw_pse_labels = torch.sum( pad_bin_labels * self.power_weight, dim=2, keepdim=True ) pse_labels = torch.argmax( (raw_pse_labels.int() == self.int_token_arr).float(), dim=2 ) # 2. Network forward pred, inter_outputs = self.prediction_forward( speech, speech_lengths, profile, profile_lengths, return_inter_outputs=True ) (speech, speech_lengths), (profile, profile_lengths), (ci_score, cd_score) = ( inter_outputs ) # If encoder uses conv* as input_layer (i.e., subsampling), # the sequence length of 'pred' might be slightly less than the # length of 'spk_labels'. Here we force them to be equal. length_diff_tolerance = 2 length_diff = abs(pse_labels.shape[1] - pred.shape[1]) if length_diff <= length_diff_tolerance: min_len = min(pred.shape[1], pse_labels.shape[1]) pse_labels = pse_labels[:, :min_len] pred = pred[:, :min_len] cd_score = cd_score[:, :min_len] ci_score = ci_score[:, :min_len] loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths) loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths) loss_inter_ci, loss_inter_cd = self.internal_score_loss( cd_score, ci_score, pse_labels, binary_labels_lengths ) regularizer_loss = None if ( self.model_regularizer_weight > 0 and self.to_regularize_parameters is not None ): regularizer_loss = self.calculate_regularizer_loss() label_mask = make_pad_mask( binary_labels_lengths, maxlen=pse_labels.shape[1] ).to(pse_labels.device) loss = ( loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis + self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd) ) # if regularizer_loss is not None: # loss = loss + regularizer_loss * self.model_regularizer_weight ( correct, num_frames, speech_scored, speech_miss, speech_falarm, speaker_scored, speaker_miss, speaker_falarm, speaker_error, ) = self.calc_diarization_error( pred=F.embedding(pred.argmax(dim=2) * (~label_mask), self.pse_embedding), label=F.embedding(pse_labels * (~label_mask), self.pse_embedding), length=binary_labels_lengths, ) if speech_scored > 0 and num_frames > 0: sad_mr, sad_fr, mi, fa, cf, acc, der = ( speech_miss / speech_scored, speech_falarm / speech_scored, speaker_miss / speaker_scored, speaker_falarm / speaker_scored, speaker_error / speaker_scored, correct / num_frames, (speaker_miss + speaker_falarm + speaker_error) / speaker_scored, ) else: sad_mr, sad_fr, mi, fa, cf, acc, der = 0, 0, 0, 0, 0, 0, 0 stats = dict( loss=loss.detach(), loss_diar=loss_diar.detach() if loss_diar is not None else None, loss_spk_dis=loss_spk_dis.detach() if loss_spk_dis is not None else None, loss_inter_ci=loss_inter_ci.detach() if loss_inter_ci is not None else None, loss_inter_cd=loss_inter_cd.detach() if loss_inter_cd is not None else None, regularizer_loss=( regularizer_loss.detach() if regularizer_loss is not None else None ), sad_mr=sad_mr, sad_fr=sad_fr, mi=mi, fa=fa, cf=cf, acc=acc, der=der, forward_steps=self.forward_steps, ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight def calculate_regularizer_loss(self): regularizer_loss = 0.0 for name, param in self.to_regularize_parameters: regularizer_loss = regularizer_loss + torch.norm(param, p=2) return regularizer_loss def classification_loss( self, predictions: torch.Tensor, labels: torch.Tensor, prediction_lengths: torch.Tensor, ) -> torch.Tensor: mask = make_pad_mask(prediction_lengths, maxlen=labels.shape[1]) pad_labels = labels.masked_fill( mask.to(predictions.device), value=self.ignore_id ) loss = self.criterion_diar(predictions.contiguous(), pad_labels) return loss def speaker_discrimination_loss( self, profile: torch.Tensor, profile_lengths: torch.Tensor ) -> torch.Tensor: profile_mask = ( torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0 ).float() # (B, N, 1) mask = torch.matmul(profile_mask, profile_mask.transpose(1, 2)) # (B, N, N) mask = mask * (1.0 - torch.eye(self.max_spk_num).unsqueeze(0).to(mask)) eps = 1e-12 coding_norm = ( torch.linalg.norm( profile * profile_mask + (1 - profile_mask) * eps, dim=2, keepdim=True ) * profile_mask ) # profile: Batch, N, dim cos_theta = ( F.cosine_similarity( profile.unsqueeze(2), profile.unsqueeze(1), dim=-1, eps=eps ) * mask ) cos_theta = torch.clip(cos_theta, -1 + eps, 1 - eps) loss = (F.relu(mask * coding_norm * (cos_theta - 0.0))).sum() / mask.sum() return loss def calculate_multi_labels(self, pse_labels, pse_labels_lengths): mask = make_pad_mask(pse_labels_lengths, maxlen=pse_labels.shape[1]) padding_labels = pse_labels.masked_fill(mask.to(pse_labels.device), value=0).to( pse_labels ) multi_labels = F.embedding(padding_labels, self.pse_embedding) return multi_labels def internal_score_loss( self, cd_score: torch.Tensor, ci_score: torch.Tensor, pse_labels: torch.Tensor, pse_labels_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: multi_labels = self.calculate_multi_labels(pse_labels, pse_labels_lengths) ci_loss = self.criterion_bce(ci_score, multi_labels, pse_labels_lengths) cd_loss = self.criterion_bce(cd_score, multi_labels, pse_labels_lengths) return ci_loss, cd_loss def collect_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor, profile: torch.Tensor = None, profile_lengths: torch.Tensor = None, binary_labels: torch.Tensor = None, binary_labels_lengths: torch.Tensor = None, ) -> Dict[str, torch.Tensor]: feats, feats_lengths = self._extract_feats(speech, speech_lengths) return {"feats": feats, "feats_lengths": feats_lengths} def encode_speaker( self, profile: torch.Tensor, profile_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: with autocast(False): if profile.shape[1] < self.max_spk_num: profile = F.pad( profile, [0, 0, 0, self.max_spk_num - profile.shape[1], 0, 0], "constant", 0.0, ) profile_mask = ( torch.linalg.norm(profile, ord=2, dim=2, keepdim=True) > 0 ).float() profile = F.normalize(profile, dim=2) if self.speaker_encoder is not None: profile = self.speaker_encoder(profile, profile_lengths)[0] return profile * profile_mask, profile_lengths else: return profile, profile_lengths def encode_speech( self, speech: torch.Tensor, speech_lengths: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: if self.encoder is not None and self.inputs_type == "raw": speech, speech_lengths = self.encode(speech, speech_lengths) speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1]) speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float() return speech * speech_mask, speech_lengths else: return speech, speech_lengths @staticmethod def concate_speech_ivc(speech: torch.Tensor, ivc: torch.Tensor) -> torch.Tensor: nn, tt = ivc.shape[1], speech.shape[1] speech = speech.unsqueeze(dim=1) # B x 1 x T x D speech = speech.expand(-1, nn, -1, -1) # B x N x T x D ivc = ivc.unsqueeze(dim=2) # B x N x 1 x D ivc = ivc.expand(-1, -1, tt, -1) # B x N x T x D sd_in = torch.cat([speech, ivc], dim=3) # B x N x T x 2D return sd_in def calc_similarity( self, speech_encoder_outputs: torch.Tensor, speaker_encoder_outputs: torch.Tensor, seq_len: torch.Tensor = None, spk_len: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: bb, tt = speech_encoder_outputs.shape[0], speech_encoder_outputs.shape[1] d_sph, d_spk = speech_encoder_outputs.shape[2], speaker_encoder_outputs.shape[2] if self.normalize_speech_speaker: speech_encoder_outputs = F.normalize(speech_encoder_outputs, dim=2) speaker_encoder_outputs = F.normalize(speaker_encoder_outputs, dim=2) ge_in = self.concate_speech_ivc(speech_encoder_outputs, speaker_encoder_outputs) ge_in = torch.reshape(ge_in, [bb * self.max_spk_num, tt, d_sph + d_spk]) ge_len = seq_len.unsqueeze(1).expand(-1, self.max_spk_num) ge_len = torch.reshape(ge_len, [bb * self.max_spk_num]) cd_simi = self.cd_scorer(ge_in, ge_len)[0] cd_simi = torch.reshape(cd_simi, [bb, self.max_spk_num, tt, 1]) cd_simi = cd_simi.squeeze(dim=3).permute([0, 2, 1]) if isinstance(self.ci_scorer, AbsEncoder): ci_simi = self.ci_scorer(ge_in, ge_len)[0] ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute( [0, 2, 1] ) else: ci_simi = self.ci_scorer(speech_encoder_outputs, speaker_encoder_outputs) return ci_simi, cd_simi def post_net_forward(self, simi, seq_len): logits = self.decoder(simi, seq_len)[0] return logits def prediction_forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, profile: torch.Tensor, profile_lengths: torch.Tensor, return_inter_outputs: bool = False, ) -> [torch.Tensor, Optional[list]]: # speech encoding speech, speech_lengths = self.encode_speech(speech, speech_lengths) # speaker encoding profile, profile_lengths = self.encode_speaker(profile, profile_lengths) # calculating similarity ci_simi, cd_simi = self.calc_similarity( speech, profile, speech_lengths, profile_lengths ) similarity = torch.cat([cd_simi, ci_simi], dim=2) # post net forward logits = self.post_net_forward(similarity, speech_lengths) if return_inter_outputs: return logits, [ (speech, speech_lengths), (profile, profile_lengths), (ci_simi, cd_simi), ] return logits def encode( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Frontend + Encoder Args: speech: (Batch, Length, ...) speech_lengths: (Batch,) """ with autocast(False): # 1. Extract feats feats, feats_lengths = self._extract_feats(speech, speech_lengths) # 2. Data augmentation if self.specaug is not None and self.training: feats, feats_lengths = self.specaug(feats, feats_lengths) # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN if self.normalize is not None: feats, feats_lengths = self.normalize(feats, feats_lengths) # 4. Forward encoder # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim) encoder_outputs = self.encoder(feats, feats_lengths) encoder_out, encoder_out_lens = encoder_outputs[:2] assert encoder_out.size(0) == speech.size(0), ( encoder_out.size(), speech.size(0), ) assert encoder_out.size(1) <= encoder_out_lens.max(), ( encoder_out.size(), encoder_out_lens.max(), ) return encoder_out, encoder_out_lens def _extract_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = speech.shape[0] speech_lengths = ( speech_lengths if speech_lengths is not None else torch.ones(batch_size).int() * speech.shape[1] ) assert speech_lengths.dim() == 1, speech_lengths.shape # for data-parallel speech = speech[:, : speech_lengths.max()] if self.frontend is not None: # Frontend # e.g. STFT and Feature extract # data_loader may send time-domain signal in this case # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) feats, feats_lengths = self.frontend(speech, speech_lengths) else: # No frontend and no feature extract feats, feats_lengths = speech, speech_lengths return feats, feats_lengths @staticmethod def calc_diarization_error(pred, label, length): # Note (jiatong): Credit to https://github.com/hitachi-speech/EEND (batch_size, max_len, num_output) = label.size() # mask the padding part mask = ~make_pad_mask(length, maxlen=label.shape[1]).unsqueeze(-1).numpy() # pred and label have the shape (batch_size, max_len, num_output) label_np = label.data.cpu().numpy().astype(int) pred_np = (pred.data.cpu().numpy() > 0).astype(int) label_np = label_np * mask pred_np = pred_np * mask length = length.data.cpu().numpy() # compute speech activity detection error n_ref = np.sum(label_np, axis=2) n_sys = np.sum(pred_np, axis=2) speech_scored = float(np.sum(n_ref > 0)) speech_miss = float(np.sum(np.logical_and(n_ref > 0, n_sys == 0))) speech_falarm = float(np.sum(np.logical_and(n_ref == 0, n_sys > 0))) # compute speaker diarization error speaker_scored = float(np.sum(n_ref)) speaker_miss = float(np.sum(np.maximum(n_ref - n_sys, 0))) speaker_falarm = float(np.sum(np.maximum(n_sys - n_ref, 0))) n_map = np.sum(np.logical_and(label_np == 1, pred_np == 1), axis=2) speaker_error = float(np.sum(np.minimum(n_ref, n_sys) - n_map)) correct = float(1.0 * np.sum((label_np == pred_np) * mask) / num_output) num_frames = np.sum(length) return ( correct, num_frames, speech_scored, speech_miss, speech_falarm, speaker_scored, speaker_miss, speaker_falarm, speaker_error, )