File size: 605 Bytes
a7a5e2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from transformers import PretrainedConfig

class BrainiacConfig(PretrainedConfig):
    model_type = "brainiac"
    
    def __init__(
        self,
        in_channels: int = 1,
        out_features: int = 2048,
        **kwargs
    ):
        self.in_channels = in_channels
        self.out_features = out_features
        super().__init__(**kwargs)

    @classmethod
    def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs):
        config_dict = super().get_config_dict(pretrained_model_name_or_path, **kwargs)
        config_dict["model_type"] = "brainiac"
        return config_dict