|
from transformers import PretrainedConfig |
|
from transformers.utils import logging |
|
from transformers.models.esm import EsmConfig |
|
from transformers.models.bert import BertConfig |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class ProtSTConfig(PretrainedConfig): |
|
r""" |
|
This is the configuration class to store the configuration of a [`ProtSTModel`]. |
|
|
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
|
documentation from [`PretrainedConfig`] for more information. |
|
|
|
Args: |
|
protein_config (`dict`, *optional*): |
|
Dictionary of configuration options used to initialize [`EsmForProteinRepresentation`]. |
|
text_config (`dict`, *optional*): |
|
Dictionary of configuration options used to initialize [`BertForPubMed`]. |
|
```""" |
|
|
|
model_type = "protst" |
|
|
|
def __init__( |
|
self, |
|
protein_config=None, |
|
text_config=None, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
|
|
if protein_config is None: |
|
protein_config = {} |
|
logger.info("`protein_config` is `None`. Initializing the `ProtSTTextConfig` with default values.") |
|
|
|
if text_config is None: |
|
text_config = {} |
|
logger.info("`text_config` is `None`. Initializing the `ProtSTVisionConfig` with default values.") |
|
|
|
self.protein_config = EsmConfig(**protein_config) |
|
self.text_config = BertConfig(**text_config) |
|
|
|
@classmethod |
|
def from_protein_text_configs( |
|
cls, protein_config: EsmConfig, text_config: BertConfig, **kwargs |
|
): |
|
r""" |
|
Instantiate a [`ProtSTConfig`] (or a derived class) from ProtST text model configuration. Returns: |
|
[`ProtSTConfig`]: An instance of a configuration object |
|
""" |
|
|
|
return cls(protein_config=protein_config.to_dict(), text_config=text_config.to_dict(), **kwargs) |