# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import warnings
from typing import Any, List, Optional, Tuple, Union
from PIL import Image, ImageDraw
from io import BytesIO
import requests
import torch.distributed as dist
import torch.utils.checkpoint
from .modeling_internlm2 import InternLM2ForCausalLM
from peft import LoraConfig, get_peft_model
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
LlamaTokenizer, Qwen2ForCausalLM)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, logging
from .conversation import get_conv_template
from .configuration_internvl_chat import InternVLChatConfig
from .modeling_intern_vit import InternVisionModel
from .modeling_internvl_chat import InternVLChatModel
from .configuration_internvl_audio_chat import InternVLChatAudioConfig
from .modeling_whisper import AudioWhisperModel
from .conversation import get_conv_template
def load_audio(audio_file, audio_processor):
audio_values, _ = librosa.load(audio_file, sr=16000) # sample rate should be 16000
audio_process_values = audio_processor(audio_values, sampling_rate=16000, return_tensors="pt")
input_features = audio_process_values['input_features']
audio_len_after_cnn = audio_process_values['audio_len_after_cnn']
audio_token_num = audio_process_values['audio_token_num']
audio_input = {'audio_values': input_features,
'audio_len_after_cnn': audio_len_after_cnn,
'audio_token_num': audio_token_num,
}
return audio_input
class InternVLChatAudioModel(InternVLChatModel):
def __init__(self, config: InternVLChatAudioConfig, vision_model=None, language_model=None, audio_model=None):
super().__init__(config, vision_model, language_model)
if audio_model is not None:
self.audio_model = audio_model
else:
self.audio_model = AudioWhisperModel(config.audio_config)
audio_hidden_size = config.audio_config.d_model
llm_hidden_size = config.llm_config.hidden_size
self.mlp2 = nn.Sequential(
nn.LayerNorm(audio_hidden_size),
nn.Linear(audio_hidden_size, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size)
) # mlp2: audio feature mapping
self.audio_context_token_id = None
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def extract_audio_feature(self, audio_values, audio_len_after_cnn):
audio_values = audio_values.squeeze(1)
#TODO: construct audio padding_mask in loader
max_len_in_batch = int(torch.max(audio_len_after_cnn).item())
padding_mask = torch.ones([audio_values.size(0), max_len_in_batch]).to(dtype=audio_values.dtype,
device=audio_values.device)
for index in range(len(audio_values)):
padding_mask[index, :int(audio_len_after_cnn[index].item())] = 0
last_hidden_state = self.audio_model(audio_values, padding_mask, audio_len_after_cnn) # (bs, max_token_num, 1280)
audio_embeds = self.mlp2(last_hidden_state)
return audio_embeds
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.LongTensor = None,
audio_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
image_flags: Optional[torch.LongTensor] = None,
audio_flags: Optional[torch.LongTensor] = None,
audio_len_after_cnn: Optional[torch.LongTensor] = None,
audio_token_num: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
statistics: Optional[torch.LongTensor] = None,
loss_weight: Optional[List] = None,
loss_reduction_all_gather: Optional[bool] = False,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
image_flags = image_flags.squeeze(-1)
input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
vit_embeds = self.extract_feature(pixel_values)
vit_embeds = vit_embeds[image_flags == 1]
vit_batch_size = pixel_values.shape[0]
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
if statistics is not None:
num_samples, num_padding_tokens, num_padding_images = statistics.tolist()
self.num_samples += num_samples
print(f'total_samples={self.num_samples}, {num_samples=}, {num_padding_tokens=}, {num_padding_images=}')
input_ids = input_ids.reshape(B * N)
img_selected = (input_ids == self.img_context_token_id)
try:
input_embeds[img_selected] = input_embeds[img_selected] * 0.0 + vit_embeds.reshape(-1, C)
ignore_flag = False
except Exception as e:
vit_embeds = vit_embeds.reshape(-1, C)
print(f'warning: {e}, input_embeds[img_selected].shape={input_embeds[img_selected].shape}, '
f'vit_embeds.shape={vit_embeds.shape}')
n_token = img_selected.sum()
input_embeds[img_selected] = input_embeds[img_selected] * 0.0 + vit_embeds[:n_token]
ignore_flag = True
if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
audio_batch_size = audio_values.shape[0]
print(f'audio batch size: {audio_batch_size}, audios per sample: {audio_batch_size / B}')
audio_embeds = self.extract_audio_feature(audio_values, audio_len_after_cnn) # (audio_num, n_frame, C)
output_audios = []
for i in range(len(audio_token_num)):
if audio_flags[i] > 0:
token_num = int(audio_token_num[i].item())
audio = audio_embeds[i][:token_num] # 提取有效的token
output_audios.append(audio)
if len(output_audios):
output_audios = torch.cat(output_audios, dim=0)
audio_selected = (input_ids == self.audio_context_token_id)
input_embeds[audio_selected] = input_embeds[audio_selected] * 0.0 + output_audios.reshape(-1, C)
input_embeds = input_embeds.reshape(B, N, C)
outputs = self.language_model(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = outputs.logits
loss = None
if labels is not None and loss_weight is not None:
loss_weight = torch.tensor(loss_weight,
dtype=torch.float32,
device=labels.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_weights = loss_weight[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction='none')
shift_logits = shift_logits.view(
-1, self.language_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_weights = shift_weights.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
shift_weights = shift_weights.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
shift_weights_sum = shift_weights.sum()
if loss_reduction_all_gather:
dist.all_reduce(shift_weights_sum, op=dist.ReduceOp.AVG)
loss = loss * shift_weights
loss = loss.sum() / shift_weights_sum
if ignore_flag:
loss = loss * 0.0
elif labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if ignore_flag:
loss = loss * 0.0
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def Audio_chat(self, tokenizer, pixel_values, audio, question, generation_config, history=None, return_history=False,num_patches_list=None,
IMG_START_TOKEN='
', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='',AUDIO_START_TOKEN='',
AUDIO_CONTEXT_TOKEN='',verbose=None):
if history is None and audio is not None:
if question is None:
question = '