|
from copy import deepcopy |
|
from typing import Optional |
|
|
|
import torch |
|
from transformers import AutoConfig, VisionTextDualEncoderConfig |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class CustomCLIPPooler(torch.nn.Module): |
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
first_token_tensor = hidden_states[:, 0, :] |
|
return first_token_tensor |
|
|
|
|
|
def get_text_model_pooler(text_model_pooler: str) -> torch.nn.Module: |
|
if text_model_pooler == "CustomCLIPPooler": |
|
return CustomCLIPPooler |
|
else: |
|
raise ValueError(f"Unrecognized text model pooler type {text_model_pooler!r}.") |
|
|
|
|
|
def is_valid_text_model_pooler( |
|
text_model_pooler: str, suppress_error: bool = False |
|
) -> bool: |
|
try: |
|
get_text_model_pooler(text_model_pooler) |
|
except ValueError: |
|
if not suppress_error: |
|
raise |
|
return False |
|
else: |
|
return True |
|
|
|
|
|
class CustomCLIPConfig(VisionTextDualEncoderConfig): |
|
model_type = "custom-clip-model" |
|
|
|
DEFAULT_TEXT_MODEL_POOLER_STR: str = "CustomCLIPPooler" |
|
DEFAULT_TEXT_MODEL_POOLER_KWARGS: dict = {} |
|
|
|
def __init__( |
|
self, |
|
*args, |
|
text_model_pooler: Optional[str] = None, |
|
text_model_pooler_kwargs: Optional[dict] = None, |
|
**kwargs, |
|
): |
|
super().__init__(*args, **kwargs) |
|
|
|
self.text_model_pooler = ( |
|
self.DEFAULT_TEXT_MODEL_POOLER_STR |
|
if text_model_pooler is None |
|
else text_model_pooler |
|
) |
|
is_valid_text_model_pooler(self.text_model_pooler, suppress_error=False) |
|
|
|
self.text_model_pooler_kwargs = ( |
|
self.DEFAULT_TEXT_MODEL_POOLER_KWARGS |
|
if text_model_pooler_kwargs is None |
|
else text_model_pooler_kwargs |
|
) |
|
|
|
@classmethod |
|
def from_base(cls, obj: VisionTextDualEncoderConfig): |
|
if not isinstance(obj, cls): |
|
base = VisionTextDualEncoderConfig |
|
if not isinstance(obj, base): |
|
raise TypeError(f"obj must be of type {cls!r} or {base!r}.") |
|
obj = deepcopy(obj) |
|
logger.warning(f"Changing config class from {obj.__class__!r} to {cls!r}.") |
|
obj.__class__ = cls |
|
|
|
def setattr_with_warning(object, name, value): |
|
logger.warning(f"Setting {name!r} to {value!r}.") |
|
setattr(object, name, value) |
|
|
|
setattr_with_warning( |
|
obj, "text_model_pooler", cls.DEFAULT_TEXT_MODEL_POOLER_STR |
|
) |
|
setattr_with_warning( |
|
obj, "text_model_pooler_kwargs", cls.DEFAULT_TEXT_MODEL_POOLER_KWARGS |
|
) |
|
return obj |
|
|
|
|
|
AutoConfig.register(CustomCLIPConfig.model_type, CustomCLIPConfig) |
|
|