import random import torch import copy import re from torch import nn from torch.nn.utils.rnn import pad_sequence import torch.nn.functional as F from collections import defaultdict from typing import Dict, List, Optional, Tuple from transformers import AutoModelForCausalLM from transformers import AutoTokenizer from vita.model.vita_tts.adapter import * IGNORE_ID = -1 class AudioLLM(torch.nn.Module): def __init__( self, encoder: torch.nn.Module, llm_path: str, freeze_llm: bool = True, enc_out_dim: int = 512, llm_embed_dim: int = 4096, kernel_size: int = 3, IGNORE_ID: int = -100, adpter_type: str = 'cnn', add_audio_bos_eos: bool = False, task_num: int = 10, add_ctc_prompt_ratio: float = 0.0, lang_dict: dict = None, ctc: torch.nn.Module = None, tokenize_ctc_char: bool = False, task_before_audio: bool = False, hyp_before_task: bool = False, prompt_finetune: bool = False, add_prompt_before: bool = False, prompt_num: int = 5, prefix_finetune: bool = False, prefix_num: int = 5, llm_head_num: int = 32, num_key_value_heads: int = None, task_type: str = 'prompt', freeze_encoder: bool = False, freeze_adpter: bool = False, activation_func: str = 'relu', norm: str = 'batch', use_lora: bool = False, clone_encoder: torch.nn.Module = None, chat_template: str = None, predict_usr_state: int = 0, chunk_size: int = -1, ): super().__init__() self.encoder = encoder self.llm_decoder = AutoModelForCausalLM.from_pretrained(llm_path, torch_dtype="auto", trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(llm_path, trust_remote_code=True) self.freeze_llm = freeze_llm self.enc_out_dim = enc_out_dim self.llm_embed_dim = llm_embed_dim self.IGNORE_ID = IGNORE_ID self.add_audio_bos_eos = add_audio_bos_eos self.add_ctc_prompt_ratio = add_ctc_prompt_ratio self.lang_dict = lang_dict self.tokenize_ctc_char = tokenize_ctc_char self.task_before_audio = task_before_audio self.hyp_before_task = hyp_before_task self.prompt_finetune = prompt_finetune self.add_prompt_before = add_prompt_before self.prompt_num = prompt_num self.prefix_finetune = prefix_finetune self.prefix_num = prefix_num self.llm_head_num = llm_head_num if num_key_value_heads is None: self.num_key_value_heads = llm_head_num else: self.num_key_value_heads = num_key_value_heads self.kv_cache_dim = llm_embed_dim // self.llm_head_num * self.num_key_value_heads self.task_type = task_type self.freeze_encoder = freeze_encoder self.freeze_adpter = freeze_adpter self.predict_usr_state = predict_usr_state self.chunk_size = chunk_size if not hasattr(self.tokenizer, "eod_id"): self.tokenizer.eod_id = self.tokenizer.eos_token_id if not hasattr(self.llm_decoder, "transformer"): self.llm_decoder.transformer = self.llm_decoder.model self.llm_decoder.transformer.h = self.llm_decoder.transformer.layers if not hasattr(self.llm_decoder.transformer, "wte"): self.llm_decoder.transformer.wte = \ self.llm_decoder.transformer.embed_tokens # for chat mode if chat_template is not None: self.tokenizer.eod_id = self.tokenizer('<|im_end|>' )['input_ids'][0] self.chat_template = {} chat_template = chat_template.split('