|
from tokenizers import Tokenizer |
|
from transformers import GPT2Tokenizer, GPT2TokenizerFast |
|
|
|
def fetch_encoder(params): |
|
no_dataset = params.get('no_dataset', False) |
|
if no_dataset: |
|
return None |
|
|
|
dataset = next(iter(params['dataset_configs'].values())) |
|
path = dataset["tokenizer_path"] |
|
is_pretrained = dataset.get("tokenizer_is_pretrained", False) |
|
|
|
if is_pretrained: |
|
tok = GPT2TokenizerFast.from_pretrained(path) |
|
|
|
|
|
tok.add_special_tokens({'pad_token': '<|padding|>'}) |
|
return tok |
|
|
|
return Tokenizer.from_file(path) |
|
|
|
|
|
|
|
def encode(encoder, text, gpt=True): |
|
result = encoder.encode(text) |
|
if isinstance(result, list): |
|
return result |
|
return result.ids |
|
|