|
""" |
|
Subclasses VisionTextDualEncoderModel to customize text pooler. |
|
""" |
|
|
|
from typing import Optional |
|
|
|
import torch |
|
from transformers import AutoModel, VisionTextDualEncoderModel |
|
|
|
from .configuration_custom_clip import CustomCLIPConfig, get_text_model_pooler |
|
|
|
|
|
|
|
class CustomCLIPModel(VisionTextDualEncoderModel): |
|
config_class = CustomCLIPConfig |
|
|
|
DEFAULT_TEXT_MODEL_POOLER_TYPE: torch.nn.Module = get_text_model_pooler( |
|
CustomCLIPConfig.DEFAULT_TEXT_MODEL_POOLER_STR |
|
) |
|
DEFAULT_TEXT_MODEL_POOLER_KWARGS: dict = ( |
|
CustomCLIPConfig.DEFAULT_TEXT_MODEL_POOLER_KWARGS |
|
) |
|
|
|
def __init__( |
|
self, config: Optional[CustomCLIPConfig.__base__] = None, *args, **kwargs |
|
): |
|
config = config if config is None else CustomCLIPConfig.from_base(config) |
|
super().__init__( |
|
config, |
|
*args, |
|
**kwargs, |
|
) |
|
|
|
self.text_model.pooler = ( |
|
(self.DEFAULT_TEXT_MODEL_POOLER_TYPE)( |
|
**self.DEFAULT_TEXT_MODEL_POOLER_KWARGS |
|
) |
|
if config is None |
|
else get_text_model_pooler(config.text_model_pooler)( |
|
**config.text_model_pooler_kwargs |
|
) |
|
) |
|
|
|
|
|
AutoModel.register(CustomCLIPConfig, CustomCLIPModel) |
|
|