File size: 442 Bytes
24c13a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers import PreTrainedModel
from .configuration_fc import FCConfig
from torch.nn import Linear

class FCModel(PreTrainedModel):
    config_class = FCConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = Linear(in_features=10, out_features=config.num_nodes)
    
    def forward(self, tensor):
        # Use as forward similar to forward in torch
        return self.model.forward(tensor)