# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, List, Optional, Tuple import torch from torch import nn import torchaudio import torchaudio.compliance.kaldi as kaldi from .adapter import CNNAdapter, CNNSubsampling, LinearAdapter from .cmvn import GlobalCMVN, load_cmvn from .module.encoder.encoder import whaleEncoder class audioEncoderProcessor: def __init__( self, dataset_conf: dict = None, ): self.dataset_conf = dataset_conf def process(self, wav_path): try: waveform, sample_rate = torchaudio.load(wav_path) except Exception as e: print(f"cannot open {wav_path}!!!!!!!!!!!!!!!!") if sample_rate != self.dataset_conf["resample_conf"]["resample_rate"]: waveform = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=self.dataset_conf["resample_conf"]["resample_rate"] )(waveform) sample_rate = self.dataset_conf['resample_conf']['resample_rate'] waveform = waveform * (1 << 15) # Only keep key, feat, label mat = kaldi.fbank( waveform, num_mel_bins=self.dataset_conf["fbank_conf"]["num_mel_bins"], frame_length=self.dataset_conf["fbank_conf"]["frame_length"], frame_shift=self.dataset_conf["fbank_conf"]["frame_shift"], dither=self.dataset_conf["fbank_conf"]["dither"], energy_floor=0.0, sample_frequency=sample_rate, ) attn_mask = torch.ones(mat.shape[0]) attn_mask = attn_mask[2::2][2::2][0::2] return mat, attn_mask.shape[0] class audioEncoder(torch.nn.Module): def __init__( self, encoder: torch.nn.Module, llm_path: str, freeze_llm: bool = True, enc_out_dim: int = 512, llm_embed_dim: int = 4096, kernel_size: int = 3, IGNORE_ID: int = -100, adpter_type: str = "cnn", add_audio_bos_eos: bool = False, task_num: int = 10, task_before_audio: bool = False, task_type: str = "prompt", freeze_encoder: bool = False, freeze_adpter: bool = False, audio_prompt_finetune: bool = False, audio_prompt_num: int = 25, activation_func: str = "relu", norm: str = "batch", chat_template=None, ): super().__init__() self.encoder = encoder self.enc_out_dim = enc_out_dim self.llm_embed_dim = llm_embed_dim self.IGNORE_ID = IGNORE_ID self.add_audio_bos_eos = add_audio_bos_eos self.task_before_audio = task_before_audio self.task_type = task_type self.freeze_encoder = freeze_encoder self.freeze_adpter = freeze_adpter self.audio_prompt_finetune = audio_prompt_finetune self.audio_prompt_num = audio_prompt_num if adpter_type == "cnn": self.adpter = CNNAdapter(enc_out_dim, llm_embed_dim, kernel_size) elif adpter_type == "linear": self.adpter = LinearAdapter(enc_out_dim, llm_embed_dim) elif adpter_type == "subsampling": self.adpter = CNNSubsampling( enc_out_dim, llm_embed_dim, kernel_size, activation_func, norm ) if self.freeze_encoder: self.encoder.eval() for (name, param) in self.encoder.named_parameters(): param.requires_grad = False if self.freeze_adpter: self.adpter.eval() for (name, param) in self.adpter.named_parameters(): param.requires_grad = False if self.audio_prompt_finetune: self.prompt_embeddings = nn.Embedding(audio_prompt_num, llm_embed_dim) self.prompt_ids = torch.tensor([i for i in range(audio_prompt_num)]).long() def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, ) -> Dict[str, Optional[torch.Tensor]]: speech = speech.to(next(self.parameters()).dtype) # 1. Encoder encoder_out, encoder_mask = self.encoder(speech, speech_lengths) inputs_embeds, encoder_mask = self.adpter(encoder_out, encoder_mask) # B, T, D attention_mask = encoder_mask.squeeze(1) # B, T assert inputs_embeds.size(1) == attention_mask.size(1) # audio bos/eos if self.add_audio_bos_eos: inputs_embeds, attention_mask, target = self._add_bos_eos( "audio", "/audio", inputs_embeds, attention_mask, target ) B, _, _ = inputs_embeds.shape if self.audio_prompt_finetune: prompt_ids = self.prompt_ids.repeat(B, 1).to(inputs_embeds.device) prompt_embeds = self.prompt_embeddings( prompt_ids.to(inputs_embeds.device)) # B, 5, D inputs_embeds = torch.cat((prompt_embeds, inputs_embeds), 1) # B, (T+5), D outputs = { "inputs_embeds": inputs_embeds, "attention_mask": attention_mask, } return outputs def _add_bos_eos(self, bos, eos, inputs_embeds, attention_mask, target=None): B = len(inputs_embeds) bos_embed = self.task_embeddings( torch.full([B, 1], self.task_ids[bos]).to(inputs_embeds.device) ) # B, 1, D eos_embed = self.task_embeddings( torch.full([B, 1], self.task_ids[eos]).to(inputs_embeds.device) ) # B, 1, D bos_eos_target = torch.full([B, 2], self.IGNORE_ID).to(inputs_embeds.device) # B, 2 bos_eos_mask = torch.full([B, 1], True).to(inputs_embeds.device) # B, 1 inputs_embeds = torch.cat((bos_embed, inputs_embeds), 1) # B, (1+T), D inputs_embeds = torch.cat((inputs_embeds, eos_embed), 1) # B, (1+T+1), D attention_mask = torch.cat((bos_eos_mask, attention_mask), 1) # B, (1+T) attention_mask = torch.cat((attention_mask, bos_eos_mask), 1) # B, (1+T+1) if target is not None: target = torch.cat((target, bos_eos_target), 1) # B, (T+2), D return inputs_embeds, attention_mask, target def init_model(configs): if configs["cmvn_file"] is not None: mean, istd = load_cmvn(configs["cmvn_file"], configs["is_json_cmvn"]) global_cmvn = GlobalCMVN(torch.from_numpy(mean).float(), torch.from_numpy(istd).float()) else: global_cmvn = None input_dim = configs["input_dim"] encoder = whaleEncoder(input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"]) model = audioEncoder(encoder=encoder, **configs["model_conf"]) processor = audioEncoderProcessor(dataset_conf=configs["dataset_conf"]) model.audio_processor = processor return model