diff --git a/model/slam_model_s2s.py b/model/slam_model_s2s.py new file mode 100644 index 0000000000000000000000000000000000000000..65ab83aa911f642aacb837d713fbe9f43801fcf2 --- /dev/null +++ b/model/slam_model_s2s.py @@ -0,0 +1,444 @@ +import torch +import os +import logging +import torch.nn.functional as F +from slam_llm.models.slam_model import ( + slam_model, + setup_tokenizer, + setup_encoder, + setup_encoder_projector, + setup_llm, +) +from slam_llm.utils.train_utils import print_model_size +from typing import List, Optional +from slam_llm.utils.metric import compute_accuracy +from transformers import T5ForConditionalGeneration +from tqdm import tqdm +from utils.tts_adapter_utils import setup_tts_adapter +from utils.codec_utils import setup_codec +from utils.trick_utils import partial_freeze_weights, train_embedding_layer_only +from utils.snac_utils import layershift + +logger = logging.getLogger(__name__) + + +def model_factory(train_config, model_config, ckpt_path, **kwargs): + # return necessary components for training + tokenizer = setup_tokenizer(train_config, model_config, **kwargs) + + if train_config.task_type == "s2s" or train_config.task_type == "asr": + encoder = setup_encoder(train_config, model_config, **kwargs) + elif train_config.task_type == "tts": + encoder = None + else: + raise NotImplementedError + + # llm + llm = setup_llm(train_config, model_config, **kwargs) + + # projector + if encoder is not None: + encoder_projector = setup_encoder_projector( + train_config, model_config, **kwargs + ) + else: + encoder_projector = None + + codec_decoder = None + if model_config.codec_decode: + codec_decoder = setup_codec(train_config, model_config, **kwargs) + + tts_adapter = None + if model_config.tts_adapter: + adapter_config = model_config.tts_adapter_config + tts_adapter = setup_tts_adapter(adapter_config, model_config, **kwargs) + + model = slam_model_s2s( + encoder, + llm, + encoder_projector, + tokenizer, + tts_adapter, + codec_decoder, + train_config, + model_config, + **kwargs, + ) + + if ckpt_path is not None: + logger.info("loading other parts from: {}".format(ckpt_path)) + ckpt_dict = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(ckpt_dict, strict=False) + + if train_config.train_audio_embed_only: + partial_freeze_weights(model, model_config.vocab_config.padded_text_vocabsize, model_config.vocab_config.total_vocabsize) + + if train_config.train_embed_only: + train_embedding_layer_only(model) + + print_model_size( + model, + train_config, + ( + int(os.environ["RANK"]) + if train_config.enable_fsdp or train_config.enable_ddp + else 0 + ), + ) + return model, tokenizer + + +class slam_model_s2s(slam_model): + def __init__( + self, + encoder, + llm, + encoder_projector, + tokenizer, + tts_adapter, + codec_decoder, + train_config, + model_config, + **kwargs, + ): + super().__init__( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + # resize llm embedding layer + self.original_vocabsize = self.llm.lm_head.weight.size(0) + if self.model_config.vocab_config.total_vocabsize != self.original_vocabsize: + self.llm.resize_token_embeddings(self.model_config.vocab_config.total_vocabsize) + + if int(os.environ.get("RANK", "0")) == 0: + logger.info("Resize llm embedding layer's vocab size to {}".format(self.model_config.vocab_config.total_vocabsize)) + + self.codec_decoder = codec_decoder + self.tts_adapter = tts_adapter + self.code_layer = self.model_config.vocab_config.code_layer + + + def forward(self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ): + audio_mel = kwargs.get("audio_mel", None) + audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper + + audio = kwargs.get("audio", None) + audio_mask = kwargs.get("audio_mask", None) + + modality_mask = kwargs.get("modality_mask", None) + + encoder_outs = None + if audio_mel is not None or audio is not None: + if self.train_config.freeze_encoder: # freeze encoder + self.encoder.eval() + + if self.model_config.encoder_name == "whisper": + encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim + if self.model_config.encoder_name == "wavlm": + encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask + if self.model_config.encoder_name == "hubert": + results = self.encoder(source = audio, padding_mask = 1-audio_mask) + if self.model_config.encoder_type == "pretrain": + encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"] + if self.model_config.encoder_type == "finetune": + encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"] + encoder_outs = encoder_outs.transpose(0, 1) + if self.encoder is None: + encoder_outs = audio_mel if audio_mel is not None else audio + + if self.model_config.encoder_projector == "q-former": + encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) + if self.model_config.encoder_projector == "linear": + encoder_outs = self.encoder_projector(encoder_outs) + if self.model_config.encoder_projector == "cov1d-linear": + encoder_outs = self.encoder_projector(encoder_outs) + + if input_ids is not None: + input_ids[input_ids == -1] = 0 # [btz, 8, seq_length] + + if isinstance(self.llm, T5ForConditionalGeneration): + inputs_embeds = self.llm.shared(input_ids) + else: + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(input_ids) # [btz, 8, seq_length, emb_dim] + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(input_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) + + if modality_mask is not None and encoder_outs is not None: + modality_mask = modality_mask.unsqueeze(1).repeat(1, self.code_layer, 1) # [btz, 8, seq_length] + modality_mask_start_indices = (modality_mask == True).float().argmax(dim=2) + modality_lengths = torch.clamp(modality_mask.sum(dim=2), max=encoder_outs.shape[1]).tolist() + + encoder_outs_pad = torch.zeros_like(inputs_embeds) + for i in range(encoder_outs.shape[0]): + for j in range(self.code_layer): + start_idx = modality_mask_start_indices[i, j].item() + length = modality_lengths[i][j] + encoder_outs_pad[i, j, start_idx:start_idx+length] = encoder_outs[i, :length] + + inputs_embeds[:, :self.code_layer, :, :] = encoder_outs_pad[:, :self.code_layer, :, :] + inputs_embeds[:, :self.code_layer, :, :] * (~modality_mask[:, :, :, None]) + + inputs_embeds = torch.mean(inputs_embeds, dim=1) # [btz, seq_length, emb_dim], average over the 8 layers + + if kwargs.get("inference_mode", False): + return inputs_embeds, attention_mask + + text_labels = labels[:,self.code_layer] if labels is not None else None + audio_labels = labels[:, :self.code_layer] if labels is not None else None + model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=text_labels) # here we use the text token layer as the target label + + # parrallel generation + # TODO: add tts adapter forward + x_ori = model_outputs.logits + text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize + audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize + xt = x_ori[..., :text_vocab_size] + xa = [] + for i in range(self.code_layer): + xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)]) + + loss_recorder = [] + total_loss, loss_recorder = self.compute_parallel_loss(xt, text_labels, xa, audio_labels) + model_outputs.loss = total_loss + + text_acc = -1 + audio_acc = [-1 for _ in range(self.code_layer)] + if self.metric: + with torch.no_grad(): + preds = torch.argmax(xt, -1) + text_acc = compute_accuracy(preds.detach()[:, :-1], text_labels.detach()[:, 1:], ignore_label=-100) + + preds_audio = [torch.argmax(xa[i], -1) for i in range(self.code_layer)] + audio_acc = [compute_accuracy(preds_audio[i].detach()[:, :-1], audio_labels[:, i, 1:], ignore_label=-100) for i in range(self.code_layer)] + + # metrics = {"text_acc": text_acc, "audio_acc": audio_acc, "layer_loss": loss_recorder} + return model_outputs, text_acc, audio_acc, loss_recorder + + + + def compute_parallel_loss(self, xt, text_labels, xa, audio_labels): + """ + Compute the parallel loss for text and audio layers. + """ + text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize + audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize + layer_loss = [0 for _ in range(self.code_layer+1) ] + + if text_labels is not None: + # text_loss = F.cross_entropy(xt.reshape(-1, text_vocab_size), text_labels.reshape(-1), ignore_index=-100) + text_loss = F.cross_entropy(xt[:, :-1, :].reshape(-1, text_vocab_size), text_labels[:, 1:].reshape(-1), ignore_index=-100) + layer_loss[self.code_layer] = text_loss + else: + text_loss = 0 + + total_audio_loss = 0 + single_audio_loss = 0 + for i in range(self.code_layer): + if audio_labels[:,i] is not None: + # audio_loss += F.cross_entropy(xa[i].reshape(-1, audio_vocab_size), audio_labels[:,i].reshape(-1), ignore_index=-100) + single_audio_loss = F.cross_entropy(xa[i][:, :-1, :].reshape(-1, audio_vocab_size), audio_labels[:, i, 1:].reshape(-1), ignore_index=-100) + layer_loss[i] = single_audio_loss + total_audio_loss += single_audio_loss + + total_loss = (text_loss + total_audio_loss) / (self.code_layer+1) + return total_loss, layer_loss + + + @torch.no_grad() + def generate(self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ): + kwargs["inference_mode"] = True + + inputs_embeds, attention_mask = self.forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + generated_ids = [[] for _ in range((self.code_layer+1))] + current_input_text = None + current_audio_tokens = [None for _ in range(self.code_layer)] + # input_pos = torch.arange(input_ids.size(-1), device=input_ids.device).unsqueeze(0) + past_key_values = None + + text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize + audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize + + max_new_tokens = kwargs.get("max_new_tokens", 360) + repetition_penalty = kwargs.get("repetition_penalty", 1.0) + decode_text_only = kwargs.get("decode_text_only", False) + + pad_t = self.model_config.vocab_config.pad_t + pad_a = self.model_config.vocab_config.pad_a + eot = self.model_config.vocab_config.eot + eoa = self.model_config.vocab_config.eoa + + text_end = False # Track whether text generation has ended + audio_end = False # Track whether audio generation has ended + + # NOTE: currently, we only support greedy decoding and sampling for parallel generation, no beam search + for step in tqdm(range(max_new_tokens), desc="Generating"): + if current_input_text is not None: + audio_tokens = torch.cat([layershift(current_audio_tokens[i], i).unsqueeze(1) for i in range(self.code_layer)], dim=1) + combined_input_ids = torch.cat([audio_tokens, current_input_text.unsqueeze(1)], dim=1) + inputs_embeds = self.llm.model.embed_tokens(combined_input_ids) + inputs_embeds = torch.mean(inputs_embeds, dim=1).unsqueeze(1) + + outputs = self.llm( + inputs_embeds=inputs_embeds, # [btz, seq_len / 1, emb_dim] + attention_mask=attention_mask, # single sample, no need for attention mask + past_key_values=past_key_values, + # position_ids=input_pos, + use_cache=True, + ) + + logits = outputs.logits + past_key_values = outputs.past_key_values # Update past_key_values for the next step + + # Split logits into text and audio layers based on vocab size + xt_logits = logits[..., :text_vocab_size] + xa_logits = [logits[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)] for i in range(self.code_layer)] + + # Apply repetition penalty to the logits + if repetition_penalty != 1.0: + xt_logits = self.repetition_penalty(xt_logits, generated_ids[self.code_layer], repetition_penalty) + for i in range(self.code_layer): + xa_logits[i] = self.repetition_penalty(xa_logits[i], generated_ids[i], repetition_penalty) + + if not text_end: + next_token_text = self.sample_next_token(xt_logits[:, -1, :], **kwargs) + else: + next_token_text = torch.tensor([pad_t], device=input_ids.device) + + next_tokens_audio = [] + for i in range(self.code_layer): + if not audio_end and not decode_text_only: + next_token_audio = self.sample_next_token(xa_logits[i][:, -1, :], **kwargs) + else: + next_token_audio = torch.full((input_ids.size(0),), pad_a, device=input_ids.device) + next_tokens_audio.append(next_token_audio) + + if next_tokens_audio[-1] == eoa or decode_text_only: + audio_end = True + if next_token_text == eot: + text_end = True + + # Update input_ids for the next step + current_input_text = next_token_text + for i in range(self.code_layer): + current_audio_tokens[i] = next_tokens_audio[i] + + # if input_pos.size(-1) > 1: + # input_pos = torch.tensor(input_pos.size(-1), device=input_ids.device).unsqueeze(0) + # else: + # input_pos = input_pos.add_(1) + attention_mask = torch.cat([attention_mask, torch.ones((input_ids.size(0), 1), device=input_ids.device)], dim=1) + + if audio_end and text_end: + break + + # Append generated tokens to the list + for i in range(self.code_layer): + generated_ids[i].append(next_tokens_audio[i].clone().tolist()[0]) # Audio layers + generated_ids[self.code_layer].append(next_token_text.clone().tolist()[0]) # Text layer + + # Concatenate the generated tokens to form the complete sequence + text_tokens = generated_ids[-1] + generated_ids[-1] = text_tokens[: text_tokens.index(eot)] if eot in text_tokens else text_tokens + generated_ids = [torch.tensor(layer) for layer in generated_ids] + return generated_ids + + + @torch.no_grad() + def sample_next_token(self, logits, **kwargs): + """ + Generate the next token based on the model output logits. + Supports both greedy decoding, top-k sampling, and top-p (nucleus) sampling. + """ + do_sample = kwargs.get("do_sample", False) + temperature = kwargs.get("temperature", 1.0) + top_k = kwargs.get("top_k", 50) + top_p = kwargs.get("top_p", 1.0) + num_samples = kwargs.get("num_samples", 1) + + # Adjust logits with temperature + logits = logits.squeeze(0) + logits = logits / temperature + + # Top-k filtering + if top_k > 0: + top_k = min(top_k, logits.size(-1)) # Make sure top_k is within the vocab size + values, indices = torch.topk(logits, top_k) + logits[logits < values[..., [-1]]] = -float('Inf') # Filter tokens not in top_k + + # Top-p filtering (nucleus sampling) + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = -float('Inf') + + if do_sample: + # Perform sampling + return torch.multinomial(F.softmax(logits, dim=-1), num_samples=num_samples) + else: + # Greedy decoding (argmax) + return torch.argmax(logits, dim=-1, keepdim=True) + + + def repetition_penalty(self, logits, generated_ids, repetition_penalty): + """ + Apply repetition penalty to the logits. + """ + for token_id in set(generated_ids): + if logits[0, -1, token_id] < 0: + logits[0, -1, token_id] *= repetition_penalty + else: + logits[0, -1, token_id] /= repetition_penalty + + return logits \ No newline at end of file diff --git a/s2s.py b/s2s.py new file mode 100644 index 0000000000000000000000000000000000000000..23018d5b50b0d0b5e42376b7485dd0d0b5c61c84 --- /dev/null +++ b/s2s.py @@ -0,0 +1,178 @@ +import random +import torch +from slam_llm.utils.model_utils import get_custom_model_factory +from utils.snac_utils import reconscruct_snac, reconstruct_tensors, layershift +import whisper +import numpy as np +from s2s_config import InferenceConfig, CKPT_PATH, CKPT_REPO, CKPT_LOCAL_DIR, CKPT_NAME +import os +from omegaconf import OmegaConf +from huggingface_hub import hf_hub_download +from typing import Callable + + +def update_progress(progress_callback: Callable[[str], None] | None, message: str): + if progress_callback: + progress_callback(message) + + +def pull_model_ckpt(): + if not os.path.exists(CKPT_LOCAL_DIR): + os.makedirs(CKPT_LOCAL_DIR) + if os.path.exists(CKPT_PATH): + return + hf_hub_download( + repo_id=CKPT_REPO, + filename=CKPT_NAME, + local_dir=CKPT_LOCAL_DIR, + token=os.getenv("HF_TOKEN"), + ) + + +pull_model_ckpt() + + +def extract_audio_feature(audio_path, mel_size): + print("Extracting audio features from", audio_path) + audio_raw = whisper.load_audio(audio_path) + audio_raw = whisper.pad_or_trim(audio_raw) + audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size).permute(1, 0) + audio_length = (audio_mel.shape[0] + 1) // 2 + audio_length = audio_length // 5 + audio_res = audio_mel + + return audio_res, audio_length + + +def get_input_ids(length, special_token_a, special_token_t, vocab_config): + input_ids = [] + for i in range(vocab_config.code_layer): + input_ids_item = [] + input_ids_item.append(layershift(vocab_config.input_a, i)) + input_ids_item += [layershift(vocab_config.pad_a, i)] * length + input_ids_item += [ + (layershift(vocab_config.eoa, i)), + layershift(special_token_a, i), + ] + input_ids.append(torch.tensor(input_ids_item).unsqueeze(0)) + input_id_T = torch.tensor( + [vocab_config.input_t] + + [vocab_config.pad_t] * length + + [vocab_config.eot, special_token_t] + ) + input_ids.append(input_id_T.unsqueeze(0)) + return input_ids + + +def generate_from_wav( + wav_path, model, codec_decoder, dataset_config, decode_config, device +): + mel_size = dataset_config.mel_size + prompt = dataset_config.prompt + prompt_template = "USER: {}\n ASSISTANT: " + vocab_config = dataset_config.vocab_config + special_token_a = vocab_config.answer_a + special_token_t = vocab_config.answer_t + code_layer = vocab_config.code_layer + task_type = dataset_config.task_type + + audio_mel, audio_length = extract_audio_feature(wav_path, mel_size) + + prompt = prompt_template.format(prompt) + prompt_ids = model.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) + + example_ids = get_input_ids( + audio_length + prompt_length, special_token_a, special_token_t, vocab_config + ) + text_layer = example_ids[code_layer] + text_layer = torch.cat( + ( + text_layer[:, : audio_length + 1], + prompt_ids.unsqueeze(0), + text_layer[:, -2:], + ), + dim=1, + ) #