ProtST-ESM1b / configuration_protst.py
MilaDeepGraph's picture
init from Jiqing's repo
314a644 verified
raw
history blame
1.9 kB
from transformers import PretrainedConfig
from transformers.utils import logging
from transformers.models.esm import EsmConfig
from transformers.models.bert import BertConfig
logger = logging.get_logger(__name__)
class ProtSTConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ProtSTModel`].
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
protein_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`EsmForProteinRepresentation`].
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`BertForPubMed`].
```"""
model_type = "protst"
def __init__(
self,
protein_config=None,
text_config=None,
**kwargs,
):
super().__init__(**kwargs)
if protein_config is None:
protein_config = {}
logger.info("`protein_config` is `None`. Initializing the `ProtSTTextConfig` with default values.")
if text_config is None:
text_config = {}
logger.info("`text_config` is `None`. Initializing the `ProtSTVisionConfig` with default values.")
self.protein_config = EsmConfig(**protein_config)
self.text_config = BertConfig(**text_config)
@classmethod
def from_protein_text_configs(
cls, protein_config: EsmConfig, text_config: BertConfig, **kwargs
):
r"""
Instantiate a [`ProtSTConfig`] (or a derived class) from ProtST text model configuration. Returns:
[`ProtSTConfig`]: An instance of a configuration object
"""
return cls(protein_config=protein_config.to_dict(), text_config=text_config.to_dict(), **kwargs)