Spaces:
Running
on
Zero
Running
on
Zero
File size: 18,759 Bytes
bc752b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 |
import random
import torch
import copy
import re
from torch import nn
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from vita.model.vita_tts.adapter import *
IGNORE_ID = -1
class AudioLLM(torch.nn.Module):
def __init__(
self,
encoder: torch.nn.Module,
llm_path: str,
freeze_llm: bool = True,
enc_out_dim: int = 512,
llm_embed_dim: int = 4096,
kernel_size: int = 3,
IGNORE_ID: int = -100,
adpter_type: str = 'cnn',
add_audio_bos_eos: bool = False,
task_num: int = 10,
add_ctc_prompt_ratio: float = 0.0,
lang_dict: dict = None,
ctc: torch.nn.Module = None,
tokenize_ctc_char: bool = False,
task_before_audio: bool = False,
hyp_before_task: bool = False,
prompt_finetune: bool = False,
add_prompt_before: bool = False,
prompt_num: int = 5,
prefix_finetune: bool = False,
prefix_num: int = 5,
llm_head_num: int = 32,
num_key_value_heads: int = None,
task_type: str = 'prompt',
freeze_encoder: bool = False,
freeze_adpter: bool = False,
activation_func: str = 'relu',
norm: str = 'batch',
use_lora: bool = False,
clone_encoder: torch.nn.Module = None,
chat_template: str = None,
predict_usr_state: int = 0,
chunk_size: int = -1,
):
super().__init__()
self.encoder = encoder
self.llm_decoder = AutoModelForCausalLM.from_pretrained(llm_path,
torch_dtype="auto",
trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(llm_path,
trust_remote_code=True)
self.freeze_llm = freeze_llm
self.enc_out_dim = enc_out_dim
self.llm_embed_dim = llm_embed_dim
self.IGNORE_ID = IGNORE_ID
self.add_audio_bos_eos = add_audio_bos_eos
self.add_ctc_prompt_ratio = add_ctc_prompt_ratio
self.lang_dict = lang_dict
self.tokenize_ctc_char = tokenize_ctc_char
self.task_before_audio = task_before_audio
self.hyp_before_task = hyp_before_task
self.prompt_finetune = prompt_finetune
self.add_prompt_before = add_prompt_before
self.prompt_num = prompt_num
self.prefix_finetune = prefix_finetune
self.prefix_num = prefix_num
self.llm_head_num = llm_head_num
if num_key_value_heads is None:
self.num_key_value_heads = llm_head_num
else:
self.num_key_value_heads = num_key_value_heads
self.kv_cache_dim = llm_embed_dim // self.llm_head_num * self.num_key_value_heads
self.task_type = task_type
self.freeze_encoder = freeze_encoder
self.freeze_adpter = freeze_adpter
self.predict_usr_state = predict_usr_state
self.chunk_size = chunk_size
if not hasattr(self.tokenizer, "eod_id"):
self.tokenizer.eod_id = self.tokenizer.eos_token_id
if not hasattr(self.llm_decoder, "transformer"):
self.llm_decoder.transformer = self.llm_decoder.model
self.llm_decoder.transformer.h = self.llm_decoder.transformer.layers
if not hasattr(self.llm_decoder.transformer, "wte"):
self.llm_decoder.transformer.wte = \
self.llm_decoder.transformer.embed_tokens
# for chat mode
if chat_template is not None:
self.tokenizer.eod_id = self.tokenizer('<|im_end|>'
)['input_ids'][0]
self.chat_template = {}
chat_template = chat_template.split('<audio>')
chat_prefix = chat_template[0].split('<|im_end|>')
chat_role = chat_prefix[0] + '<|im_end|>'
self.chat_template['role'] = self.tokenizer(
[chat_role], return_tensors="pt")['input_ids']
self.chat_template['prefix'] = self.tokenizer(
[chat_prefix[1]], return_tensors="pt")['input_ids']
self.chat_template['suffix'] = self.tokenizer(
[chat_template[1]], return_tensors="pt")['input_ids']
else:
self.chat_template = None
# for CTC prompt
if self.add_ctc_prompt_ratio > 0.0:
assert lang_dict is not None
assert ctc is not None
self.ctc = ctc.eval()
if clone_encoder is None:
self.clone_encoder = copy.deepcopy(encoder)
else:
self.clone_encoder = clone_encoder
self.clone_encoder.eval()
for (name, param) in self.clone_encoder.named_parameters():
param.requires_grad = False
for (name, param) in self.ctc.named_parameters():
param.requires_grad = False
else:
self.clone_encoder = None
if self.freeze_llm:
self.llm_decoder.eval()
for (name, param) in self.llm_decoder.named_parameters():
param.requires_grad = False
if use_lora:
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=UNET_TARGET_MODULES,
lora_dropout=args.lora_dropout,
bias=args.lora_bias,
)
if adpter_type == 'cnn':
self.adpter = CNNAdapter(enc_out_dim, llm_embed_dim, kernel_size)
elif adpter_type == 'linear':
self.adpter = LinearAdapter(enc_out_dim, llm_embed_dim)
elif adpter_type == 'subsampling':
self.adpter = CNNSubsampling(enc_out_dim, llm_embed_dim,
kernel_size, activation_func, norm)
self.task_embeddings = torch.nn.Embedding(task_num, llm_embed_dim)
if task_type == 'prefix':
self.prefix_embeddings = nn.ModuleList(
[
torch.nn.ModuleList(
[nn.Embedding(task_num, self.kv_cache_dim),
nn.Embedding(task_num, self.kv_cache_dim)]
)
for i in range(len(self.llm_decoder.transformer.h))
]
)
if self.prompt_finetune or self.prefix_finetune:
if self.prompt_finetune:
self.prompt_embeddings = nn.Embedding(prompt_num, llm_embed_dim)
self.prompt_ids = torch.Tensor([i for i in range(prompt_num)]).long()
if self.prefix_finetune:
self.prefix_embeddings = nn.ModuleList(
[
torch.nn.ModuleList(
[nn.Embedding(prefix_num, self.kv_cache_dim),
nn.Embedding(prefix_num, self.kv_cache_dim)]
)
for i in range(len(self.llm_decoder.transformer.h))
]
)
self.prefix_ids = torch.Tensor([i for i in range(prefix_num)]).long()
if self.freeze_encoder:
self.encoder.eval()
for (name, param) in self.encoder.named_parameters():
param.requires_grad = False
if self.freeze_adpter:
self.adpter.eval()
for (name, param) in self.adpter.named_parameters():
param.requires_grad = False
if self.predict_usr_state:
self.predictor_head = torch.nn.Linear(llm_embed_dim, predict_usr_state)
else:
self.predictor_head = None
# define task ids
self.task_ids = {
"sot": 0,
"transcribe": 1,
"translate": 2,
"zh": 3,
"en": 4,
"audio": 5,
"/audio": 6,
"hyps": 7,
"/hyps": 8,
}
def set_system_role(
self,
extra_inputs: Optional[dict] = None,
):
# Ensure 'past_key_values' does not exist in extra_inputs, raise an exception if it does
assert extra_inputs.get('past_key_values', None) is None, "past key values already exist!!!"
# If 'role' key is present in extra_inputs, use that role as the chat prefix
if extra_inputs.get('role', None) is not None:
chat_prefix = self.tokenizer([extra_inputs['role']],
return_tensors="pt")['input_ids'].to('cuda') # Convert role to tokens and move to CUDA device
else:
# If no 'role' is provided, use the default chat template and remove the last token (<|im_end|>)
chat_prefix = self.chat_template['role'][:, :-1].to('cuda')
# Use the LLM decoder's word embedding layer to convert the chat prefix into embeddings
inputs_embeds = self.llm_decoder.transformer.wte(chat_prefix)
# Create an attention mask with the same shape as the chat prefix, all values set to True
attention_mask = torch.full(chat_prefix.shape,
True).to(inputs_embeds.device)
# Prepare the input dictionary containing embeddings and attention mask
inputs = {
'inputs_embeds': inputs_embeds.half(), # Convert embeddings to half precision floats
'attention_mask': attention_mask,
}
# Call the _generate_one_step method to generate one step output, including past_key_values, etc.
_, past_key_values, stat, _ = self._generate_one_step(
copy.deepcopy(inputs), "sl")
# Return the generated past_key_values
return past_key_values
def recognize(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
extra_inputs: Optional[dict] = None,
):
assert extra_inputs.get('past_key_values', None) is not None, "must set system role first!!!"
buffer = extra_inputs.get('encoder_cache', None)
cnn_cache = extra_inputs.get('adapter_cache', None)
pe_index = extra_inputs.get('pe_index', 0)
if extra_inputs['stat'] == 'sl' or extra_inputs['stat'] == 'cl':
# Encoder
if buffer is None:
buffer = [None] * self.encoder.enc[1].num_blocks
encoder_out, buffer, _, _, pe_index = self.encoder.infer(speech, buffer,
0, None, pe_index)
encoder_mask = torch.full(encoder_out.shape[:2], True).unsqueeze(1
).to(encoder_out.device)
# adapter
inputs_embeds, encoder_mask, cnn_cache = self.adpter(encoder_out, encoder_mask,
cache=cnn_cache, return_cache=True) # 1, T, D
attention_mask = encoder_mask.squeeze(1) # 1, T
# prompt
if extra_inputs['stat'] == 'sl':
if self.prompt_finetune:
prompt_ids = self.prompt_ids.repeat(1, 1).to(inputs_embeds.device)
prompt_embeds = self.prompt_embeddings(
prompt_ids.to(inputs_embeds.device)) # B, 5, D
prompt_mask = torch.full(prompt_ids.shape,
True).to(inputs_embeds.device) # B, 5
if self.add_prompt_before:
inputs_embeds = torch.cat((prompt_embeds, inputs_embeds), 1) # B, (T+5), D
attention_mask = torch.cat((prompt_mask, attention_mask), 1) # B, (T+5)
# chat mode
if self.chat_template is not None:
if extra_inputs['stat'] == 'sl':
chat_prefix = self.chat_template['prefix'].to(
inputs_embeds.device)
chat_prefix = torch.cat((torch.tensor([[self.tokenizer.eod_id]]
).to(inputs_embeds.device), chat_prefix), 1)
chat_prefix_embeds = self.llm_decoder.transformer.wte(chat_prefix)
chat_prefix_mask = torch.full(chat_prefix.shape,
True).to(inputs_embeds.device)
inputs_embeds = torch.cat((chat_prefix_embeds, inputs_embeds), 1)
attention_mask = torch.cat((chat_prefix_mask, attention_mask), 1)
if extra_inputs['stat'] == 'ss':
chat_suffix = self.chat_template['suffix'].to('cuda')
chat_suffix_embeds = self.llm_decoder.transformer.wte(chat_suffix)
chat_suffix_mask = torch.full(chat_suffix.shape, True).to('cuda')
inputs_embeds = chat_suffix_embeds
attention_mask = chat_suffix_mask
if extra_inputs['stat'] != 'cs':
inputs = {
'inputs_embeds': inputs_embeds.half(),
'attention_mask': attention_mask,
}
else:
attention_mask = torch.full([1, 1], True).to('cuda')
inputs = {
'input_ids': extra_inputs['last_id'],
'attention_mask': attention_mask
}
# add kv cache
inputs['past_key_values'] = extra_inputs['past_key_values']
past_mask = torch.full([1, inputs['past_key_values'][0][0].size(2)],
True).to('cuda')
attention_mask = torch.cat((past_mask, attention_mask), 1)
inputs['attention_mask'] = attention_mask
top_p = extra_inputs.get('top_p', 1.0)
top_k = extra_inputs.get('top_k', 0)
temperature = extra_inputs.get('temperature', 1.0)
last_id, past_key_values, stat, hidden_state = self._generate_one_step(copy.deepcopy(inputs),
extra_inputs['stat'],
top_p=top_p,
top_k=top_k,
temperature=temperature)
return last_id, stat, past_key_values, cnn_cache, buffer, pe_index, hidden_state
def _post_decode(self, output, temperature=1.0, top_k=0, top_p=0.0):
"""
Decoding function, based on the posterior probability output,
uses top_k, top_p, and temperature parameters for sampling.
Parameters:
- output: torch.Tensor, shaped as (1, 1, D), represents the posterior probability output by the model.
- top_k: int, indicates selecting the top k tokens with the highest probability for sampling.
If 0, no top_k filtering is performed.
- top_p: float, indicates selecting tokens with cumulative probability not exceeding p for sampling.
If 0.0, no top_p filtering is performed.
- temperature: float, represents the sampling temperature parameter.
The higher the value, the more random the sampling;
the lower the value, the more deterministic the sampling.
Returns:
- Selected token index.
"""
output = output.squeeze(0).squeeze(0)
# temperature
if temperature != 1.0:
output = output / temperature
probs = torch.nn.functional.softmax(output, dim=-1)
# top_k
if top_k > 0:
top_k_probs, top_k_indices = torch.topk(probs, top_k)
probs = torch.zeros_like(probs).scatter_(0, top_k_indices, top_k_probs)
probs = probs / probs.sum()
# top_p
if top_p > 0.0:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
if sorted_indices_to_remove[0]:
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
probs[indices_to_remove] = 0
probs = probs / probs.sum()
token_index = torch.multinomial(probs, 1)
return token_index.unsqueeze(0)
def _generate_one_step(
self,
inputs,
stat,
top_p: float = 1.0,
top_k: int = 0,
temperature: float = 1.0,
):
"""
Generates the model's next output based on the current input and state.
Parameters:
- inputs: The input tensor containing the model's input data.
- stat: The current state information used to control the generation process.
- top_p: The threshold for controlling top-p sampling.
- top_k: The threshold for controlling top-k sampling.
- temperature: Controls the randomness of sampling.
Returns:
- last_id: The index of the last generated token.
- stat: The updated state information.
- past_key_values: The model's historical key-value pairs, used for cross-step memory.
- hidden_state: The model's hidden state, used to maintain cross-step contextual information.
"""
outputs = self.llm_decoder.model(**inputs)
if stat == 'sl' or stat == 'cl':
state_logits = self.predictor_head(
outputs['last_hidden_state'])[0, :]
prob = F.softmax(state_logits[:, :-1])
state_prob = prob[-1].clone()
state_1 = state_prob[1]
state_2 = state_prob[2]
print("State 1 prob: {:.4f}, State 2 prob: {:.4f}".format(state_1.item(), state_2.item()))
if state_2 > 0.5:
return None, outputs['past_key_values'], 'el', None
if state_1 > 0.5:
return None, outputs['past_key_values'], 'ss', None
return None, outputs['past_key_values'], 'cl', None
last_logit = self.llm_decoder.lm_head(outputs['last_hidden_state'][:, -1:, :])
last_id = self._post_decode(last_logit, temperature=temperature, top_k=top_k, top_p=top_p)
return_tts_state = outputs['last_hidden_state'][:, -1:, :]
if last_id[0][0] == self.tokenizer.eod_id:
return None, outputs['past_key_values'], 'sl', return_tts_state
else:
return last_id, outputs['past_key_values'], 'cs', return_tts_state
|