File size: 361 Bytes
42d0fc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from transformers import PretrainedConfig
from typing import List


class TestConfig(PretrainedConfig):
    model_type = "my_test_model"

    def __init__(
        self,
        input_dim: int = 20,
        output_dim: int = 10,
        **kwargs,
    ):
        self.input_dim = input_dim
        self.output_dim = output_dim
        super().__init__(**kwargs)