# -------------------------------------------------------- # InternVL # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import warnings from typing import Any, List, Optional, Tuple, Union from PIL import Image, ImageDraw from io import BytesIO import requests import torch.distributed as dist import torch.utils.checkpoint from .modeling_internlm2 import InternLM2ForCausalLM from peft import LoraConfig, get_peft_model from torch import nn from torch.nn import CrossEntropyLoss from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM, LlamaTokenizer, Qwen2ForCausalLM) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import ModelOutput, logging from .conversation import get_conv_template from .configuration_internvl_chat import InternVLChatConfig from .modeling_intern_vit import InternVisionModel from .modeling_internvl_chat import InternVLChatModel from .configuration_internvl_audio_chat import InternVLChatAudioConfig from .modeling_whisper import AudioWhisperModel from .conversation import get_conv_template def load_audio(audio_file, audio_processor): audio_values, _ = librosa.load(audio_file, sr=16000) # sample rate should be 16000 audio_process_values = audio_processor(audio_values, sampling_rate=16000, return_tensors="pt") input_features = audio_process_values['input_features'] audio_len_after_cnn = audio_process_values['audio_len_after_cnn'] audio_token_num = audio_process_values['audio_token_num'] audio_input = {'audio_values': input_features, 'audio_len_after_cnn': audio_len_after_cnn, 'audio_token_num': audio_token_num, } return audio_input class InternVLChatAudioModel(InternVLChatModel): def __init__(self, config: InternVLChatAudioConfig, vision_model=None, language_model=None, audio_model=None): super().__init__(config, vision_model, language_model) if audio_model is not None: self.audio_model = audio_model else: self.audio_model = AudioWhisperModel(config.audio_config) audio_hidden_size = config.audio_config.d_model llm_hidden_size = config.llm_config.hidden_size self.mlp2 = nn.Sequential( nn.LayerNorm(audio_hidden_size), nn.Linear(audio_hidden_size, llm_hidden_size), nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size) ) # mlp2: audio feature mapping self.audio_context_token_id = None def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if hasattr(module, "bias") and module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() def extract_audio_feature(self, audio_values, audio_len_after_cnn): audio_values = audio_values.squeeze(1) #TODO: construct audio padding_mask in loader max_len_in_batch = int(torch.max(audio_len_after_cnn).item()) padding_mask = torch.ones([audio_values.size(0), max_len_in_batch]).to(dtype=audio_values.dtype, device=audio_values.device) for index in range(len(audio_values)): padding_mask[index, :int(audio_len_after_cnn[index].item())] = 0 last_hidden_state = self.audio_model(audio_values, padding_mask, audio_len_after_cnn) # (bs, max_token_num, 1280) audio_embeds = self.mlp2(last_hidden_state) return audio_embeds def forward( self, pixel_values: torch.FloatTensor, input_ids: torch.LongTensor = None, audio_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, image_flags: Optional[torch.LongTensor] = None, audio_flags: Optional[torch.LongTensor] = None, audio_len_after_cnn: Optional[torch.LongTensor] = None, audio_token_num: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, statistics: Optional[torch.LongTensor] = None, loss_weight: Optional[List] = None, loss_reduction_all_gather: Optional[bool] = False, ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict image_flags = image_flags.squeeze(-1) input_embeds = self.language_model.get_input_embeddings()(input_ids).clone() vit_embeds = self.extract_feature(pixel_values) vit_embeds = vit_embeds[image_flags == 1] vit_batch_size = pixel_values.shape[0] B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}') if statistics is not None: num_samples, num_padding_tokens, num_padding_images = statistics.tolist() self.num_samples += num_samples print(f'total_samples={self.num_samples}, {num_samples=}, {num_padding_tokens=}, {num_padding_images=}') input_ids = input_ids.reshape(B * N) img_selected = (input_ids == self.img_context_token_id) try: input_embeds[img_selected] = input_embeds[img_selected] * 0.0 + vit_embeds.reshape(-1, C) ignore_flag = False except Exception as e: vit_embeds = vit_embeds.reshape(-1, C) print(f'warning: {e}, input_embeds[img_selected].shape={input_embeds[img_selected].shape}, ' f'vit_embeds.shape={vit_embeds.shape}') n_token = img_selected.sum() input_embeds[img_selected] = input_embeds[img_selected] * 0.0 + vit_embeds[:n_token] ignore_flag = True if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: audio_batch_size = audio_values.shape[0] print(f'audio batch size: {audio_batch_size}, audios per sample: {audio_batch_size / B}') audio_embeds = self.extract_audio_feature(audio_values, audio_len_after_cnn) # (audio_num, n_frame, C) output_audios = [] for i in range(len(audio_token_num)): if audio_flags[i] > 0: token_num = int(audio_token_num[i].item()) audio = audio_embeds[i][:token_num] # 提取有效的token output_audios.append(audio) if len(output_audios): output_audios = torch.cat(output_audios, dim=0) audio_selected = (input_ids == self.audio_context_token_id) input_embeds[audio_selected] = input_embeds[audio_selected] * 0.0 + output_audios.reshape(-1, C) input_embeds = input_embeds.reshape(B, N, C) outputs = self.language_model( inputs_embeds=input_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = outputs.logits loss = None if labels is not None and loss_weight is not None: loss_weight = torch.tensor(loss_weight, dtype=torch.float32, device=labels.device) # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() shift_weights = loss_weight[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss(reduction='none') shift_logits = shift_logits.view( -1, self.language_model.config.vocab_size) shift_labels = shift_labels.view(-1) shift_weights = shift_weights.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) shift_weights = shift_weights.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) shift_weights_sum = shift_weights.sum() if loss_reduction_all_gather: dist.all_reduce(shift_weights_sum, op=dist.ReduceOp.AVG) loss = loss * shift_weights loss = loss.sum() / shift_weights_sum if ignore_flag: loss = loss * 0.0 elif labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if ignore_flag: loss = loss * 0.0 if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def Audio_chat(self, tokenizer, pixel_values, audio, question, generation_config, history=None, return_history=False,num_patches_list=None, IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='',AUDIO_START_TOKEN='', AUDIO_CONTEXT_TOKEN='',verbose=None): if history is None and audio is not None: if question is None: question = '