from transformers import PretrainedConfig | |
class BrainiacConfig(PretrainedConfig): | |
model_type = "brainiac" | |
def __init__( | |
self, | |
in_channels: int = 1, | |
out_features: int = 2048, | |
**kwargs | |
): | |
self.in_channels = in_channels | |
self.out_features = out_features | |
super().__init__(**kwargs) | |
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs): | |
config_dict = super().get_config_dict(pretrained_model_name_or_path, **kwargs) | |
config_dict["model_type"] = "brainiac" | |
return config_dict |