import random from string import punctuation import re import os from transformers import AutoTokenizer from tqdm import tqdm from typing import List from constants import KURDISH_CHARS, KEYBOARD_BLANK, KEYBOARD_KEYS, NUMBERS def tokenizer_check_if_text_too_long(text, tokenizer, max_length): data = tokenizer.batch_encode_plus([text],max_length=max_length,truncation=True,return_overflowing_tokens=True ) if len(data["input_ids"]) > 1: return True else: return False#, len(data["input_ids"][0]) def delete_characters(text, char_delete_percentage=0.01): modifyed_line = [] for char in text: if random.random() > char_delete_percentage or char in NUMBERS: modifyed_line.append(char) return "".join(modifyed_line) def insert_characters(text, augmentation_probability=0.01): modifyed_line = [] for char in text: if random.random() <= augmentation_probability and char not in NUMBERS: modifyed_line.append(random.choice(KURDISH_CHARS)) modifyed_line.append(char) return "".join(modifyed_line) def replace_characters(text, augmentation_probability=0.01): modifyed_line = [] for char in text: if random.random() <= augmentation_probability and char not in NUMBERS: modifyed_line.append(random.choice(KURDISH_CHARS)) else: modifyed_line.append(char) return "".join(modifyed_line) def random_neighbor_replace(line: str, keyboard_rows: List[str], blank: str) -> str: lines = keyboard_rows n_rows = len(keyboard_rows) _mapper = {} def __get_left(row_idx: int, col_idx: int) -> List[str]: if col_idx == 0: return [] return [lines[row_idx][col_idx - 1]] def __get_right(row_idx: int, col_idx: int) -> List[str]: if col_idx == (len(lines[row_idx]) - 1): return [] return lines[row_idx][col_idx + 1] def __get_upper(row_idx: int, col_idx: int) -> List[str]: if row_idx == 0: return [] line = lines[row_idx - 1] start = max(0, col_idx - 1) end = min(len(line), col_idx + 2) return list(line[start: end]) def __get_lower(row_idx: int, col_idx: int) -> List[str]: if row_idx == (n_rows - 1): return [] line = lines[row_idx + 1] start = max(0, col_idx - 1) end = min(len(line), col_idx + 2) return list(line[start: end]) funcs = [__get_left, __get_right, __get_upper, __get_lower] for row_idx in range(n_rows): for col_idx in range(len(lines[row_idx])): items = [] for func in funcs: items.extend(func(row_idx, col_idx)) items = list(filter(lambda x: x != blank, items)) char = lines[row_idx][col_idx] _mapper[char] = items.copy() def get_char(char: str) -> str: if char not in _mapper: return char return random.choice(_mapper[char]) length = len(line) if length == 0: length = 1 idx = random.randint(0, length - 1) return line[:idx] + get_char(line[idx]) + line[idx + 1:] def lower_case_words(text, augmentation_probability=0.5): modifyed_line = [] for word in text.split(): if word[0].islower() == False and random.random() <= augmentation_probability: word = word.lower() modifyed_line.append(word) return " ".join(modifyed_line) clean_chars = re.compile(r'[^A-Za-zöäüÖÄÜß,.!?’\'$%€0-9\(\)\- ]', re.MULTILINE) def cleanup(text): text = clean_chars.sub('', text) #print("bug: somehow all numbers are removed - this is might be due to this regex") #exit() #text = text.replace("\n", "") #text = text.replace('"','\\"') return text clean_punctuation = re.compile(r"(?0.02: # we will leave 2% of the data untouched, to teach the # model, not to "overact" on the texts new_line = delete_word(line) new_line = delete_characters(new_line) new_line = insert_characters(new_line) new_line = replace_characters(new_line) new_line = random_neighbor_replace(new_line, KEYBOARD_KEYS, KEYBOARD_BLANK) new_line = remove_punctuation(new_line) else: new_line = line output.write(f'"{new_line.strip()}","{line.strip()}"\n') os.system(f"echo \"text,summary\" > {language}.train.csv") num_lines = sum(1 for line in open(f"{language}.csv",'r')) os.system(f"head -n {num_lines-2000} {language}.csv >> {language}.train.csv") os.system(f"echo \"text,summary\" > {language}.test.csv") os.system(f"tail -n 2000 {language}.csv >> {language}.test.csv")