from transformers import PreTrainedModel from .configuration_fc import FCConfig from torch.nn import Linear class FCModel(PreTrainedModel): config_class = FCConfig def __init__(self, config): super().__init__(config) self.model = Linear(in_features=10, out_features=config.num_nodes) def forward(self, tensor): # Use as forward similar to forward in torch return self.model.forward(tensor)