File size: 2,810 Bytes
d1a642c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a108e1a
d1a642c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import re
from sentencepiece import SentencePieceProcessor


def replace_spaces_with_blank(match: re.Match[str]):
    return f"<|blank_{len(match.group())}|>"


def replace_blank_with_spaces(match: re.Match[str]):
    return " " * int(match.group(1))


class ChatGLMTokenizer:
    def __init__(self, vocab_file):
        assert vocab_file is not None
        self.vocab_file = vocab_file
        self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "<unused_0>", "<sop>", "<eop>", "<ENC>", "<dBLOCK>"]
        self.text_tokenizer = SentencePieceProcessor(str(vocab_file))

    def __len__(self):
        return len(self.text_tokenizer)

    def __getitem__(self, key: str):
        return self.text_tokenizer[key]


    def preprocess(self, text: str, linebreak=True, whitespaces=True):
        if linebreak:
            text = text.replace("\n", "<n>")
        if whitespaces:
            text = text.replace("\t", "<|tab|>")
            text = re.sub(r" {2,80}", replace_spaces_with_blank, text)
        return text


    def encode(
        self, text: str, text_pair: str = None,
        linebreak=True, whitespaces=True,
        add_dummy_prefix=True, special_tokens=True,
    ) -> tuple[list[int], list[int]]:
        """
        text: Text to encode. Bidirectional part with a [gMASK] and an <sop> for causal LM.
        text_pair: causal LM part.
        linebreak: Whether to encode newline (\n) in text.
        whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
        special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
        add_dummy_prefix: Whether to add dummy blank space in the beginning.
        """
        text = self.preprocess(text, linebreak, whitespaces)
        if not add_dummy_prefix:
            text = "<n>" + text

        tokens = self.text_tokenizer.encode(text)
        prefix_mask = [1] * len(tokens)
        if special_tokens:
            tokens += [self.text_tokenizer["[gMASK]"], self.text_tokenizer["<sop>"]]
            prefix_mask += [1, 0]

        if text_pair is not None:
            text_pair = self.preprocess(text_pair, linebreak, whitespaces)
            pair_tokens = self.text_tokenizer.encode(text_pair)
            tokens += pair_tokens
            prefix_mask += [0] * len(pair_tokens)
            if special_tokens:
                tokens += [self.text_tokenizer["<eop>"]]
                prefix_mask += [0]

        return (tokens if add_dummy_prefix else tokens[2:]), prefix_mask


    def decode(self, text_ids: list[int]) -> str:
        text = self.text_tokenizer.decode(text_ids)
        text = text.replace("<n>", "\n")
        text = text.replace("<|tab|>", "\t")
        text = re.sub(r"<\|blank_(\d\d?)\|>", replace_blank_with_spaces, text)
        return text