edugp commited on
Commit
13451d2
1 Parent(s): c1a6926

Add model card with sample usage

Browse files
Files changed (2) hide show
  1. README.md +31 -0
  2. model.py +163 -0
README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # KenLM models
2
+ This repo contains several KenLM models trained on different tokenized datasets and languages.
3
+ KenLM models are probabilistic n-gram languge models that models. One use case of these models consist on fast perplexity estimation for [filtering or sampling large datasets](https://huggingface.co/bertin-project/bertin-roberta-base-spanish). For example, one could use a KenLM model trained on French Wikipedia to run inference on a large dataset and filter out samples that are very unlike to appear on Wikipedia (high perplexity), or very simple non-informative sentences that could appear repeatedly (low perplexity).
4
+
5
+ At the root of this repo you will find different directories named after the dataset models were trained on (e.g. `wikipedia`, `oscar`). Within each directory, you will find several models trained on different language subsets of the dataset (e.g. `en (English)`, `es (Spanish)`, `fr (French)`). For each language you will find three different files
6
+ * `{language}.arpa.bin`: The trained KenLM model binary
7
+ * `{language}.sp.model`: The trained SentencePiece model used for tokenization
8
+ * `{language}.sp.vocab`: The vocabulary file for the SentencePiece model
9
+
10
+ The models have been trained using some of the preprocessing steps from [cc_net](https://github.com/facebookresearch/cc_net), in particular replacing numbers with zeros and normalizing punctuation. So, it is important to keep the default values for the parameters: `lower_case`, `remove_accents`, `normalize_numbers` and `punctuation` when using the pre-trained models in order to replicate the same pre-processing steps at inference time.
11
+
12
+ # Dependencies
13
+ * KenLM: `pip install https://github.com/kpu/kenlm/archive/master.zip`
14
+ * SentencePiece: `pip install https://github.com/kpu/kenlm/archive/master.zip`
15
+
16
+ # Example:
17
+ ```
18
+ from model import KenlmModel
19
+
20
+
21
+ # Load model trained on English wikipedia
22
+ model = KenlmModel.from_pretrained("wikipedia", "en")
23
+
24
+ # Get perplexity
25
+ model.get_perplexity("I am very perplexed")
26
+ # 341.3 (low perplexity, since sentence style is formal and with no grammar mistakes)
27
+
28
+ model.get_perplexity("im hella trippin")
29
+ # 46793.5 (high perplexity, since the sentence is colloquial and contains grammar mistakes)
30
+ ```
31
+ In the example above we see that, since Wikipedia is a collection of encyclopedic articles, a KenLM model trained on it will naturally give lower perplexity scores to sentences with formal language and no grammar mistakes than colloquial sentences with grammar mistakes.
model.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import unicodedata
4
+ from typing import Dict
5
+
6
+ import kenlm
7
+ import sentencepiece
8
+ from huggingface_hub import cached_download, hf_hub_url
9
+
10
+
11
+ class SentencePiece:
12
+ def __init__(
13
+ self,
14
+ model: str,
15
+ ):
16
+ super().__init__()
17
+ self.sp = sentencepiece.SentencePieceProcessor()
18
+ self.sp.load(str(model))
19
+
20
+ def do(self, text: dict) -> dict:
21
+ tokenized = self.sp.encode_as_pieces(text)
22
+ return " ".join(tokenized)
23
+
24
+
25
+ class KenlmModel:
26
+ digit_re: re.Pattern = re.compile(r"\d")
27
+ unicode_punct: Dict[str, str] = {
28
+ ",": ",",
29
+ "。": ".",
30
+ "、": ",",
31
+ "„": '"',
32
+ "”": '"',
33
+ "“": '"',
34
+ "«": '"',
35
+ "»": '"',
36
+ "1": '"',
37
+ "」": '"',
38
+ "「": '"',
39
+ "《": '"',
40
+ "》": '"',
41
+ "´": "'",
42
+ "∶": ":",
43
+ ":": ":",
44
+ "?": "?",
45
+ "!": "!",
46
+ "(": "(",
47
+ ")": ")",
48
+ ";": ";",
49
+ "–": "-",
50
+ "—": " - ",
51
+ ".": ". ",
52
+ "~": "~",
53
+ "’": "'",
54
+ "…": "...",
55
+ "━": "-",
56
+ "〈": "<",
57
+ "〉": ">",
58
+ "【": "[",
59
+ "】": "]",
60
+ "%": "%",
61
+ "►": "-",
62
+ }
63
+ unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]")
64
+ non_printing_chars_re = re.compile(
65
+ f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
66
+ )
67
+ kenlm_model_dir = None
68
+ sentence_piece_model_dir = None
69
+
70
+ def __init__(
71
+ self,
72
+ model_dataset: str,
73
+ language: str,
74
+ lower_case: bool = False,
75
+ remove_accents: bool = False,
76
+ normalize_numbers: bool = True,
77
+ punctuation: int = 1,
78
+ ):
79
+ self.model = kenlm.Model(os.path.join(model_dataset, f"{language}.arpa.bin"))
80
+ self.tokenizer = SentencePiece(os.path.join(model_dataset, f"{language}.sp.model"))
81
+ self.accent = remove_accents
82
+ self.case = lower_case
83
+ self.numbers = normalize_numbers
84
+ self.punct = punctuation
85
+
86
+ @classmethod
87
+ def from_pretrained(
88
+ cls,
89
+ model_dataset: str,
90
+ language: str,
91
+ ):
92
+ return cls(
93
+ model_dataset,
94
+ language,
95
+ False,
96
+ False,
97
+ True,
98
+ 1,
99
+ )
100
+
101
+ def pp(self, log_score, length):
102
+ return 10.0 ** (-log_score / length)
103
+
104
+ def get_perplexity(self, doc: str, normalize_cc_net: bool = True):
105
+ if normalize_cc_net:
106
+ doc = self.normalize(
107
+ doc,
108
+ accent=self.accent,
109
+ case=self.case,
110
+ numbers=self.numbers,
111
+ punct=self.punct,
112
+ )
113
+ # Tokenize (after normalizing): See https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/mine.py#L352 for full pipeline
114
+ doc = self.tokenizer.do(doc)
115
+ doc_log_score, doc_length = 0, 0
116
+ for line in doc.split("\n"):
117
+ log_score = self.model.score(line)
118
+ length = len(line.split()) + 1
119
+ doc_log_score += log_score
120
+ doc_length += length
121
+ return round(self.pp(doc_log_score, doc_length), 1)
122
+
123
+ def normalize(
124
+ self,
125
+ line: str,
126
+ accent: bool = True,
127
+ case: bool = True,
128
+ numbers: bool = True,
129
+ punct: int = 1,
130
+ ) -> str:
131
+ line = line.strip()
132
+ if not line:
133
+ return line
134
+ if case:
135
+ line = line.lower()
136
+ if accent:
137
+ line = self.strip_accents(line)
138
+ if numbers:
139
+ line = self.digit_re.sub("0", line)
140
+ if punct == 1:
141
+ line = self.replace_unicode_punct(line)
142
+ elif punct == 2:
143
+ line = self.remove_unicode_punct(line)
144
+ line = self.remove_non_printing_char(line)
145
+ return line
146
+
147
+ def strip_accents(self, line: str) -> str:
148
+ """Strips accents from a piece of text."""
149
+ nfd = unicodedata.normalize("NFD", line)
150
+ output = [c for c in nfd if unicodedata.category(c) != "Mn"]
151
+ if len(output) == line:
152
+ return line
153
+ return "".join(output)
154
+
155
+ def replace_unicode_punct(self, text: str) -> str:
156
+ return "".join(self.unicode_punct.get(c, c) for c in text)
157
+
158
+ def remove_unicode_punct(self, text: str) -> str:
159
+ """More aggressive version of replace_unicode_punct but also faster."""
160
+ return self.unicode_punct_re.sub("", text)
161
+
162
+ def remove_non_printing_char(self, text: str) -> str:
163
+ return self.non_printing_chars_re.sub("", text)