Add model card with sample usage
Browse files
README.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# KenLM models
|
2 |
+
This repo contains several KenLM models trained on different tokenized datasets and languages.
|
3 |
+
KenLM models are probabilistic n-gram languge models that models. One use case of these models consist on fast perplexity estimation for [filtering or sampling large datasets](https://huggingface.co/bertin-project/bertin-roberta-base-spanish). For example, one could use a KenLM model trained on French Wikipedia to run inference on a large dataset and filter out samples that are very unlike to appear on Wikipedia (high perplexity), or very simple non-informative sentences that could appear repeatedly (low perplexity).
|
4 |
+
|
5 |
+
At the root of this repo you will find different directories named after the dataset models were trained on (e.g. `wikipedia`, `oscar`). Within each directory, you will find several models trained on different language subsets of the dataset (e.g. `en (English)`, `es (Spanish)`, `fr (French)`). For each language you will find three different files
|
6 |
+
* `{language}.arpa.bin`: The trained KenLM model binary
|
7 |
+
* `{language}.sp.model`: The trained SentencePiece model used for tokenization
|
8 |
+
* `{language}.sp.vocab`: The vocabulary file for the SentencePiece model
|
9 |
+
|
10 |
+
The models have been trained using some of the preprocessing steps from [cc_net](https://github.com/facebookresearch/cc_net), in particular replacing numbers with zeros and normalizing punctuation. So, it is important to keep the default values for the parameters: `lower_case`, `remove_accents`, `normalize_numbers` and `punctuation` when using the pre-trained models in order to replicate the same pre-processing steps at inference time.
|
11 |
+
|
12 |
+
# Dependencies
|
13 |
+
* KenLM: `pip install https://github.com/kpu/kenlm/archive/master.zip`
|
14 |
+
* SentencePiece: `pip install https://github.com/kpu/kenlm/archive/master.zip`
|
15 |
+
|
16 |
+
# Example:
|
17 |
+
```
|
18 |
+
from model import KenlmModel
|
19 |
+
|
20 |
+
|
21 |
+
# Load model trained on English wikipedia
|
22 |
+
model = KenlmModel.from_pretrained("wikipedia", "en")
|
23 |
+
|
24 |
+
# Get perplexity
|
25 |
+
model.get_perplexity("I am very perplexed")
|
26 |
+
# 341.3 (low perplexity, since sentence style is formal and with no grammar mistakes)
|
27 |
+
|
28 |
+
model.get_perplexity("im hella trippin")
|
29 |
+
# 46793.5 (high perplexity, since the sentence is colloquial and contains grammar mistakes)
|
30 |
+
```
|
31 |
+
In the example above we see that, since Wikipedia is a collection of encyclopedic articles, a KenLM model trained on it will naturally give lower perplexity scores to sentences with formal language and no grammar mistakes than colloquial sentences with grammar mistakes.
|
model.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import unicodedata
|
4 |
+
from typing import Dict
|
5 |
+
|
6 |
+
import kenlm
|
7 |
+
import sentencepiece
|
8 |
+
from huggingface_hub import cached_download, hf_hub_url
|
9 |
+
|
10 |
+
|
11 |
+
class SentencePiece:
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
model: str,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.sp = sentencepiece.SentencePieceProcessor()
|
18 |
+
self.sp.load(str(model))
|
19 |
+
|
20 |
+
def do(self, text: dict) -> dict:
|
21 |
+
tokenized = self.sp.encode_as_pieces(text)
|
22 |
+
return " ".join(tokenized)
|
23 |
+
|
24 |
+
|
25 |
+
class KenlmModel:
|
26 |
+
digit_re: re.Pattern = re.compile(r"\d")
|
27 |
+
unicode_punct: Dict[str, str] = {
|
28 |
+
",": ",",
|
29 |
+
"。": ".",
|
30 |
+
"、": ",",
|
31 |
+
"„": '"',
|
32 |
+
"”": '"',
|
33 |
+
"“": '"',
|
34 |
+
"«": '"',
|
35 |
+
"»": '"',
|
36 |
+
"1": '"',
|
37 |
+
"」": '"',
|
38 |
+
"「": '"',
|
39 |
+
"《": '"',
|
40 |
+
"》": '"',
|
41 |
+
"´": "'",
|
42 |
+
"∶": ":",
|
43 |
+
":": ":",
|
44 |
+
"?": "?",
|
45 |
+
"!": "!",
|
46 |
+
"(": "(",
|
47 |
+
")": ")",
|
48 |
+
";": ";",
|
49 |
+
"–": "-",
|
50 |
+
"—": " - ",
|
51 |
+
".": ". ",
|
52 |
+
"~": "~",
|
53 |
+
"’": "'",
|
54 |
+
"…": "...",
|
55 |
+
"━": "-",
|
56 |
+
"〈": "<",
|
57 |
+
"〉": ">",
|
58 |
+
"【": "[",
|
59 |
+
"】": "]",
|
60 |
+
"%": "%",
|
61 |
+
"►": "-",
|
62 |
+
}
|
63 |
+
unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]")
|
64 |
+
non_printing_chars_re = re.compile(
|
65 |
+
f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
|
66 |
+
)
|
67 |
+
kenlm_model_dir = None
|
68 |
+
sentence_piece_model_dir = None
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
model_dataset: str,
|
73 |
+
language: str,
|
74 |
+
lower_case: bool = False,
|
75 |
+
remove_accents: bool = False,
|
76 |
+
normalize_numbers: bool = True,
|
77 |
+
punctuation: int = 1,
|
78 |
+
):
|
79 |
+
self.model = kenlm.Model(os.path.join(model_dataset, f"{language}.arpa.bin"))
|
80 |
+
self.tokenizer = SentencePiece(os.path.join(model_dataset, f"{language}.sp.model"))
|
81 |
+
self.accent = remove_accents
|
82 |
+
self.case = lower_case
|
83 |
+
self.numbers = normalize_numbers
|
84 |
+
self.punct = punctuation
|
85 |
+
|
86 |
+
@classmethod
|
87 |
+
def from_pretrained(
|
88 |
+
cls,
|
89 |
+
model_dataset: str,
|
90 |
+
language: str,
|
91 |
+
):
|
92 |
+
return cls(
|
93 |
+
model_dataset,
|
94 |
+
language,
|
95 |
+
False,
|
96 |
+
False,
|
97 |
+
True,
|
98 |
+
1,
|
99 |
+
)
|
100 |
+
|
101 |
+
def pp(self, log_score, length):
|
102 |
+
return 10.0 ** (-log_score / length)
|
103 |
+
|
104 |
+
def get_perplexity(self, doc: str, normalize_cc_net: bool = True):
|
105 |
+
if normalize_cc_net:
|
106 |
+
doc = self.normalize(
|
107 |
+
doc,
|
108 |
+
accent=self.accent,
|
109 |
+
case=self.case,
|
110 |
+
numbers=self.numbers,
|
111 |
+
punct=self.punct,
|
112 |
+
)
|
113 |
+
# Tokenize (after normalizing): See https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/mine.py#L352 for full pipeline
|
114 |
+
doc = self.tokenizer.do(doc)
|
115 |
+
doc_log_score, doc_length = 0, 0
|
116 |
+
for line in doc.split("\n"):
|
117 |
+
log_score = self.model.score(line)
|
118 |
+
length = len(line.split()) + 1
|
119 |
+
doc_log_score += log_score
|
120 |
+
doc_length += length
|
121 |
+
return round(self.pp(doc_log_score, doc_length), 1)
|
122 |
+
|
123 |
+
def normalize(
|
124 |
+
self,
|
125 |
+
line: str,
|
126 |
+
accent: bool = True,
|
127 |
+
case: bool = True,
|
128 |
+
numbers: bool = True,
|
129 |
+
punct: int = 1,
|
130 |
+
) -> str:
|
131 |
+
line = line.strip()
|
132 |
+
if not line:
|
133 |
+
return line
|
134 |
+
if case:
|
135 |
+
line = line.lower()
|
136 |
+
if accent:
|
137 |
+
line = self.strip_accents(line)
|
138 |
+
if numbers:
|
139 |
+
line = self.digit_re.sub("0", line)
|
140 |
+
if punct == 1:
|
141 |
+
line = self.replace_unicode_punct(line)
|
142 |
+
elif punct == 2:
|
143 |
+
line = self.remove_unicode_punct(line)
|
144 |
+
line = self.remove_non_printing_char(line)
|
145 |
+
return line
|
146 |
+
|
147 |
+
def strip_accents(self, line: str) -> str:
|
148 |
+
"""Strips accents from a piece of text."""
|
149 |
+
nfd = unicodedata.normalize("NFD", line)
|
150 |
+
output = [c for c in nfd if unicodedata.category(c) != "Mn"]
|
151 |
+
if len(output) == line:
|
152 |
+
return line
|
153 |
+
return "".join(output)
|
154 |
+
|
155 |
+
def replace_unicode_punct(self, text: str) -> str:
|
156 |
+
return "".join(self.unicode_punct.get(c, c) for c in text)
|
157 |
+
|
158 |
+
def remove_unicode_punct(self, text: str) -> str:
|
159 |
+
"""More aggressive version of replace_unicode_punct but also faster."""
|
160 |
+
return self.unicode_punct_re.sub("", text)
|
161 |
+
|
162 |
+
def remove_non_printing_char(self, text: str) -> str:
|
163 |
+
return self.non_printing_chars_re.sub("", text)
|