brainiac / configuration_brainiac.py
Divytak's picture
Upload configuration_brainiac.py with huggingface_hub
a7a5e2e verified
raw
history blame contribute delete
605 Bytes
from transformers import PretrainedConfig
class BrainiacConfig(PretrainedConfig):
model_type = "brainiac"
def __init__(
self,
in_channels: int = 1,
out_features: int = 2048,
**kwargs
):
self.in_channels = in_channels
self.out_features = out_features
super().__init__(**kwargs)
@classmethod
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs):
config_dict = super().get_config_dict(pretrained_model_name_or_path, **kwargs)
config_dict["model_type"] = "brainiac"
return config_dict