Commit
·
fdfbe63
1
Parent(s):
6c5f0fe
upload files
Browse files- README.md +26 -0
- config.json +0 -0
- configuration_nombert.py +10 -0
- model.safetensors +3 -0
- modeling_nombert.py +234 -0
- special_tokens_map.json +1 -0
- tokenization_nombert.py +64 -0
- tokenizer_config.json +14 -0
- vocab.json +0 -0
README.md
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
```python
|
2 |
+
import torch
|
3 |
+
from transformers import AutoModel, AutoTokenizer
|
4 |
+
|
5 |
+
model_path = 'CjangCjengh/NomBert-hn2qn-v0.1'
|
6 |
+
device = 'cuda'
|
7 |
+
|
8 |
+
model = AutoModel.from_pretrained(model_path, torch_dtype='auto', trust_remote_code=True).eval().to(device)
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
10 |
+
|
11 |
+
with torch.inference_mode():
|
12 |
+
output_text, output_probs = model.parse_nom_text(tokenizer, ['仍調𬖉𧡊㐌𤴬疸𢚸'])
|
13 |
+
print(output_text[0])
|
14 |
+
# những điều trông thấy đã đau đớn lòng
|
15 |
+
print(output_probs[0])
|
16 |
+
# [
|
17 |
+
# {'char': '仍', 'candidates': [('những', 0.5237383842468262), ('nhưng', 0.475042462348938), ('dưng', 0.0008663760963827372), ('nhang', 0.00022805406479164958), ('dừng', 8.42325171106495e-05), ('nhẵng', 1.6380783563363366e-05), ('nhùng', 1.5950208762660623e-05), ('nhửng', 3.0440487535088323e-06), ('nhăng', 2.9528700906666927e-06), ('nhẳng', 1.0688020211091498e-06), ('nhừng', 5.84112399337755e-07), ('nhâng', 5.119333650327462e-07)]},
|
18 |
+
# {'char': '調', 'candidates': [('điều', 0.8831620812416077), ('đều', 0.11558306217193604), ('điệu', 0.0012446790933609009), ('dìu', 8.889981472748332e-06), ('điu', 7.615183221787447e-07), ('đìu', 5.942594043517602e-07)]},
|
19 |
+
# {'char': '𬖉', 'candidates': [('trông', 1.0)]},
|
20 |
+
# {'char': '𧡊', 'candidates': [('thấy', 1.0)]},
|
21 |
+
# {'char': '㐌', 'candidates': [('đã', 0.9998464584350586), ('dã', 0.00014108473260421306), ('đà', 1.2395633348205592e-05)]},
|
22 |
+
# {'char': '𤴬', 'candidates': [('đau', 0.9999825954437256), ('đáu', 1.744620021781884e-05)]},
|
23 |
+
# {'char': '疸', 'candidates': [('đớn', 0.9998302459716797), ('đơn', 0.00014517175441142172), ('đảm', 2.457975824654568e-05)]},
|
24 |
+
# {'char': '𢚸', 'candidates': [('lòng', 1.0)]}
|
25 |
+
# ]
|
26 |
+
```
|
config.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
configuration_nombert.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertConfig
|
2 |
+
|
3 |
+
|
4 |
+
class NomBertConfig(BertConfig):
|
5 |
+
def __init__(self, unk_id=0, id_start=1, output_vocab_size=7430, lm_head_dict={}, **kwargs):
|
6 |
+
super().__init__(**kwargs)
|
7 |
+
self.unk_id = unk_id
|
8 |
+
self.id_start = id_start
|
9 |
+
self.output_vocab_size = output_vocab_size
|
10 |
+
self.lm_head_dict = lm_head_dict
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dc19832810a9514c0daaa2fc1d5624f95e6793e5b453ba2e905a359ed03f45f6
|
3 |
+
size 255697712
|
modeling_nombert.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from transformers import BertPreTrainedModel, BertModel
|
6 |
+
from .configuration_nombert import NomBertConfig
|
7 |
+
|
8 |
+
|
9 |
+
class NomBertModel(BertPreTrainedModel):
|
10 |
+
config_class = NomBertConfig
|
11 |
+
|
12 |
+
def __init__(self, config):
|
13 |
+
super().__init__(config)
|
14 |
+
self.bert = BertModel(config)
|
15 |
+
self.max_position_embeddings = config.max_position_embeddings
|
16 |
+
self.lm_head_dict = config.lm_head_dict
|
17 |
+
self.registered_token_ids = list(map(int, config.lm_head_dict.keys()))
|
18 |
+
self.lm_head = nn.Embedding(config.output_vocab_size, config.hidden_size)
|
19 |
+
|
20 |
+
def forward(self, input_ids, labels=None, attention_mask=None):
|
21 |
+
outputs = self.bert(input_ids, attention_mask)
|
22 |
+
hidden_states = outputs.last_hidden_state
|
23 |
+
|
24 |
+
if attention_mask is None:
|
25 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
26 |
+
|
27 |
+
registered_token_ids_tensor = torch.tensor(
|
28 |
+
self.registered_token_ids,
|
29 |
+
device=input_ids.device
|
30 |
+
)
|
31 |
+
valid_token_mask = torch.isin(input_ids, registered_token_ids_tensor)
|
32 |
+
valid_mask = valid_token_mask & attention_mask.bool()
|
33 |
+
|
34 |
+
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
|
35 |
+
|
36 |
+
for token_id_str in self.lm_head_dict.keys():
|
37 |
+
token_id = int(token_id_str)
|
38 |
+
mask = (input_ids == token_id) & valid_mask
|
39 |
+
selected_hidden = hidden_states[mask]
|
40 |
+
selected_labels = labels[mask] if labels is not None else None
|
41 |
+
|
42 |
+
if selected_hidden.size(0) == 0:
|
43 |
+
continue
|
44 |
+
|
45 |
+
lm_head_ids = self.lm_head_dict[token_id_str]
|
46 |
+
lm_head_ids_tensor = torch.tensor(lm_head_ids, device=input_ids.device)
|
47 |
+
lm_head = self.lm_head(lm_head_ids_tensor)
|
48 |
+
logits = torch.matmul(selected_hidden, lm_head.T)
|
49 |
+
|
50 |
+
if labels is not None:
|
51 |
+
loss = loss + F.cross_entropy(
|
52 |
+
logits,
|
53 |
+
selected_labels,
|
54 |
+
ignore_index=-100
|
55 |
+
)
|
56 |
+
|
57 |
+
return {'loss': loss} if labels is not None else outputs
|
58 |
+
|
59 |
+
def parse_nom_text(self, tokenizer, texts, post_normalize=True, batch_size=None):
|
60 |
+
max_length = self.max_position_embeddings
|
61 |
+
segments_info = []
|
62 |
+
for text_idx, text in enumerate(texts):
|
63 |
+
segments = [text[i:i+max_length] for i in range(0, len(text), max_length)]
|
64 |
+
for seg_idx, seg in enumerate(segments):
|
65 |
+
segments_info.append((text_idx, seg_idx, seg))
|
66 |
+
|
67 |
+
all_segments = [seg for _, _, seg in segments_info]
|
68 |
+
all_pred_chars = []
|
69 |
+
all_pred_probs = []
|
70 |
+
|
71 |
+
if batch_size is None:
|
72 |
+
batch_size = len(texts)
|
73 |
+
for i in range(0, len(all_segments), batch_size):
|
74 |
+
batch_segments = all_segments[i:i+batch_size]
|
75 |
+
batch_pred_chars, batch_pred_probs = self._parse_nom_text_batch(tokenizer, batch_segments)
|
76 |
+
all_pred_chars.extend(batch_pred_chars)
|
77 |
+
all_pred_probs.extend(batch_pred_probs)
|
78 |
+
|
79 |
+
text_results = {}
|
80 |
+
for text_idx in range(len(texts)):
|
81 |
+
text_results[text_idx] = {'chars': [], 'probs': []}
|
82 |
+
|
83 |
+
for (text_idx, seg_idx, _), pred_chars, pred_probs in zip(segments_info, all_pred_chars, all_pred_probs):
|
84 |
+
text_results[text_idx]['chars'].append((seg_idx, pred_chars))
|
85 |
+
text_results[text_idx]['probs'].append((seg_idx, pred_probs))
|
86 |
+
|
87 |
+
output_texts = []
|
88 |
+
all_outputs_probs = []
|
89 |
+
for text_idx in range(len(texts)):
|
90 |
+
sorted_chars = sorted(text_results[text_idx]['chars'], key=lambda x: x[0])
|
91 |
+
sorted_probs = sorted(text_results[text_idx]['probs'], key=lambda x: x[0])
|
92 |
+
|
93 |
+
merged_chars = []
|
94 |
+
merged_probs = []
|
95 |
+
for seg_idx, chars in sorted_chars:
|
96 |
+
merged_chars.extend(chars)
|
97 |
+
for seg_idx, probs in sorted_probs:
|
98 |
+
merged_probs.extend(probs)
|
99 |
+
|
100 |
+
output_text = ''
|
101 |
+
for i, (char, processed) in enumerate(merged_chars):
|
102 |
+
output_text += char
|
103 |
+
if i < len(merged_chars)-1 and (processed or merged_chars[i+1][1]):
|
104 |
+
output_text += ' '
|
105 |
+
|
106 |
+
if post_normalize:
|
107 |
+
output_text = self.post_normalize(output_text)
|
108 |
+
output_texts.append(output_text)
|
109 |
+
all_outputs_probs.append(merged_probs)
|
110 |
+
|
111 |
+
return output_texts, all_outputs_probs
|
112 |
+
|
113 |
+
def _parse_nom_text_batch(self, tokenizer, segments):
|
114 |
+
encoded = tokenizer.batch_encode_plus(
|
115 |
+
segments,
|
116 |
+
add_special_tokens=False,
|
117 |
+
padding=True,
|
118 |
+
return_tensors='pt',
|
119 |
+
truncation=True,
|
120 |
+
max_length=self.max_position_embeddings
|
121 |
+
)
|
122 |
+
input_ids = encoded['input_ids'].to(self.device)
|
123 |
+
attention_mask = encoded['attention_mask'].to(self.device)
|
124 |
+
|
125 |
+
batch_size = len(segments)
|
126 |
+
id_to_options_ids = list(tokenizer.id_to_options.keys())
|
127 |
+
id_to_options_tensor = torch.tensor(id_to_options_ids, device=self.device)
|
128 |
+
registered_ids = torch.tensor(self.registered_token_ids, device=self.device)
|
129 |
+
valid_mask = (
|
130 |
+
torch.isin(input_ids, registered_ids) &
|
131 |
+
attention_mask.bool()
|
132 |
+
)
|
133 |
+
|
134 |
+
pred_chars = [[(c, False) for c in seg] for seg in segments]
|
135 |
+
pred_probs = [[] for _ in range(batch_size)]
|
136 |
+
|
137 |
+
if valid_mask.any():
|
138 |
+
outputs = self.bert(input_ids, attention_mask=attention_mask)
|
139 |
+
hidden_states = outputs.last_hidden_state
|
140 |
+
|
141 |
+
batch_indices, seq_indices = torch.where(valid_mask)
|
142 |
+
token_ids = input_ids[batch_indices, seq_indices]
|
143 |
+
hidden_vecs = hidden_states[batch_indices, seq_indices]
|
144 |
+
|
145 |
+
for token_id_str in self.lm_head_dict:
|
146 |
+
token_id = int(token_id_str)
|
147 |
+
token_mask = (token_ids == token_id)
|
148 |
+
if not token_mask.any():
|
149 |
+
continue
|
150 |
+
|
151 |
+
token_hidden = hidden_vecs[token_mask]
|
152 |
+
token_batch = batch_indices[token_mask]
|
153 |
+
token_seq = seq_indices[token_mask]
|
154 |
+
|
155 |
+
lm_head_ids = self.lm_head_dict[token_id_str]
|
156 |
+
lm_head_ids_tensor = torch.tensor(lm_head_ids, device=input_ids.device)
|
157 |
+
lm_head = self.lm_head(lm_head_ids_tensor)
|
158 |
+
logits = torch.matmul(token_hidden, lm_head.T)
|
159 |
+
probs = F.softmax(logits, dim=-1)
|
160 |
+
preds = torch.argmax(logits, dim=-1)
|
161 |
+
|
162 |
+
for i, (b, s) in enumerate(zip(token_batch.tolist(), token_seq.tolist())):
|
163 |
+
options = tokenizer.id_to_options[token_id]
|
164 |
+
char = options[preds[i].item()]
|
165 |
+
|
166 |
+
pred_chars[b][s] = (char, True)
|
167 |
+
|
168 |
+
candidates = sorted(
|
169 |
+
[(opt, probs[i][j].item()) for j, opt in enumerate(options)],
|
170 |
+
key=lambda x: x[1], reverse=True
|
171 |
+
)
|
172 |
+
|
173 |
+
if s >= len(pred_probs[b]):
|
174 |
+
pred_probs[b].extend([{}] * (s - len(pred_probs[b]) + 1))
|
175 |
+
|
176 |
+
pred_probs[b][s] = {
|
177 |
+
'char': segments[b][s],
|
178 |
+
'candidates': candidates
|
179 |
+
}
|
180 |
+
|
181 |
+
single_option_mask = (
|
182 |
+
attention_mask.bool() &
|
183 |
+
torch.isin(input_ids, id_to_options_tensor) &
|
184 |
+
~torch.isin(input_ids, registered_ids)
|
185 |
+
)
|
186 |
+
batch_indices_single, seq_indices_single = torch.where(single_option_mask)
|
187 |
+
|
188 |
+
for b, s in zip(batch_indices_single.tolist(), seq_indices_single.tolist()):
|
189 |
+
token_id = input_ids[b, s].item()
|
190 |
+
options = tokenizer.id_to_options[token_id]
|
191 |
+
|
192 |
+
pred_chars[b][s] = (options[0], True)
|
193 |
+
|
194 |
+
if s >= len(pred_probs[b]):
|
195 |
+
pred_probs[b].extend([{}] * (s - len(pred_probs[b]) + 1))
|
196 |
+
|
197 |
+
pred_probs[b][s] = {
|
198 |
+
'char': segments[b][s],
|
199 |
+
'candidates': [(options[0], 1.0)]
|
200 |
+
}
|
201 |
+
|
202 |
+
for b in range(batch_size):
|
203 |
+
seg_len = len(segments[b])
|
204 |
+
pred_chars[b] = pred_chars[b][:seg_len]
|
205 |
+
for s in range(seg_len):
|
206 |
+
if s < len(pred_probs[b]) and pred_probs[b][s]:
|
207 |
+
continue
|
208 |
+
char = segments[b][s]
|
209 |
+
if s >= input_ids.shape[1]:
|
210 |
+
token_id = 0
|
211 |
+
else:
|
212 |
+
token_id = input_ids[b, s].item()
|
213 |
+
|
214 |
+
candidates = [(char, 1.0)]
|
215 |
+
if token_id != 0 and token_id in tokenizer.id_to_options:
|
216 |
+
options = tokenizer.id_to_options[token_id]
|
217 |
+
if len(options) == 1:
|
218 |
+
candidates = [(options[0], 1.0)]
|
219 |
+
|
220 |
+
if s >= len(pred_probs[b]):
|
221 |
+
pred_probs[b].extend([{}] * (s - len(pred_probs[b]) + 1))
|
222 |
+
|
223 |
+
pred_probs[b] = pred_probs[b][:seg_len]
|
224 |
+
|
225 |
+
pred_probs = [[p for p in batch if p != {}] for batch in pred_probs]
|
226 |
+
|
227 |
+
return pred_chars, pred_probs
|
228 |
+
|
229 |
+
def post_normalize(self, text):
|
230 |
+
text = re.sub(r'\s*[。\.]', '.', text)
|
231 |
+
text = re.sub(r'\s*[,、,]', ',', text)
|
232 |
+
text = re.sub(r'\s*[!!]', '!', text)
|
233 |
+
text = re.sub(r'\s*[?\?]', '?', text)
|
234 |
+
return text
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{}
|
tokenization_nombert.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from transformers import PreTrainedTokenizer
|
4 |
+
|
5 |
+
|
6 |
+
class NomTokenizer(PreTrainedTokenizer):
|
7 |
+
vocab_files_names = {'vocab_file': 'vocab.json'}
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
vocab_file,
|
12 |
+
unk_token='<UNK>',
|
13 |
+
unk_token_id=0,
|
14 |
+
id_start=1,
|
15 |
+
**kwargs
|
16 |
+
):
|
17 |
+
self.vocab_file = vocab_file
|
18 |
+
self.id_start = id_start
|
19 |
+
self.unk_token = unk_token
|
20 |
+
self.unk_token_id = unk_token_id
|
21 |
+
self.pad_token = unk_token
|
22 |
+
self.pad_token_id = unk_token_id
|
23 |
+
|
24 |
+
with open(vocab_file, 'r', encoding='utf-8') as f:
|
25 |
+
self.vocab_dict = json.load(f)
|
26 |
+
|
27 |
+
self.char2id = {}
|
28 |
+
self.id2char = {}
|
29 |
+
for i, char in enumerate(self.vocab_dict.keys(), start=id_start):
|
30 |
+
self.char2id[char] = i
|
31 |
+
self.id2char[i] = char
|
32 |
+
self.id_to_options = {idx: v for idx, v in enumerate(self.vocab_dict.values(), start=id_start)}
|
33 |
+
|
34 |
+
super().__init__(**kwargs)
|
35 |
+
|
36 |
+
def _tokenize(self, text):
|
37 |
+
return list(text)
|
38 |
+
|
39 |
+
def _convert_token_to_id(self, token):
|
40 |
+
return self.char2id.get(token, self.unk_token_id)
|
41 |
+
|
42 |
+
def _convert_id_to_token(self, index):
|
43 |
+
if index == self.unk_token_id:
|
44 |
+
return self.unk_token
|
45 |
+
return self.id2char.get(index, self.unk_token)
|
46 |
+
|
47 |
+
@property
|
48 |
+
def vocab_size(self):
|
49 |
+
return len(self.char2id) + 1
|
50 |
+
|
51 |
+
def get_vocab(self):
|
52 |
+
vocab = {**self.char2id, **self.added_tokens_encoder}
|
53 |
+
return vocab
|
54 |
+
|
55 |
+
def save_vocabulary(self, save_directory, filename_prefix=None):
|
56 |
+
if filename_prefix:
|
57 |
+
vocab_file = os.path.join(save_directory, f'{filename_prefix}-vocab.json')
|
58 |
+
else:
|
59 |
+
vocab_file = os.path.join(save_directory, 'vocab.json')
|
60 |
+
|
61 |
+
with open(vocab_file, 'w', encoding='utf-8') as f:
|
62 |
+
json.dump(self.vocab_dict, f, ensure_ascii=False)
|
63 |
+
|
64 |
+
return (vocab_file,)
|
tokenizer_config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {},
|
3 |
+
"clean_up_tokenization_spaces": false,
|
4 |
+
"extra_special_tokens": {},
|
5 |
+
"model_max_length": 1000000000000000019884624838656,
|
6 |
+
"tokenizer_class": "NomTokenizer",
|
7 |
+
"auto_map": {
|
8 |
+
"AutoTokenizer": [
|
9 |
+
"tokenization_nombert.NomTokenizer",
|
10 |
+
null
|
11 |
+
]
|
12 |
+
},
|
13 |
+
"unk_token": "<UNK>"
|
14 |
+
}
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|