my_model / custom_model.py
nsorros's picture
add model
1ca0f10
raw
history blame contribute delete
350 Bytes
from transformers import PreTrainedModel, BertConfig, AutoModel
class Model(PreTrainedModel):
config_class = BertConfig
def __init__(self, config):
super().__init__(config)
self.model = AutoModel.from_pretrained("bert-base-uncased")
def forward(self, **inputs):
outs = self.model(**inputs)
return outs