"""Fast tokenization classes for Shami.""" import json from typing import TYPE_CHECKING, List, Optional, Tuple from tokenizers import pre_tokenizers from transformers.tokenization_utils_base import BatchEncoding from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from transformers.utils import logging if TYPE_CHECKING: from transformers.pipelines.conversational import Conversation logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"tokenizer_file": "tokenizer.json"} PRETRAINED_VOCAB_FILES_MAP = { "tokenizer_file": { }, } class ShamiTokenizerFast(PreTrainedTokenizerFast): vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP model_input_names = ["input_ids", "attention_mask"] slow_tokenizer_class = None def __init__( self, vocab_file=None, merges_file=None, tokenizer_file=None, unk_token="<|endoftext|>", bos_token="<|endoftext|>", eos_token="<|endoftext|>", pad_token="<|endoftext|>", add_prefix_space=False, **kwargs ): super().__init__( vocab_file, merges_file, tokenizer_file=tokenizer_file, unk_token=unk_token, bos_token=bos_token, eos_token=eos_token, pad_token=pad_token, add_prefix_space=add_prefix_space, **kwargs, ) pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) pre_tok_state["add_prefix_space"] = add_prefix_space self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) self.add_prefix_space = add_prefix_space def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding: is_split_into_words = kwargs.get("is_split_into_words", False) if not (self.add_prefix_space or not is_split_into_words): raise Exception( f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with" " pretokenized inputs." ) return super()._batch_encode_plus(*args, **kwargs) def _encode_plus(self, *args, **kwargs) -> BatchEncoding: is_split_into_words = kwargs.get("is_split_into_words", False) if not (self.add_prefix_space or not is_split_into_words): raise Exception( f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with" " pretokenized inputs." ) return super()._encode_plus(*args, **kwargs) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: files = self._tokenizer.model.save(save_directory, name=filename_prefix) return tuple(files) def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]: """This corresponds to DialoGPT variants of models.""" input_ids = [] for is_user, text in conversation.iter_texts(): input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id]) if len(input_ids) > self.model_max_length: input_ids = input_ids[-self.model_max_length :] return input_ids