from __future__ import annotations from transformers import PretrainedConfig from transformers import PreTrainedModel from torch import nn import torch class FastTextJpConfig(PretrainedConfig): model_type = "fasttext_jp" def __init__(self, **kwargs): super().__init__(**kwargs) class FastTextJpModel(PreTrainedModel): """FastTextのEmbeddingを行います。 """ config_class = FastTextJpConfig def __init__(self, config: FastTextJpConfig): super().__init__(config) self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) def forward(self, input_ids, **kwargs): return self.word_embeddings(torch.tensor([0])) FastTextJpConfig.register_for_auto_class() FastTextJpModel.register_for_auto_class("AutoModel")