# coding=utf-8 # Copyright 2024 oshizo # # This implementation is based on: # 1. Qwen2-VL (https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/) # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. # Originally based on EleutherAI's GPT-NeoX library and GPT-NeoX/OPT implementations. # # 2. CLIP (https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/) # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. # CLIP Configuration # Copyright 2021 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """CLIPQwen2VL model implementation.""" from __future__ import annotations import itertools from typing import Any, Dict, List, Optional, Union import torch import torch.nn.functional as F import transformers from PIL import Image from torch import nn from transformers import BertConfig, BertModel, PretrainedConfig, PreTrainedModel from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig from transformers.models.qwen2_vl.modeling_qwen2_vl import ( Qwen2VisionTransformerPretrainedModel, ) class CLIPQwen2VLConfig(PretrainedConfig): model_type = "clip_qwen2vl" def __init__( self, text_config: Optional[Dict[str, Any]] = None, vision_config: Optional[Dict[str, Any]] = None, projection_dim: int = 1024, logit_scale_init_value: float = 2.6592, **kwargs, ): super().__init__(**kwargs) text_config = text_config or {} vision_config = vision_config or {} self.text_config = BertConfig(**text_config) self.vision_config = Qwen2VLVisionConfig(**vision_config) self.projection_dim = projection_dim self.logit_scale_init_value = logit_scale_init_value class CLIPQwen2VLModel(PreTrainedModel): config_class = CLIPQwen2VLConfig def __init__(self, config: CLIPQwen2VLConfig): super().__init__(config) self.projection_dim = config.text_config.hidden_size # 1024 self.text_embed_dim = config.text_config.hidden_size # 1024 self.vision_embed_dim = config.vision_config.hidden_size # 1536 # Text encoder self.text_model = BertModel(config.text_config) # Vision encoder self.vision_model = Qwen2VisionTransformerPretrainedModel(config.vision_config) # vision projection (1536 -> 1024) self.vision_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) self.logit_scale = nn.Parameter(torch.ones([]) * config.logit_scale_init_value) def get_text_features( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # Mean pooling attention_mask = attention_mask.to(text_outputs.last_hidden_state.dtype) input_mask_expanded = attention_mask.unsqueeze(-1).expand( text_outputs.last_hidden_state.size() ) sum_embeddings = torch.sum(text_outputs.last_hidden_state * input_mask_expanded, 1) sum_mask = input_mask_expanded.sum(1) sum_mask = torch.clamp(sum_mask, min=1e-9) text_embeds = sum_embeddings / sum_mask return text_embeds def get_image_features( self, pixel_values: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, ) -> torch.FloatTensor: batch_size = image_grid_thw.shape[0] spatial_merge_size = 2 cu_seqlens = torch.repeat_interleave( image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0] ).cumsum(dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) vision_output = self.vision_model(hidden_states=pixel_values, grid_thw=image_grid_thw) merged_patches_per_image = [ ((h // spatial_merge_size) * (w // spatial_merge_size) * t).item() for t, h, w in image_grid_thw ] merged_cu_seqlens = torch.tensor( [0] + list(itertools.accumulate(merged_patches_per_image)), device=vision_output.device, ) image_features = [] for i in range(batch_size): start_idx = merged_cu_seqlens[i] end_idx = merged_cu_seqlens[i + 1] image_features.append(vision_output[start_idx:end_idx].mean(dim=0)) image_features = torch.stack(image_features) image_embeds = self.vision_projection(image_features) return image_embeds class CLIPQwen2VLWrapper(nn.Module): save_in_root: bool = True def __init__( self, model_name_or_path: str, cache_dir: str = None, backend: str = "torch", **kwargs ) -> None: super().__init__() model_args = kwargs.get("model_args", {}) if "torch_dtype" not in model_args: model_args["torch_dtype"] = torch.bfloat16 self.model = CLIPQwen2VLModel.from_pretrained( model_name_or_path, cache_dir=cache_dir, **model_args ) self.tokenizer = transformers.AutoTokenizer.from_pretrained("cl-nagoya/ruri-large") self.processor = transformers.AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") def __repr__(self) -> str: return "CLIPQwen2VLWrapper()" def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: image_embeds = [] text_embeds = [] if "pixel_values" in features: image_embeds = self.model.get_image_features( pixel_values=features["pixel_values"], image_grid_thw=features["image_grid_thw"], ) if "input_ids" in features: text_embeds = self.model.get_text_features( input_ids=features["input_ids"], attention_mask=features.get("attention_mask", None), position_ids=features.get("position_ids", None), output_attentions=features.get("output_attentions", None), output_hidden_states=features.get("output_hidden_states", None), ) sentence_embedding = [] image_features = iter(image_embeds) text_features = iter(text_embeds) for idx, input_type in enumerate(features["image_text_info"]): if input_type == 0: sentence_embedding.append(next(image_features)) else: sentence_embedding.append(next(text_features)) features["sentence_embedding"] = torch.stack(sentence_embedding).float() return features def tokenize( self, texts: List[Union[str, Image.Image]], padding: str | bool = True ) -> dict[str, torch.Tensor]: images = [] texts_values = [] image_text_info = [] for idx, data in enumerate(texts): if isinstance(data, Image.Image): images.append(data) image_text_info.append(0) else: texts_values.append(data) image_text_info.append(1) encoding = {} if len(texts_values): encoding = self.tokenizer( texts_values, return_tensors="pt", padding=padding, truncation=True, max_length=512, ) if len(images): image_features = self.processor.image_processor(images, return_tensors="pt") encoding.update(image_features) encoding["image_text_info"] = image_text_info return dict(encoding) @property def processor(self) -> transformers.PreTrainedModel: return self._processor @processor.setter def processor(self, processor): self._processor = processor def save(self, output_path: str) -> None: self.model.save_pretrained(output_path) self.tokenizer.save_pretrained(output_path) self.processor.save_pretrained(output_path) @staticmethod def load(input_path: str) -> CLIPQwen2VLWrapper: return CLIPQwen2VLWrapper(model_name_or_path=input_path)