File size: 442 Bytes
24c13a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
from transformers import PreTrainedModel
from .configuration_fc import FCConfig
from torch.nn import Linear
class FCModel(PreTrainedModel):
config_class = FCConfig
def __init__(self, config):
super().__init__(config)
self.model = Linear(in_features=10, out_features=config.num_nodes)
def forward(self, tensor):
# Use as forward similar to forward in torch
return self.model.forward(tensor)
|