|
|
|
|
|
|
|
|
|
""" |
|
GottBERT: a pure German Language Model |
|
""" |
|
|
|
from fairseq.models import register_model |
|
|
|
from .hub_interface import RobertaHubInterface |
|
from .model import RobertaModel |
|
|
|
|
|
@register_model('gottbert') |
|
class GottbertModel(RobertaModel): |
|
|
|
@classmethod |
|
def hub_models(cls): |
|
return { |
|
'gottbert-base': 'https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz', |
|
} |
|
|
|
@classmethod |
|
def from_pretrained(cls, |
|
model_name_or_path, |
|
checkpoint_file='model.pt', |
|
data_name_or_path='.', |
|
bpe='hf_byte_bpe', |
|
bpe_vocab='vocab.json', |
|
bpe_merges='merges.txt', |
|
bpe_add_prefix_space=False, |
|
**kwargs |
|
): |
|
from fairseq import hub_utils |
|
|
|
x = hub_utils.from_pretrained( |
|
model_name_or_path, |
|
checkpoint_file, |
|
data_name_or_path, |
|
archive_map=cls.hub_models(), |
|
bpe=bpe, |
|
load_checkpoint_heads=True, |
|
bpe_vocab=bpe_vocab, |
|
bpe_merges=bpe_merges, |
|
bpe_add_prefix_space=bpe_add_prefix_space, |
|
**kwargs, |
|
) |
|
return RobertaHubInterface(x['args'], x['task'], x['models'][0]) |
|
|