Spaces:
Runtime error
Runtime error
File size: 5,161 Bytes
26827a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
import random
import string
import numpy as np
import torch
from torch.utils.data import Dataset
class TokenClfDataset(Dataset): # Hàm tạo custom dataset
def __init__(
self,
texts,
max_len=512, # 256 (phobert) 512 (xlm-roberta)
tokenizer=None,
model_name="m_bert",
):
self.len = len(texts)
self.texts = texts
self.tokenizer = tokenizer
self.max_len = max_len
self.model_name = model_name
if "m_bert" in model_name:
self.cls_token = "[CLS]"
self.sep_token = "[SEP]"
self.unk_token = "[UNK]"
self.pad_token = "[PAD]"
self.mask_token = "[MASK]"
elif "xlm-roberta-large" in model_name:
self.bos_token = "<s>"
self.eos_token = "</s>"
self.sep_token = "</s>"
self.cls_token = "<s>"
self.unk_token = "<unk>"
self.pad_token = "<pad>"
self.mask_token = "<mask>"
elif "xlm-roberta" in model_name:
self.bos_token = "<s>"
self.eos_token = "</s>"
self.sep_token = "</s>"
self.cls_token = "<s>"
self.unk_token = "<unk>"
self.pad_token = "<pad>"
self.mask_token = "<mask>"
elif "phobert" in model_name:
self.bos_token = "<s>"
self.eos_token = "</s>"
self.sep_token = "</s>"
self.cls_token = "<s>"
self.unk_token = "<unk>"
self.pad_token = "<pad>"
self.mask_token = "<mask>"
#else: raise NotImplementedError()
def __getitem__(self, index):
text = self.texts[index]
tokenized_text = self.tokenizer.tokenize(text)
tokenized_text = (
[self.cls_token] + tokenized_text + [self.sep_token]
) # add special tokens
if len(tokenized_text) > self.max_len:
tokenized_text = tokenized_text[: self.max_len]
else:
tokenized_text = tokenized_text + [
self.pad_token for _ in range(self.max_len - len(tokenized_text))
]
attn_mask = [1 if tok != self.pad_token else 0 for tok in tokenized_text]
ids = self.tokenizer.convert_tokens_to_ids(tokenized_text)
return {
"ids": torch.tensor(ids, dtype=torch.long),
"mask": torch.tensor(attn_mask, dtype=torch.long),
}
def __len__(self):
return self.len
def seed_everything(seed: int):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def is_begin_of_new_word(token, model_name, force_tokens, token_map): # Thêm kí tự bắt đầu vào từ mới
if "m_bert" in model_name:
if token.lstrip("##") in force_tokens or token.lstrip("##") in set(
token_map.values()
):
return True
return not token.startswith("##")
elif "xlm-roberta-large" in model_name:
#print("xlm-roberta-large")
if (
token in string.punctuation
or token in force_tokens
or token in set(token_map.values())
):
return True
return token.startswith("▁") # check xem token có bắt đầu bằng kí tự "_" hay ko -> Trả về False
elif "xlm-roberta" in model_name:
#print("xlm-roberta-large")
if (
token in string.punctuation
or token in force_tokens
or token in set(token_map.values())
):
return True
return token.startswith("▁")
elif "phobert" in model_name:
#print("minh phobert")
#print("xlm-roberta-large")
if (
token in string.punctuation # điều kiện hoặc
or token in force_tokens
or token in set(token_map.values())
):
return True
#return token.startswith("▁") #
#return not token.startswith("▁")
#return not token.startswith("@@")
return not token.endswith("@@")
#return token.startswith("@@")
#else: raise NotImplementedError()
def replace_added_token(token, token_map):
for ori_token, new_token in token_map.items():
token = token.replace(new_token, ori_token)
return token
def get_pure_token(token, model_name): # hàm get pure token trả về token gốc (sau khi loại bỏ kí tự đặc biệt subword)
if "m_bert" in model_name:
return token.lstrip("##")
elif "xlm-roberta-large" in model_name:
return token.lstrip("▁") # bỏ kí tự "_" ở phía bên trái của từ
elif "xlm-roberta" in model_name:
return token.lstrip("▁") # bỏ kí tự "_" ở phía bên trái của từ
elif "phobert" in model_name:
#return token.lstrip("▁")
#return token.lstrip("@@")
return token.rstrip("@@")
# else: raise NotImplementedError() |