clip-japanese-base / model.py
pfzhu's picture
Upload folder using huggingface_hub
071945c verified
raw
history blame contribute delete
No virus
17 kB
# coding=utf-8
# Copyright 2024 LY Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from typing import Optional, Union
import timm
import torch
import torch.distributed as dist
import torch.distributed.nn
import torch.nn as nn
import torch.nn.functional as F
from timm.models.swin_transformer import SwinTransformer as TimmSwinTransformer
from transformers import PreTrainedModel
from transformers.utils.logging import get_logger
from .configuration_clyp import (
CLYPTextBackboneConfig,
CLYPTextEncoderConfig,
CLYPVisionBackboneConfig,
CLYPVisionEncoderConfig,
)
from .model_rinna import RinnaCLIPConfig, RinnaCLIPModel
DEFAULT_LOGGER = get_logger(__name__)
class VisionEncoder(nn.Module):
"""Vision encoder to extract image feateurs.
Pooler and neck are optional.
Instead of defining pooler and neck in VisionEncoder, you can define them in algorithm classes.
Attributes:
backbone (nn.Module): backbone loaded from timm, huggingface or registry.
pooler (nn.Module): module to extract image-level features.
neck (nn.Module): module to adjust feature dimensions.
"""
def __init__(
self,
backbone: nn.Module,
pooler: Optional[nn.Module] = None,
neck: Optional[nn.Module] = None,
) -> None:
super().__init__()
self.backbone = backbone
self.pooler = pooler
self.neck = neck
def forward(self, imgs: torch.Tensor):
"""A method to extract image features.
Args:
imgs (torch.Tensor): shape=(batch_size, channels, height, width).
Returns:
out (torch.Tensor): the output shape changes depending on pooler, and the following shapes are usually expected.
- output only image-level features like CLIP: shape=(batch_size, embed_dim)
- output image-level and local patch features like BLIP2: shape=(batch_size, embed_dim, length)
"""
out = self.backbone(imgs) # Shape=(batch_size, channels, height, width)
if self.pooler:
out = self.pooler(out)
if self.neck:
out = self.neck(out)
return out
class SwinTransformerPerm(nn.Module):
"""Wrapper for SwinTransformer in timm.
This wrapper changes the output shape to (batch_size, channels, height, width).
The original shape of timm SwinTransformer is (batch_size, height, width, channels).
"""
def __init__(self, swin: nn.Module) -> None:
super().__init__()
self.swin = swin
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.swin(x)
out = out.permute(0, 3, 1, 2)
return out
def load_from_timm(
config: CLYPVisionBackboneConfig,
use_gradient_checkpointing: bool = False,
path_weights: Optional[str] = None,
logger: logging.Logger = DEFAULT_LOGGER,
):
"""Create a backbone using a method: timm.create_model.
Args:
config (TimmBackboneConfig): config fed to timm.create_model.
use_gradient_checkpointing (bool): True if use gradient checkpointing.
path_weights (str): path to weights for backbone initialization.
"""
# backbone
assert config is not None
backbone = timm.create_model(
model_name=config.model_name,
pretrained=config.pretrained,
**config.extra_kwargs,
)
backbone.reset_classifier(0, "")
logger.info(
f" - load from timm: model_name={config.model_name}, pretrained={config.pretrained}"
)
# gradient checkpointing
backbone.set_grad_checkpointing(enable=use_gradient_checkpointing)
if use_gradient_checkpointing:
logger.info(" - gradient checkpointing is enebled.")
# init weights
if path_weights:
state_dict = torch.load(path_weights, map_location="cpu")
checks = backbone.load_state_dict(state_dict, strict=False)
logger.info(f" - load weights from {path_weights}")
logger.info(f" - state dict checks: {checks}")
# swin
if isinstance(backbone, TimmSwinTransformer):
backbone = SwinTransformerPerm(backbone)
return backbone
def create_vision_encoder(
config: CLYPVisionEncoderConfig, logger: logging.Logger = DEFAULT_LOGGER
) -> VisionEncoder:
assert config.pooler_config.input_type
backbone = load_from_timm(config.backbone_config, logger=logger)
pooler = CLSTokenPooling(
config.pooler_config.input_type, config.pooler_config.return_patch_features
)
neck = Linear(
config.neck_config.in_channels,
config.neck_config.out_channels,
config.neck_config.bias,
)
return VisionEncoder(backbone, pooler=pooler, neck=neck)
class TextEncoder(nn.Module):
"""Text encoder to extract text features.
Pooler and neck are optional.
Instead of defining pooler and neck in TextEncoder, you can define them in algorithm classes.
Attributes:
backbone (nn.Module): backbone loaded from timm, huggingface or registry.
pooler (nn.Module): module to extract image-level features.
neck (nn.Module): module to adjust feature dimensions.
"""
def __init__(
self,
backbone: nn.Module,
pooler: Optional[nn.Module] = None,
neck: Optional[nn.Module] = None,
) -> None:
super().__init__()
self.backbone = backbone
self.pooler = pooler
self.neck = neck
def forward(self, inputs: dict) -> torch.Tensor:
"""A method to extract text features.
Args:
inputs (dict): basic keys are shown below:
- input_ids (torch.Tensor)
- attention_mask (Optional[torch.Tensor])
- position_ids (Optional[torch.Tensor])
- token_type_ids (Optional[torch.Tensor])
- output_attentions Optional[bool]
- output_hidden_states Optional[bool]
Returns:
out (torch.Tensor): the output shape changes depending on pooler, and the following shapes are usually expected.
- output only class token like CLIP: shape=(batch_size, embed_dim)
- output all token features like BLIP2: shape=(batch_size, embed_dim, length)
"""
out = self.backbone(**inputs)
if self.pooler:
out = self.pooler(out)
if self.neck:
out = self.neck(out)
return out
class TextBackboneModelWrapper(nn.Module):
def __init__(self, model: nn.Module) -> None:
super().__init__()
self.model = model.text_model
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
token_type_ids=token_type_ids,
)
return out
def set_gradient_checkpointing(self, enabled: bool) -> None:
if enabled:
self.model.gradient_checkpointing_enable()
def load_from_huggingface(
config: CLYPTextBackboneConfig,
use_gradient_checkpointing: bool = False,
path_weights: Optional[str] = None,
logger: logging.Logger = DEFAULT_LOGGER,
) -> nn.Module:
"""Load a backbone from huggingface.
Args:
config (HuggingfaceBackboneConfig): config fed to AutoModel.from_pretrained.
use_gradient_checkpointing (bool): True if use gradient checkpointing.
path_weights (str): path to weights for backbone initialization.
"""
# NOTE:
# Initialize Rinna CLIP without pretrained weights here,
# because CLYP model loads its whole weights afterward
auto_config = RinnaCLIPConfig.from_pretrained(config.model_name)
backbone = RinnaCLIPModel(auto_config)
logger.info(f" - load from huggingface: model_name={config.model_name}")
# gradient checkpointing
if isinstance(backbone, PreTrainedModel):
if use_gradient_checkpointing:
backbone.gradient_checkpointing_enable()
logger.info(" - gradient checkpointing is enabled")
else:
raise NotImplementedError()
# init weights
if path_weights:
raise NotImplementedError()
return backbone
def create_text_encoder(
config: CLYPTextEncoderConfig, logger: logging.Logger = DEFAULT_LOGGER
) -> TextEncoder:
assert config.pooler_config.input_type
backbone = TextBackboneModelWrapper(
load_from_huggingface(config.backbone_config, logger=logger)
)
pooler = CLSTokenPooling(
config.pooler_config.input_type, config.pooler_config.return_patch_features
)
neck = Linear(
config.neck_config.in_channels,
config.neck_config.out_channels,
bias=config.neck_config.bias,
)
return TextEncoder(backbone, pooler=pooler, neck=neck)
class Linear(nn.Module):
"""Linear layer."""
def __init__(self, in_channels: int, out_channels: int, bias: bool) -> None:
"""
Args:
in_channels (int): input feature dimension.
out_channels (out): output feature dimension.
bias (bool): True if use bias in nn.Linear.
"""
super().__init__()
self.linear = nn.Linear(in_channels, out_channels, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): shape=(batch_size, ..., in_channels).
Returns:
out (torch.Tensor): shape=(batch_size, ..., out_channels).
"""
out = self.linear(x)
return out
class CLSTokenPooling(nn.Module):
"""A module to extract class token."""
def __init__(self, input_type: str, return_patch_features: bool) -> None:
"""
Args:
input_type (str): timm or huggingface.
- If input_type is timm, x[:, 0] is extracted as a class token.
- If input_type is huggingface, x.last_hidden_state[:,0] is extracted as a class token.
return_patch_features (bool): True if output local features.
"""
super().__init__()
assert input_type in ["timm", "huggingface"]
self.input_type = input_type
self.return_patch_features = return_patch_features
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): shape=(batch_size, length, dim).
Returns:
out (torch.Tensor): shape=(batch_size, dim).
"""
# tensor: shape=(batch_size, length, dim)
if self.input_type == "timm":
assert x.ndim == 3, "CLSTokenPooling: dimension of input tensor must be 3."
if self.return_patch_features:
return x
else:
return x[:, 0]
# huggingface
elif self.input_type == "huggingface":
out = x.last_hidden_state
if self.return_patch_features:
return out
else:
return out[:, 0]
class InfoNCELoss(nn.Module):
def __init__(
self,
learn_temperature: bool,
init_temperature: float,
max_temperature: Optional[float] = None,
min_temperature: Optional[float] = None,
label_smoothing: float = 0.0,
gather_with_grad: bool = False,
):
super().__init__()
self.label_smoothing = label_smoothing
self.gather_with_grad = gather_with_grad
# set temperature
self.learn_temperature = learn_temperature
self.temperature = torch.ones([]) * init_temperature
if self.learn_temperature:
self.temperature = nn.Parameter(self.temperature)
self.max_temperature = max_temperature
self.min_temperature = min_temperature
# whether clip temperature or not
self.require_temperature_clipping = self.learn_temperature and (
self.max_temperature or self.min_temperature
)
def clip_temperature(self):
if self.require_temperature_clipping:
self.temperature.data = torch.clamp(
self.temperature, self.min_temperature, self.max_temperature
)
def forward(
self,
image_feats: torch.Tensor,
text_feats: torch.Tensor,
return_similarity: bool = False,
) -> Union[torch.Tensor, tuple[torch.Tensor]]:
# gather image and text features
image_feats_all = concat_all_gather(
image_feats, with_grad=self.gather_with_grad
)
text_feats_all = concat_all_gather(text_feats, with_grad=self.gather_with_grad)
# compute cosine similarity
sim_i2t = image_to_text_similarity(
image_feats=image_feats,
text_feats=text_feats_all,
)
sim_t2i = text_to_image_similarity(
text_feats=text_feats,
image_feats=image_feats_all,
)
# logits, scaled cosine similarity
logits_i2t = sim_i2t / self.temperature
logits_t2i = sim_t2i / self.temperature
# obtain targets
rank = dist.get_rank()
batch_size = image_feats.size(0)
targets = torch.arange(batch_size) + batch_size * rank
targets = targets.to(dtype=torch.long, device=image_feats.device)
# calculate loss
loss_i2t = F.cross_entropy(
logits_i2t, targets, label_smoothing=self.label_smoothing
)
loss_t2i = F.cross_entropy(
logits_t2i, targets, label_smoothing=self.label_smoothing
)
loss = (loss_i2t + loss_t2i) / 2.0
if not return_similarity:
return loss
else:
return loss, sim_i2t, sim_t2i
def image_to_text_similarity(
image_feats: torch.Tensor, text_feats: torch.Tensor
) -> torch.Tensor:
"""
Args:
image_feats (torch.Tensor): shape=(num_imgs, embed_dim) or (num_imgs, num_query_tokens, embed_dim).
text_feats (torch.Tensor): shape=(num_texts, embed_dim).
Returns:
sim_i2t (torch.Tensor): shape=(num_imgs, num_texts).
"""
assert image_feats.ndim in [2, 3]
assert text_feats.ndim == 2
# normalize features
image_feats = F.normalize(image_feats, dim=-1)
text_feats = F.normalize(text_feats, dim=-1)
if image_feats.ndim == 2:
sim_i2t = image_feats @ text_feats.T
else:
# a query token with maximum cosine similarity is selected
sim_i2t = torch.matmul(
image_feats.unsqueeze(1), text_feats.unsqueeze(0).unsqueeze(-1)
).squeeze() # shape=(num_imgs, num_texts, num_query_tokens)
sim_i2t, _ = sim_i2t.max(dim=-1) # shape=(num_imgs, num_texts)
return sim_i2t
def text_to_image_similarity(text_feats: torch.Tensor, image_feats: torch.Tensor):
"""
Args:
text_feats (torch.Tensor): shape=(num_texts, embed_dim).
image_feats (torch.Tensor): shape=(num_imgs, embed_dim) or (num_imgs, num_query_tokens, embed_dim).
Returns:
similarity_maxtrix (torch.Tensor): shape=(num_texts, num_imgs).
"""
assert image_feats.ndim in [2, 3]
assert text_feats.ndim == 2
# normalize features
image_feats = F.normalize(image_feats, dim=-1)
text_feats = F.normalize(text_feats, dim=-1)
if image_feats.ndim == 2:
sim_t2i = text_feats @ image_feats.T
else:
# a query token with maximum cosine similarity is selected
sim_t2i = torch.matmul(
text_feats.unsqueeze(1).unsqueeze(1),
image_feats.permute(0, 2, 1).unsqueeze(0),
).squeeze()
sim_t2i, _ = sim_t2i.max(dim=-1)
return sim_t2i
def concat_all_gather(tensor: torch.Tensor, with_grad: bool):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
Another implementation: https://github.com/salesforce/LAVIS/blob/main/lavis/models/base_model.py#L202-L237
"""
if with_grad:
output = torch.cat(torch.distributed.nn.all_gather(tensor), dim=0)
else:
tensors_gather = [torch.ones_like(tensor) for _ in range(dist.get_world_size())]
dist.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output