|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
from typing import Optional |
|
|
|
import torch |
|
from transformers import BatchEncoding, PreTrainedTokenizer, T5Tokenizer |
|
from transformers.tokenization_utils_base import ( |
|
PaddingStrategy, |
|
PreTokenizedInput, |
|
TextInput, |
|
TruncationStrategy, |
|
) |
|
|
|
|
|
class CLYPTokenizer(PreTrainedTokenizer): |
|
"""CLYPTokenizer based on rinna/japanese-roberta-base |
|
|
|
This tokenizer is registered as a custom tokenizer to manually add CLS token to each text. |
|
""" |
|
|
|
def __init__(self, max_length: int, padding: str, truncation: bool, **kwargs): |
|
|
|
self.tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base") |
|
self.tokenizer.do_lower_case = True |
|
|
|
super().__init__( |
|
max_length=max_length, padding=padding, truncation=truncation, **kwargs |
|
) |
|
self.max_length = max_length |
|
self.padding = padding |
|
self.truncation = truncation |
|
|
|
@property |
|
def vocab_size(self): |
|
return self.tokenizer.vocab_size |
|
|
|
def get_vocab(self) -> dict[str, int]: |
|
return self.tokenizer.get_vocab() |
|
|
|
def save_vocabulary( |
|
self, save_directory: str, filename_prefix: Optional[str] = None |
|
) -> tuple[str]: |
|
return self.tokenizer.save_vocabulary( |
|
save_directory, filename_prefix=filename_prefix |
|
) |
|
|
|
def _tokenize(self, text, **kwargs): |
|
return self.tokenizer._tokenize(text, **kwargs) |
|
|
|
def _convert_token_to_id(self, token): |
|
return self.tokenizer._convert_token_to_id(token) |
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
return self.tokenizer._convert_id_to_token(index) |
|
|
|
def __call__( |
|
self, |
|
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput], |
|
add_special_tokens: bool = True, |
|
padding: bool | str | PaddingStrategy | None = None, |
|
truncation: bool | str | TruncationStrategy | None = None, |
|
max_length: Optional[int] = None, |
|
**kwargs, |
|
): |
|
if max_length is None: |
|
max_length = self.max_length |
|
if padding is None: |
|
padding = self.padding |
|
if truncation is None: |
|
truncation = self.truncation |
|
|
|
if add_special_tokens: |
|
max_length = max_length - 1 |
|
|
|
if not isinstance(text, list): |
|
|
|
text = [text] |
|
|
|
out = self.tokenizer( |
|
text, |
|
max_length=max_length, |
|
padding=padding, |
|
truncation=truncation, |
|
add_special_tokens=False, |
|
**kwargs, |
|
) |
|
|
|
if add_special_tokens: |
|
input_ids = [ |
|
[self.tokenizer.cls_token_id] + ids for ids in out["input_ids"] |
|
] |
|
attention_mask = [[1] + am for am in out["attention_mask"]] |
|
position_ids = [list(range(0, len(input_ids[0])))] * len(input_ids) |
|
else: |
|
input_ids = out["input_ids"] |
|
attention_mask = out["attention_mask"] |
|
position_ids = [list(range(0, len(input_ids[0])))] * len(input_ids) |
|
|
|
|
|
input_ids = torch.tensor(input_ids, dtype=torch.long) |
|
attention_mask = torch.tensor(attention_mask, dtype=torch.long) |
|
position_ids = torch.tensor(position_ids, dtype=torch.long) |
|
|
|
|
|
data = { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"position_ids": position_ids, |
|
} |
|
return BatchEncoding(data=data) |
|
|