CjangCjengh commited on
Commit
fdfbe63
·
1 Parent(s): 6c5f0fe

upload files

Browse files
README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
+ import torch
3
+ from transformers import AutoModel, AutoTokenizer
4
+
5
+ model_path = 'CjangCjengh/NomBert-hn2qn-v0.1'
6
+ device = 'cuda'
7
+
8
+ model = AutoModel.from_pretrained(model_path, torch_dtype='auto', trust_remote_code=True).eval().to(device)
9
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
10
+
11
+ with torch.inference_mode():
12
+ output_text, output_probs = model.parse_nom_text(tokenizer, ['仍調𬖉𧡊㐌𤴬疸𢚸'])
13
+ print(output_text[0])
14
+ # những điều trông thấy đã đau đớn lòng
15
+ print(output_probs[0])
16
+ # [
17
+ # {'char': '仍', 'candidates': [('những', 0.5237383842468262), ('nhưng', 0.475042462348938), ('dưng', 0.0008663760963827372), ('nhang', 0.00022805406479164958), ('dừng', 8.42325171106495e-05), ('nhẵng', 1.6380783563363366e-05), ('nhùng', 1.5950208762660623e-05), ('nhửng', 3.0440487535088323e-06), ('nhăng', 2.9528700906666927e-06), ('nhẳng', 1.0688020211091498e-06), ('nhừng', 5.84112399337755e-07), ('nhâng', 5.119333650327462e-07)]},
18
+ # {'char': '調', 'candidates': [('điều', 0.8831620812416077), ('đều', 0.11558306217193604), ('điệu', 0.0012446790933609009), ('dìu', 8.889981472748332e-06), ('điu', 7.615183221787447e-07), ('đìu', 5.942594043517602e-07)]},
19
+ # {'char': '𬖉', 'candidates': [('trông', 1.0)]},
20
+ # {'char': '𧡊', 'candidates': [('thấy', 1.0)]},
21
+ # {'char': '㐌', 'candidates': [('đã', 0.9998464584350586), ('dã', 0.00014108473260421306), ('đà', 1.2395633348205592e-05)]},
22
+ # {'char': '𤴬', 'candidates': [('đau', 0.9999825954437256), ('đáu', 1.744620021781884e-05)]},
23
+ # {'char': '疸', 'candidates': [('đớn', 0.9998302459716797), ('đơn', 0.00014517175441142172), ('đảm', 2.457975824654568e-05)]},
24
+ # {'char': '𢚸', 'candidates': [('lòng', 1.0)]}
25
+ # ]
26
+ ```
config.json ADDED
The diff for this file is too large to render. See raw diff
 
configuration_nombert.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertConfig
2
+
3
+
4
+ class NomBertConfig(BertConfig):
5
+ def __init__(self, unk_id=0, id_start=1, output_vocab_size=7430, lm_head_dict={}, **kwargs):
6
+ super().__init__(**kwargs)
7
+ self.unk_id = unk_id
8
+ self.id_start = id_start
9
+ self.output_vocab_size = output_vocab_size
10
+ self.lm_head_dict = lm_head_dict
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc19832810a9514c0daaa2fc1d5624f95e6793e5b453ba2e905a359ed03f45f6
3
+ size 255697712
modeling_nombert.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import BertPreTrainedModel, BertModel
6
+ from .configuration_nombert import NomBertConfig
7
+
8
+
9
+ class NomBertModel(BertPreTrainedModel):
10
+ config_class = NomBertConfig
11
+
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+ self.bert = BertModel(config)
15
+ self.max_position_embeddings = config.max_position_embeddings
16
+ self.lm_head_dict = config.lm_head_dict
17
+ self.registered_token_ids = list(map(int, config.lm_head_dict.keys()))
18
+ self.lm_head = nn.Embedding(config.output_vocab_size, config.hidden_size)
19
+
20
+ def forward(self, input_ids, labels=None, attention_mask=None):
21
+ outputs = self.bert(input_ids, attention_mask)
22
+ hidden_states = outputs.last_hidden_state
23
+
24
+ if attention_mask is None:
25
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long)
26
+
27
+ registered_token_ids_tensor = torch.tensor(
28
+ self.registered_token_ids,
29
+ device=input_ids.device
30
+ )
31
+ valid_token_mask = torch.isin(input_ids, registered_token_ids_tensor)
32
+ valid_mask = valid_token_mask & attention_mask.bool()
33
+
34
+ loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
35
+
36
+ for token_id_str in self.lm_head_dict.keys():
37
+ token_id = int(token_id_str)
38
+ mask = (input_ids == token_id) & valid_mask
39
+ selected_hidden = hidden_states[mask]
40
+ selected_labels = labels[mask] if labels is not None else None
41
+
42
+ if selected_hidden.size(0) == 0:
43
+ continue
44
+
45
+ lm_head_ids = self.lm_head_dict[token_id_str]
46
+ lm_head_ids_tensor = torch.tensor(lm_head_ids, device=input_ids.device)
47
+ lm_head = self.lm_head(lm_head_ids_tensor)
48
+ logits = torch.matmul(selected_hidden, lm_head.T)
49
+
50
+ if labels is not None:
51
+ loss = loss + F.cross_entropy(
52
+ logits,
53
+ selected_labels,
54
+ ignore_index=-100
55
+ )
56
+
57
+ return {'loss': loss} if labels is not None else outputs
58
+
59
+ def parse_nom_text(self, tokenizer, texts, post_normalize=True, batch_size=None):
60
+ max_length = self.max_position_embeddings
61
+ segments_info = []
62
+ for text_idx, text in enumerate(texts):
63
+ segments = [text[i:i+max_length] for i in range(0, len(text), max_length)]
64
+ for seg_idx, seg in enumerate(segments):
65
+ segments_info.append((text_idx, seg_idx, seg))
66
+
67
+ all_segments = [seg for _, _, seg in segments_info]
68
+ all_pred_chars = []
69
+ all_pred_probs = []
70
+
71
+ if batch_size is None:
72
+ batch_size = len(texts)
73
+ for i in range(0, len(all_segments), batch_size):
74
+ batch_segments = all_segments[i:i+batch_size]
75
+ batch_pred_chars, batch_pred_probs = self._parse_nom_text_batch(tokenizer, batch_segments)
76
+ all_pred_chars.extend(batch_pred_chars)
77
+ all_pred_probs.extend(batch_pred_probs)
78
+
79
+ text_results = {}
80
+ for text_idx in range(len(texts)):
81
+ text_results[text_idx] = {'chars': [], 'probs': []}
82
+
83
+ for (text_idx, seg_idx, _), pred_chars, pred_probs in zip(segments_info, all_pred_chars, all_pred_probs):
84
+ text_results[text_idx]['chars'].append((seg_idx, pred_chars))
85
+ text_results[text_idx]['probs'].append((seg_idx, pred_probs))
86
+
87
+ output_texts = []
88
+ all_outputs_probs = []
89
+ for text_idx in range(len(texts)):
90
+ sorted_chars = sorted(text_results[text_idx]['chars'], key=lambda x: x[0])
91
+ sorted_probs = sorted(text_results[text_idx]['probs'], key=lambda x: x[0])
92
+
93
+ merged_chars = []
94
+ merged_probs = []
95
+ for seg_idx, chars in sorted_chars:
96
+ merged_chars.extend(chars)
97
+ for seg_idx, probs in sorted_probs:
98
+ merged_probs.extend(probs)
99
+
100
+ output_text = ''
101
+ for i, (char, processed) in enumerate(merged_chars):
102
+ output_text += char
103
+ if i < len(merged_chars)-1 and (processed or merged_chars[i+1][1]):
104
+ output_text += ' '
105
+
106
+ if post_normalize:
107
+ output_text = self.post_normalize(output_text)
108
+ output_texts.append(output_text)
109
+ all_outputs_probs.append(merged_probs)
110
+
111
+ return output_texts, all_outputs_probs
112
+
113
+ def _parse_nom_text_batch(self, tokenizer, segments):
114
+ encoded = tokenizer.batch_encode_plus(
115
+ segments,
116
+ add_special_tokens=False,
117
+ padding=True,
118
+ return_tensors='pt',
119
+ truncation=True,
120
+ max_length=self.max_position_embeddings
121
+ )
122
+ input_ids = encoded['input_ids'].to(self.device)
123
+ attention_mask = encoded['attention_mask'].to(self.device)
124
+
125
+ batch_size = len(segments)
126
+ id_to_options_ids = list(tokenizer.id_to_options.keys())
127
+ id_to_options_tensor = torch.tensor(id_to_options_ids, device=self.device)
128
+ registered_ids = torch.tensor(self.registered_token_ids, device=self.device)
129
+ valid_mask = (
130
+ torch.isin(input_ids, registered_ids) &
131
+ attention_mask.bool()
132
+ )
133
+
134
+ pred_chars = [[(c, False) for c in seg] for seg in segments]
135
+ pred_probs = [[] for _ in range(batch_size)]
136
+
137
+ if valid_mask.any():
138
+ outputs = self.bert(input_ids, attention_mask=attention_mask)
139
+ hidden_states = outputs.last_hidden_state
140
+
141
+ batch_indices, seq_indices = torch.where(valid_mask)
142
+ token_ids = input_ids[batch_indices, seq_indices]
143
+ hidden_vecs = hidden_states[batch_indices, seq_indices]
144
+
145
+ for token_id_str in self.lm_head_dict:
146
+ token_id = int(token_id_str)
147
+ token_mask = (token_ids == token_id)
148
+ if not token_mask.any():
149
+ continue
150
+
151
+ token_hidden = hidden_vecs[token_mask]
152
+ token_batch = batch_indices[token_mask]
153
+ token_seq = seq_indices[token_mask]
154
+
155
+ lm_head_ids = self.lm_head_dict[token_id_str]
156
+ lm_head_ids_tensor = torch.tensor(lm_head_ids, device=input_ids.device)
157
+ lm_head = self.lm_head(lm_head_ids_tensor)
158
+ logits = torch.matmul(token_hidden, lm_head.T)
159
+ probs = F.softmax(logits, dim=-1)
160
+ preds = torch.argmax(logits, dim=-1)
161
+
162
+ for i, (b, s) in enumerate(zip(token_batch.tolist(), token_seq.tolist())):
163
+ options = tokenizer.id_to_options[token_id]
164
+ char = options[preds[i].item()]
165
+
166
+ pred_chars[b][s] = (char, True)
167
+
168
+ candidates = sorted(
169
+ [(opt, probs[i][j].item()) for j, opt in enumerate(options)],
170
+ key=lambda x: x[1], reverse=True
171
+ )
172
+
173
+ if s >= len(pred_probs[b]):
174
+ pred_probs[b].extend([{}] * (s - len(pred_probs[b]) + 1))
175
+
176
+ pred_probs[b][s] = {
177
+ 'char': segments[b][s],
178
+ 'candidates': candidates
179
+ }
180
+
181
+ single_option_mask = (
182
+ attention_mask.bool() &
183
+ torch.isin(input_ids, id_to_options_tensor) &
184
+ ~torch.isin(input_ids, registered_ids)
185
+ )
186
+ batch_indices_single, seq_indices_single = torch.where(single_option_mask)
187
+
188
+ for b, s in zip(batch_indices_single.tolist(), seq_indices_single.tolist()):
189
+ token_id = input_ids[b, s].item()
190
+ options = tokenizer.id_to_options[token_id]
191
+
192
+ pred_chars[b][s] = (options[0], True)
193
+
194
+ if s >= len(pred_probs[b]):
195
+ pred_probs[b].extend([{}] * (s - len(pred_probs[b]) + 1))
196
+
197
+ pred_probs[b][s] = {
198
+ 'char': segments[b][s],
199
+ 'candidates': [(options[0], 1.0)]
200
+ }
201
+
202
+ for b in range(batch_size):
203
+ seg_len = len(segments[b])
204
+ pred_chars[b] = pred_chars[b][:seg_len]
205
+ for s in range(seg_len):
206
+ if s < len(pred_probs[b]) and pred_probs[b][s]:
207
+ continue
208
+ char = segments[b][s]
209
+ if s >= input_ids.shape[1]:
210
+ token_id = 0
211
+ else:
212
+ token_id = input_ids[b, s].item()
213
+
214
+ candidates = [(char, 1.0)]
215
+ if token_id != 0 and token_id in tokenizer.id_to_options:
216
+ options = tokenizer.id_to_options[token_id]
217
+ if len(options) == 1:
218
+ candidates = [(options[0], 1.0)]
219
+
220
+ if s >= len(pred_probs[b]):
221
+ pred_probs[b].extend([{}] * (s - len(pred_probs[b]) + 1))
222
+
223
+ pred_probs[b] = pred_probs[b][:seg_len]
224
+
225
+ pred_probs = [[p for p in batch if p != {}] for batch in pred_probs]
226
+
227
+ return pred_chars, pred_probs
228
+
229
+ def post_normalize(self, text):
230
+ text = re.sub(r'\s*[。\.]', '.', text)
231
+ text = re.sub(r'\s*[,、,]', ',', text)
232
+ text = re.sub(r'\s*[!!]', '!', text)
233
+ text = re.sub(r'\s*[?\?]', '?', text)
234
+ return text
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
tokenization_nombert.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from transformers import PreTrainedTokenizer
4
+
5
+
6
+ class NomTokenizer(PreTrainedTokenizer):
7
+ vocab_files_names = {'vocab_file': 'vocab.json'}
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_file,
12
+ unk_token='<UNK>',
13
+ unk_token_id=0,
14
+ id_start=1,
15
+ **kwargs
16
+ ):
17
+ self.vocab_file = vocab_file
18
+ self.id_start = id_start
19
+ self.unk_token = unk_token
20
+ self.unk_token_id = unk_token_id
21
+ self.pad_token = unk_token
22
+ self.pad_token_id = unk_token_id
23
+
24
+ with open(vocab_file, 'r', encoding='utf-8') as f:
25
+ self.vocab_dict = json.load(f)
26
+
27
+ self.char2id = {}
28
+ self.id2char = {}
29
+ for i, char in enumerate(self.vocab_dict.keys(), start=id_start):
30
+ self.char2id[char] = i
31
+ self.id2char[i] = char
32
+ self.id_to_options = {idx: v for idx, v in enumerate(self.vocab_dict.values(), start=id_start)}
33
+
34
+ super().__init__(**kwargs)
35
+
36
+ def _tokenize(self, text):
37
+ return list(text)
38
+
39
+ def _convert_token_to_id(self, token):
40
+ return self.char2id.get(token, self.unk_token_id)
41
+
42
+ def _convert_id_to_token(self, index):
43
+ if index == self.unk_token_id:
44
+ return self.unk_token
45
+ return self.id2char.get(index, self.unk_token)
46
+
47
+ @property
48
+ def vocab_size(self):
49
+ return len(self.char2id) + 1
50
+
51
+ def get_vocab(self):
52
+ vocab = {**self.char2id, **self.added_tokens_encoder}
53
+ return vocab
54
+
55
+ def save_vocabulary(self, save_directory, filename_prefix=None):
56
+ if filename_prefix:
57
+ vocab_file = os.path.join(save_directory, f'{filename_prefix}-vocab.json')
58
+ else:
59
+ vocab_file = os.path.join(save_directory, 'vocab.json')
60
+
61
+ with open(vocab_file, 'w', encoding='utf-8') as f:
62
+ json.dump(self.vocab_dict, f, ensure_ascii=False)
63
+
64
+ return (vocab_file,)
tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {},
3
+ "clean_up_tokenization_spaces": false,
4
+ "extra_special_tokens": {},
5
+ "model_max_length": 1000000000000000019884624838656,
6
+ "tokenizer_class": "NomTokenizer",
7
+ "auto_map": {
8
+ "AutoTokenizer": [
9
+ "tokenization_nombert.NomTokenizer",
10
+ null
11
+ ]
12
+ },
13
+ "unk_token": "<UNK>"
14
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff