|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""CLIPQwen2VL model implementation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import itertools
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import transformers
|
|
from PIL import Image
|
|
from torch import nn
|
|
from transformers import BertConfig, BertModel, PretrainedConfig, PreTrainedModel
|
|
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
|
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
|
Qwen2VisionTransformerPretrainedModel,
|
|
)
|
|
|
|
|
|
class CLIPQwen2VLConfig(PretrainedConfig):
|
|
model_type = "clip_qwen2vl"
|
|
|
|
def __init__(
|
|
self,
|
|
text_config: Optional[Dict[str, Any]] = None,
|
|
vision_config: Optional[Dict[str, Any]] = None,
|
|
projection_dim: int = 1024,
|
|
logit_scale_init_value: float = 2.6592,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
|
|
text_config = text_config or {}
|
|
vision_config = vision_config or {}
|
|
|
|
self.text_config = BertConfig(**text_config)
|
|
self.vision_config = Qwen2VLVisionConfig(**vision_config)
|
|
|
|
self.projection_dim = projection_dim
|
|
self.logit_scale_init_value = logit_scale_init_value
|
|
|
|
|
|
class CLIPQwen2VLModel(PreTrainedModel):
|
|
config_class = CLIPQwen2VLConfig
|
|
|
|
def __init__(self, config: CLIPQwen2VLConfig):
|
|
super().__init__(config)
|
|
|
|
self.projection_dim = config.text_config.hidden_size
|
|
self.text_embed_dim = config.text_config.hidden_size
|
|
self.vision_embed_dim = config.vision_config.hidden_size
|
|
|
|
|
|
self.text_model = BertModel(config.text_config)
|
|
|
|
|
|
self.vision_model = Qwen2VisionTransformerPretrainedModel(config.vision_config)
|
|
|
|
|
|
self.vision_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
|
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * config.logit_scale_init_value)
|
|
|
|
def get_text_features(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> torch.FloatTensor:
|
|
text_outputs = self.text_model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
|
|
attention_mask = attention_mask.to(text_outputs.last_hidden_state.dtype)
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(
|
|
text_outputs.last_hidden_state.size()
|
|
)
|
|
sum_embeddings = torch.sum(text_outputs.last_hidden_state * input_mask_expanded, 1)
|
|
sum_mask = input_mask_expanded.sum(1)
|
|
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
|
text_embeds = sum_embeddings / sum_mask
|
|
|
|
return text_embeds
|
|
|
|
def get_image_features(
|
|
self,
|
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
) -> torch.FloatTensor:
|
|
batch_size = image_grid_thw.shape[0]
|
|
spatial_merge_size = 2
|
|
|
|
cu_seqlens = torch.repeat_interleave(
|
|
image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]
|
|
).cumsum(dim=0, dtype=torch.int32)
|
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
|
|
|
vision_output = self.vision_model(hidden_states=pixel_values, grid_thw=image_grid_thw)
|
|
|
|
merged_patches_per_image = [
|
|
((h // spatial_merge_size) * (w // spatial_merge_size) * t).item()
|
|
for t, h, w in image_grid_thw
|
|
]
|
|
merged_cu_seqlens = torch.tensor(
|
|
[0] + list(itertools.accumulate(merged_patches_per_image)),
|
|
device=vision_output.device,
|
|
)
|
|
|
|
image_features = []
|
|
for i in range(batch_size):
|
|
start_idx = merged_cu_seqlens[i]
|
|
end_idx = merged_cu_seqlens[i + 1]
|
|
image_features.append(vision_output[start_idx:end_idx].mean(dim=0))
|
|
|
|
image_features = torch.stack(image_features)
|
|
image_embeds = self.vision_projection(image_features)
|
|
return image_embeds
|
|
|
|
|
|
class CLIPQwen2VLWrapper(nn.Module):
|
|
save_in_root: bool = True
|
|
|
|
def __init__(
|
|
self, model_name_or_path: str, cache_dir: str = None, backend: str = "torch", **kwargs
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
model_args = kwargs.get("model_args", {})
|
|
if "torch_dtype" not in model_args:
|
|
model_args["torch_dtype"] = torch.bfloat16
|
|
|
|
self.model = CLIPQwen2VLModel.from_pretrained(
|
|
model_name_or_path, cache_dir=cache_dir, **model_args
|
|
)
|
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained("cl-nagoya/ruri-large")
|
|
self.processor = transformers.AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
|
|
|
def __repr__(self) -> str:
|
|
return "CLIPQwen2VLWrapper()"
|
|
|
|
def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
image_embeds = []
|
|
text_embeds = []
|
|
|
|
if "pixel_values" in features:
|
|
image_embeds = self.model.get_image_features(
|
|
pixel_values=features["pixel_values"],
|
|
image_grid_thw=features["image_grid_thw"],
|
|
)
|
|
|
|
if "input_ids" in features:
|
|
text_embeds = self.model.get_text_features(
|
|
input_ids=features["input_ids"],
|
|
attention_mask=features.get("attention_mask", None),
|
|
position_ids=features.get("position_ids", None),
|
|
output_attentions=features.get("output_attentions", None),
|
|
output_hidden_states=features.get("output_hidden_states", None),
|
|
)
|
|
|
|
sentence_embedding = []
|
|
image_features = iter(image_embeds)
|
|
text_features = iter(text_embeds)
|
|
|
|
for idx, input_type in enumerate(features["image_text_info"]):
|
|
if input_type == 0:
|
|
sentence_embedding.append(next(image_features))
|
|
else:
|
|
sentence_embedding.append(next(text_features))
|
|
|
|
features["sentence_embedding"] = torch.stack(sentence_embedding).float()
|
|
|
|
return features
|
|
|
|
def tokenize(
|
|
self, texts: List[Union[str, Image.Image]], padding: str | bool = True
|
|
) -> dict[str, torch.Tensor]:
|
|
images = []
|
|
texts_values = []
|
|
image_text_info = []
|
|
|
|
for idx, data in enumerate(texts):
|
|
if isinstance(data, Image.Image):
|
|
images.append(data)
|
|
image_text_info.append(0)
|
|
else:
|
|
texts_values.append(data)
|
|
image_text_info.append(1)
|
|
|
|
encoding = {}
|
|
if len(texts_values):
|
|
encoding = self.tokenizer(
|
|
texts_values,
|
|
return_tensors="pt",
|
|
padding=padding,
|
|
truncation=True,
|
|
max_length=512,
|
|
)
|
|
|
|
if len(images):
|
|
image_features = self.processor.image_processor(images, return_tensors="pt")
|
|
encoding.update(image_features)
|
|
|
|
encoding["image_text_info"] = image_text_info
|
|
return dict(encoding)
|
|
|
|
@property
|
|
def processor(self) -> transformers.PreTrainedModel:
|
|
return self._processor
|
|
|
|
@processor.setter
|
|
def processor(self, processor):
|
|
self._processor = processor
|
|
|
|
def save(self, output_path: str) -> None:
|
|
self.model.save_pretrained(output_path)
|
|
self.tokenizer.save_pretrained(output_path)
|
|
self.processor.save_pretrained(output_path)
|
|
|
|
@staticmethod
|
|
def load(input_path: str) -> CLIPQwen2VLWrapper:
|
|
return CLIPQwen2VLWrapper(model_name_or_path=input_path)
|
|
|