Spaces:
Runtime error
Runtime error
import os | |
import types | |
import torch | |
import soundfile as sf | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.distributed as dist | |
from typing import List, Optional, Tuple, Union | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, AutoModelForSeq2SeqLM, T5ForConditionalGeneration | |
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training | |
from slam_llm.utils.config_utils import generate_peft_config | |
from slam_llm.utils.train_utils import print_module_size, print_model_size | |
from peft import PeftModel, PeftConfig | |
from torch.nn import CrossEntropyLoss | |
from slam_llm.utils.metric import compute_accuracy | |
import logging | |
logger = logging.getLogger(__name__) | |
def model_factory(train_config, model_config, **kwargs): | |
# return necessary components for training | |
tokenizer = setup_tokenizer(train_config, model_config, **kwargs) | |
encoder = setup_encoder(train_config, model_config, **kwargs) | |
# llm | |
llm = setup_llm(train_config, model_config, **kwargs) | |
# projector | |
encoder_projector = setup_encoder_projector( | |
train_config, model_config, **kwargs | |
) | |
model = slam_model( | |
encoder, | |
llm, | |
encoder_projector, | |
tokenizer, | |
train_config, | |
model_config, | |
**kwargs, | |
) | |
ckpt_path = kwargs.get("ckpt_path", None) #FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) | |
if ckpt_path is not None: | |
logger.info("loading other parts from: {}".format(ckpt_path)) | |
ckpt_dict = torch.load(ckpt_path, map_location="cpu") | |
model.load_state_dict(ckpt_dict, strict=False) | |
print_model_size(model, train_config, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) | |
return model, tokenizer | |
def setup_tokenizer(train_config, model_config, **kwargs): | |
# Load the tokenizer and add special tokens | |
if "vallex" in model_config.llm_name.lower(): | |
return None | |
elif "mupt" in model_config.llm_name.lower(): | |
tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path, | |
trust_remote_code=True, | |
use_fast=False) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path) | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
return tokenizer | |
def setup_encoder(train_config, model_config, **kwargs): | |
encoder_list = model_config.encoder_name.split(",") if model_config.encoder_name else [] | |
if len(encoder_list) == 0: | |
return None | |
if len(encoder_list) == 1: | |
encoder_name = encoder_list[0] | |
if encoder_name == "whisper" or encoder_name == "qwen-audio": | |
from slam_llm.models.encoder import WhisperWrappedEncoder | |
encoder = WhisperWrappedEncoder.load(model_config) | |
if encoder_name == "beats": | |
from slam_llm.models.encoder import BEATsEncoder | |
encoder = BEATsEncoder.load(model_config) | |
if encoder_name == "eat": | |
from slam_llm.models.encoder import EATEncoder | |
encoder = EATEncoder.load(model_config) | |
if encoder_name == "SpatialAST": | |
from slam_llm.models.encoder import SpatialASTEncoder | |
encoder = SpatialASTEncoder.load(model_config) | |
if encoder_name == "wavlm": | |
from slam_llm.models.encoder import WavLMEncoder | |
encoder = WavLMEncoder.load(model_config) | |
if encoder_name == "av_hubert": | |
from slam_llm.models.encoder import AVHubertEncoder | |
encoder = AVHubertEncoder.load(model_config) | |
if encoder_name == "hubert": | |
from slam_llm.models.encoder import HubertEncoder | |
encoder = HubertEncoder.load(model_config) | |
if encoder_name == "musicfm": | |
from slam_llm.models.encoder import MusicFMEncoder | |
encoder = MusicFMEncoder.load(model_config) | |
if "llama" in encoder_name.lower(): | |
from slam_llm.models.encoder import HfTextEncoder | |
encoder = HfTextEncoder.load(model_config) | |
print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) | |
if train_config.freeze_encoder: | |
for name, param in encoder.named_parameters(): | |
param.requires_grad = False | |
encoder.eval() | |
print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) | |
return encoder | |
def setup_llm(train_config, model_config, **kwargs): | |
from pkg_resources import packaging | |
use_cache = False if train_config.enable_fsdp or train_config.enable_ddp else None | |
if (train_config.enable_fsdp or train_config.enable_ddp) and train_config.low_cpu_fsdp: | |
""" | |
for FSDP, we can save cpu memory by loading pretrained model on rank0 only. | |
this avoids cpu oom when loading large models like llama 70B, in which case | |
model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms | |
overhead and currently requires latest nightly. | |
""" | |
# v = packaging.version.parse(torch.__version__) | |
# verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 | |
# if not verify_latest_nightly: | |
# raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " | |
# "please install latest nightly.") | |
rank = int(os.environ["RANK"]) | |
if rank == 0: | |
if "vallex" in model_config.llm_name.lower(): | |
from src.slam_llm.models.vallex.vallex_config import VallexConfig | |
from src.slam_llm.models.vallex.vallex_model import VALLE | |
vallex_config = VallexConfig( | |
**model_config | |
) | |
model = VALLE(vallex_config) | |
elif "aya" in model_config.llm_name.lower(): | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_config.llm_path, | |
load_in_8bit=True if train_config.quantization else None, | |
device_map="auto" if train_config.quantization else None, | |
use_cache=use_cache, | |
) | |
else: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_config.llm_path, | |
load_in_8bit=True if train_config.quantization else None, | |
device_map="auto" if train_config.quantization else None, | |
use_cache=use_cache, | |
) | |
else: | |
llama_config = AutoConfig.from_pretrained(model_config.llm_path) | |
llama_config.use_cache = use_cache | |
# with torch.device("meta"): | |
if "aya" in model_config.llm_name.lower(): | |
model = AutoModelForSeq2SeqLM(llama_config) | |
else: | |
model = AutoModelForCausalLM(llama_config) #(FIX:MZY): torch 2.0.1 does not support `meta` | |
else: | |
if "vallex" in model_config.llm_name.lower(): | |
from src.slam_llm.models.vallex.vallex_config import VallexConfig | |
from src.slam_llm.models.vallex.vallex_model import VALLE | |
vallex_config = VallexConfig( | |
**model_config | |
) | |
model = VALLE(vallex_config) | |
elif "aya" in model_config.llm_name.lower(): | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_config.llm_path, | |
load_in_8bit=True if train_config.quantization else None, | |
device_map="auto" if train_config.quantization else None, | |
use_cache=use_cache, | |
) | |
else: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_config.llm_path, | |
load_in_8bit=True if train_config.quantization else None, | |
device_map="auto" if train_config.quantization else None, | |
use_cache=use_cache, | |
) | |
if (train_config.enable_fsdp or train_config.enable_ddp) and train_config.use_fast_kernels: | |
""" | |
For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable | |
using of Flash Attention or Xformer memory-efficient kernels | |
based on the hardware being used. This would speed up fine-tuning. | |
""" | |
try: | |
from optimum.bettertransformer import BetterTransformer | |
model = BetterTransformer.transform(model) | |
except ImportError: | |
logger.warning("Module 'optimum' not found. Please install 'optimum' it before proceeding.") | |
print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) | |
# Prepare the model for int8 training if quantization is enabled | |
if train_config.quantization: | |
model = prepare_model_for_kbit_training(model) | |
if train_config.freeze_llm: # TODO:to test offical `freeze_layers` and `num_freeze_layers` | |
for name, param in model.named_parameters(): | |
param.requires_grad = False | |
model.eval() | |
if kwargs.get("peft_ckpt", None): # (FIX:MZY):reload will get wrong results when decoding | |
logger.info("loading peft_ckpt from: {}".format(kwargs.get("peft_ckpt"))) | |
model = PeftModel.from_pretrained(model=model, model_id=kwargs.get("peft_ckpt"), is_trainable=True) | |
model.print_trainable_parameters() | |
elif train_config.use_peft: | |
logger.info("setup peft...") | |
peft_config = generate_peft_config(train_config) | |
model = get_peft_model(model, peft_config) | |
model.print_trainable_parameters() | |
print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) | |
return model | |
def setup_encoder_projector(train_config, model_config, **kwargs): | |
if model_config.encoder_projector == "linear": | |
from slam_llm.models.projector import EncoderProjectorConcat | |
encoder_projector = EncoderProjectorConcat(model_config) | |
elif model_config.encoder_projector == "cov1d-linear": | |
from slam_llm.models.projector import EncoderProjectorCov1d | |
encoder_projector = EncoderProjectorCov1d(model_config) | |
elif model_config.encoder_projector == "q-former": | |
from slam_llm.models.projector import EncoderProjectorQFormer | |
encoder_projector = EncoderProjectorQFormer(model_config) | |
else: | |
return None | |
print_module_size(encoder_projector, model_config.encoder_projector, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) | |
return encoder_projector | |
class slam_model(nn.Module): | |
def __init__( | |
self, | |
encoder: nn.Module, | |
llm: nn.Module, | |
encoder_projector: nn.Module, | |
tokenizer, | |
train_config, | |
model_config, | |
**kwargs | |
): | |
super().__init__() | |
# modality encoder | |
self.encoder = encoder | |
# llm | |
self.llm = llm | |
# projector | |
self.encoder_projector = encoder_projector | |
# tokenizer | |
self.tokenizer = tokenizer | |
self.metric = kwargs.get("metric", "acc") | |
self.train_config = train_config | |
self.model_config = model_config | |
if train_config.get("enable_deepspeed", False): | |
def new_forward(self, input): | |
output = F.layer_norm( | |
input.float(), | |
self.normalized_shape, | |
self.weight.float() if self.weight is not None else None, | |
self.bias.float() if self.bias is not None else None, | |
self.eps, | |
) | |
return output.type_as(input) | |
for item in self.modules(): | |
if isinstance(item, nn.LayerNorm): | |
item.forward = types.MethodType(new_forward, item) | |
def forward(self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
**kwargs, | |
): | |
audio_mel = kwargs.get("audio_mel", None) | |
audio_mel_mask = kwargs.get("audio_mel_mask", None) | |
audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper | |
audio = kwargs.get("audio", None) | |
audio_mask = kwargs.get("audio_mask", None) | |
visual = kwargs.get("visual", None) | |
visual_mask = kwargs.get("visual_mask", None) | |
# for text encoder | |
instruct_ids = kwargs.get("instruct_ids", None) | |
instruct_mask = kwargs.get("instruct_mask", None) | |
modality_mask = kwargs.get("modality_mask", None) | |
zh_data = kwargs.get("zh", None) | |
en_data = kwargs.get("en", None) | |
encoder_outs = None | |
if audio_mel is not None or audio is not None or visual is not None: | |
if self.train_config.freeze_encoder: # freeze encoder | |
self.encoder.eval() | |
if self.model_config.encoder_name == "whisper": | |
encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim | |
if self.model_config.encoder_name == "beats": | |
encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel, audio_mel_mask) # bs*seq*dim | |
if self.model_config.encoder_name == "eat": | |
encoder_outs = self.encoder.model.extract_features(audio_mel.unsqueeze(dim=1), padding_mask = None, mask=False, remove_extra_tokens = False)['x'] | |
if self.model_config.encoder_name == "SpatialAST": | |
encoder_outs = self.encoder(audio) # output: [bs, seq_len=3+512, dim=768] | |
if self.model_config.encoder_name == "wavlm": | |
encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask | |
if self.model_config.encoder_name == "hubert": | |
results = self.encoder(source = audio, padding_mask = 1-audio_mask) | |
if self.model_config.encoder_type == "pretrain": | |
encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"] | |
if self.model_config.encoder_type == "finetune": | |
encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"] | |
encoder_outs = encoder_outs.transpose(0, 1) | |
if self.model_config.encoder_name == "av_hubert": | |
results = self.encoder(source={'video':visual, 'audio':audio}, padding_mask=visual_mask) # bs*seq*dim | |
encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"] | |
encoder_outs = encoder_outs.transpose(0, 1) | |
audio_mel_post_mask = (~audio_mel_post_mask).float() | |
if self.model_config.encoder_name == 'musicfm': | |
encoder_outs = self.encoder.extract_features(audio, padding_mask = None) # MusicFM doesn't support padding mask | |
if self.encoder is None: | |
encoder_outs = audio_mel if audio_mel is not None else audio | |
if self.model_config.encoder_projector == "q-former": | |
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) | |
if self.model_config.encoder_projector == "linear": | |
encoder_outs = self.encoder_projector(encoder_outs) | |
if self.model_config.encoder_projector == "cov1d-linear": | |
encoder_outs = self.encoder_projector(encoder_outs) | |
if instruct_ids is not None: | |
if self.encoder is not None: | |
encoder_outs = self.encoder(input_ids=instruct_ids, attention_mask=instruct_mask).last_hidden_state | |
if self.model_config.encoder_projector == "q-former": | |
encoder_outs = self.encoder_projector(encoder_outs, instruct_mask) | |
if self.model_config.encoder_projector == "linear": | |
encoder_outs = self.encoder_projector(encoder_outs) | |
if input_ids is not None: | |
input_ids[input_ids == -1] = 0 | |
if isinstance(self.llm, T5ForConditionalGeneration): | |
inputs_embeds = self.llm.shared(input_ids) | |
else: | |
if hasattr(self.llm.model, "embed_tokens"): | |
inputs_embeds = self.llm.model.embed_tokens(input_ids) | |
elif hasattr(self.llm.model.model, "embed_tokens"): | |
inputs_embeds = self.llm.model.model.embed_tokens(input_ids) | |
else: | |
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) | |
if modality_mask is not None: | |
modality_mask_start_indices = (modality_mask == True).float().argmax(dim=1) | |
modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist() | |
encoder_outs_pad = torch.zeros_like(inputs_embeds) | |
for i in range(encoder_outs.shape[0]): | |
encoder_outs_pad[ | |
i, modality_mask_start_indices[i]:modality_mask_start_indices[i]+modality_lengths[i] | |
] = encoder_outs[i][:modality_lengths[i]] | |
inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None]) | |
if kwargs.get("inference_mode", False): | |
return inputs_embeds, attention_mask | |
if zh_data is not None and en_data is not None: | |
model_outputs, acc = self.llm(zh=zh_data, en=en_data) | |
else: | |
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) | |
acc = -1 | |
if self.metric: | |
with torch.no_grad(): | |
preds = torch.argmax(model_outputs.logits, -1) | |
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100) | |
return model_outputs, acc | |
def generate(self, | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
**kwargs, | |
): | |
kwargs["inference_mode"] = True | |
inputs_embeds, attention_mask = self.forward( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
labels=labels, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
**kwargs, | |
) | |
model_outputs = self.llm.generate( | |
inputs_embeds=inputs_embeds, | |
# max_length=kwargs.get("max_length", 200), | |
max_new_tokens=kwargs.get("max_new_tokens", 200), | |
num_beams=kwargs.get("num_beams", 4), | |
do_sample=kwargs.get("do_sample", False), | |
min_length=kwargs.get("min_length", 1), | |
top_p=kwargs.get("top_p", 1.0), | |
repetition_penalty=kwargs.get("repetition_penalty", 1.0), | |
length_penalty=kwargs.get("length_penalty", 1.0), | |
temperature=kwargs.get("temperature", 1.0), | |
attention_mask=attention_mask, | |
bos_token_id=self.tokenizer.bos_token_id, | |
eos_token_id=self.tokenizer.eos_token_id, | |
pad_token_id=self.tokenizer.pad_token_id | |
) | |
return model_outputs | |