japanese-clip-vit-h-14-bert-base / modeling_custom_clip.py
bsyx001's picture
Upload model
8b26dd7 verified
"""
Subclasses VisionTextDualEncoderModel to customize text pooler.
"""
from typing import Optional
import torch
from transformers import AutoModel, VisionTextDualEncoderModel
from .configuration_custom_clip import CustomCLIPConfig, get_text_model_pooler
# @add_start_docstrings(CUSTOM_CLIP_START_DOCSTRING)
class CustomCLIPModel(VisionTextDualEncoderModel):
config_class = CustomCLIPConfig
DEFAULT_TEXT_MODEL_POOLER_TYPE: torch.nn.Module = get_text_model_pooler(
CustomCLIPConfig.DEFAULT_TEXT_MODEL_POOLER_STR
)
DEFAULT_TEXT_MODEL_POOLER_KWARGS: dict = (
CustomCLIPConfig.DEFAULT_TEXT_MODEL_POOLER_KWARGS
)
def __init__(
self, config: Optional[CustomCLIPConfig.__base__] = None, *args, **kwargs
):
config = config if config is None else CustomCLIPConfig.from_base(config)
super().__init__(
config, # surprisingly, `super` is unnecessary, possibly due to implementation of CustomCLIPConfig.__init__?
*args,
**kwargs,
)
self.text_model.pooler = (
(self.DEFAULT_TEXT_MODEL_POOLER_TYPE)(
**self.DEFAULT_TEXT_MODEL_POOLER_KWARGS
)
if config is None
else get_text_model_pooler(config.text_model_pooler)(
**config.text_model_pooler_kwargs
)
)
AutoModel.register(CustomCLIPConfig, CustomCLIPModel)