huseinzol05's picture
Upload MM_LLMs
fb8e179 verified
from collections import Counter, defaultdict
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from torch.nn import CrossEntropyLoss
import copy
import math
from transformers.activations import gelu
from typing import List, Optional, Tuple, Union
from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
from transformers import CONFIG_MAPPING
from transformers.modeling_outputs import BaseModelOutput
from transformers import GenerationConfig
from transformers import CLIPConfig, CLIPProcessor, CLIPModel, AutoModel
from transformers import WhisperConfig, WhisperPreTrainedModel, WhisperModel
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig
def most_frequent_element(tensor):
flattened_list = tensor.flatten().tolist()
counter = Counter(flattened_list)
most_common_element = counter.most_common(1)[0][1]
return most_common_element
class MM_LLMs_Config(PretrainedConfig):
model_type = 'mm_llms'
is_composition = True
def __init__(
self,
audio_config=None,
llm_config=None,
audio_select_layer=-2,
**kwargs
):
self.audio_config = audio_config
self.llm_config = llm_config
self.audio_select_layer = audio_select_layer
if isinstance(self.audio_config, dict):
audio_config["model_type"] = (
audio_config["model_type"] if "model_type" in audio_config else "whisper"
)
self.audio_config = CONFIG_MAPPING[audio_config["model_type"]](**audio_config)
if isinstance(self.llm_config, dict):
llm_config["model_type"] = llm_config["model_type"] if "model_type" in llm_config else "llama"
self.llm_config = CONFIG_MAPPING[llm_config["model_type"]](**llm_config)
super().__init__(**kwargs)
class LlavaMultiModalProjector(nn.Module):
def __init__(self, in_hidden_size, out_hidden_size, conv_kernel=None, conv_stride=3):
super().__init__()
self.conv_kernel = conv_kernel
if conv_kernel:
self.linear_1 = nn.Conv1d(
in_hidden_size,
out_hidden_size,
kernel_size=conv_kernel,
stride=conv_stride)
else:
self.linear_1 = nn.Linear(
in_hidden_size,
out_hidden_size,
bias=True,
)
self.act = gelu
self.linear_2 = nn.Linear(
out_hidden_size,
out_hidden_size,
bias=True)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
if self.conv_kernel:
hidden_states = hidden_states.transpose(1, 2).contiguous()
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class MM_LLMs(PreTrainedModel):
config_class = MM_LLMs_Config
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
def __init__(self, config, flash_attention=False, dtype=torch.float32):
super().__init__(config)
self.config = config
self.audio_encoder = AutoModel.from_config(config.audio_config)
self.llm = AutoModelForCausalLM.from_config(
config.llm_config,
use_flash_attention_2=flash_attention,
torch_dtype=dtype,
)
self.audio_projector = LlavaMultiModalProjector(
config.audio_config.d_model,
config.llm_config.hidden_size,
conv_kernel=40,
conv_stride=3,
)
def forward(self,
input_ids: torch.LongTensor = None,
image_index: torch.LongTensor = None,
audio_index: torch.LongTensor = None,
image_starts: torch.int = None,
image_ends: torch.int = None,
audio_starts: torch.int = None,
audio_ends: torch.int = None,
images: torch.FloatTensor = None,
audios: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
use_cache: Optional[bool] = None,
return_dict: Optional[bool] = None, **kwargs):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
audios = audios.type(self.audio_encoder.dtype) if audios is not None else None
model_inputs = self.prepare_inputs_for_generation(
input_ids=input_ids,
image_index=image_index,
audio_index=audio_index,
image_starts=image_starts,
image_ends=image_ends,
audio_starts=audio_starts,
audio_ends=audio_ends,
images=images,
audios=audios,
attention_mask=attention_mask,
labels=labels)
print(input_ids.shape, model_inputs['inputs_embeds'].shape)
outputs = self.llm(
inputs_embeds=model_inputs['inputs_embeds'],
attention_mask=model_inputs['attention_mask'],
labels=model_inputs['labels'],
return_dict=return_dict)
return outputs
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
images=None,
audios=None,
audio_starts=None,
audio_ends=None,
image_starts=None,
image_ends=None,
attention_mask=None,
labels=None,
audio_index=None,
image_index=None,
**kwargs):
audio_features = self.encode_audio(
audios) if audios is not None else None
embed_tokens = self.llm.model.embed_tokens
text_embeddings = embed_tokens(input_ids)
batch_size = text_embeddings.shape[0]
seq_len = text_embeddings.shape[1]
embed_dim = text_embeddings.shape[2]
if len(audio_index):
max_count_audio = most_frequent_element(audio_index)
else:
max_count_audio = 0
if audio_features is not None:
seq_audio = audio_features.shape[1]
else:
seq_audio = 0
audio_len = seq_audio * max_count_audio
new_len = text_embeddings.shape[1] + audio_len
final_embedding = torch.zeros(
batch_size, new_len, embed_dim,
device=text_embeddings.device,
dtype=text_embeddings.dtype
)
final_embedding[:, :seq_len] = text_embeddings
final_attention_mask = torch.zeros(
batch_size, new_len,
device=attention_mask.device,
dtype=attention_mask.dtype
)
final_attention_mask[:, :seq_len] = attention_mask
if labels is not None:
final_labels = torch.full(
(batch_size, new_len),
-100,
device=labels.device,
dtype=labels.dtype
)
final_labels[:, :seq_len] = labels
else:
final_labels = None
audio_id = int(audio_starts[0])
where_is = torch.where(input_ids == audio_id)
positions = defaultdict(int)
b_audio = 0
for i in range(len(where_is[0])):
b, k = where_is[0][i], where_is[1][i]
int_b = int(b)
int_k = int(k)
f = audio_features[b_audio]
b_audio += 1
c = torch.cat([final_embedding[b, :int_k + 1 + positions[int_b]],
f, text_embeddings[b, k + 1:]])
final_embedding[b, :len(c)] = c
final_attention_mask[b, :len(c)] = 1.0
if labels is not None:
ignore = torch.tensor([-100] * len(f), device=labels.device)
c_label = torch.cat(
[final_labels[b, :int_k + 1 + positions[int_b]], ignore, labels[b, k + 1:]])
final_labels[b, :len(c)] = c_label
positions[int_b] += len(f)
model_inputs = {
"input_ids": input_ids,
"inputs_embeds": final_embedding,
"use_cache": kwargs.get("use_cache"),
"attention_mask": final_attention_mask,
"labels": final_labels,
}
return model_inputs
def encode_audio(self, audios):
encoded = self.audio_encoder.encoder(audios, output_hidden_states=True)
encoded = encoded.hidden_states[self.config.audio_select_layer]
audio_features = self.audio_projector(encoded.transpose(1, 2).contiguous())
return audio_features