xlm-roberta-capu / utils.py
dragonSwing's picture
Initialize commit
6fe5f76
raw
history blame
7.9 kB
import os
from pathlib import Path
import re
VOCAB_DIR = Path(__file__).resolve().parent
PAD = "@@PADDING@@"
UNK = "@@UNKNOWN@@"
START_TOKEN = "$START"
SEQ_DELIMETERS = {"tokens": " ", "labels": "SEPL|||SEPR", "operations": "SEPL__SEPR"}
def get_verb_form_dicts():
path_to_dict = os.path.join(VOCAB_DIR, "verb-form-vocab.txt")
encode, decode = {}, {}
with open(path_to_dict, encoding="utf-8") as f:
for line in f:
words, tags = line.split(":")
word1, word2 = words.split("_")
tag1, tag2 = tags.split("_")
decode_key = f"{word1}_{tag1}_{tag2.strip()}"
if decode_key not in decode:
encode[words] = tags
decode[decode_key] = word2
return encode, decode
ENCODE_VERB_DICT, DECODE_VERB_DICT = get_verb_form_dicts()
def get_target_sent_by_edits(source_tokens, edits):
target_tokens = source_tokens[:]
shift_idx = 0
for edit in edits:
start, end, label, _ = edit
target_pos = start + shift_idx
if start < 0:
continue
elif len(target_tokens) > target_pos:
source_token = target_tokens[target_pos]
else:
source_token = ""
if label == "":
del target_tokens[target_pos]
shift_idx -= 1
elif start == end:
word = label.replace("$APPEND_", "")
# Avoid appending same token twice
if (target_pos < len(target_tokens) and target_tokens[target_pos] == word) or (
target_pos > 0 and target_tokens[target_pos - 1] == word
):
continue
target_tokens[target_pos:target_pos] = [word]
shift_idx += 1
elif label.startswith("$TRANSFORM_"):
word = apply_reverse_transformation(source_token, label)
if word is None:
word = source_token
target_tokens[target_pos] = word
elif start == end - 1:
word = label.replace("$REPLACE_", "")
target_tokens[target_pos] = word
elif label.startswith("$MERGE_"):
target_tokens[target_pos + 1 : target_pos + 1] = [label]
shift_idx += 1
return replace_merge_transforms(target_tokens)
def replace_merge_transforms(tokens):
if all(not x.startswith("$MERGE_") for x in tokens):
return tokens
if tokens[0].startswith("$MERGE_"):
tokens = tokens[1:]
if tokens[-1].startswith("$MERGE_"):
tokens = tokens[:-1]
target_line = " ".join(tokens)
target_line = target_line.replace(" $MERGE_HYPHEN ", "-")
target_line = target_line.replace(" $MERGE_SPACE ", "")
target_line = re.sub(r'([\.\,\?\:]\s+)+', r'\1', target_line)
return target_line.split()
def convert_using_case(token, smart_action):
if not smart_action.startswith("$TRANSFORM_CASE_"):
return token
if smart_action.endswith("LOWER"):
return token.lower()
elif smart_action.endswith("UPPER"):
return token.upper()
elif smart_action.endswith("CAPITAL"):
return token.capitalize()
elif smart_action.endswith("CAPITAL_1"):
return token[0] + token[1:].capitalize()
elif smart_action.endswith("UPPER_-1"):
return token[:-1].upper() + token[-1]
else:
return token
def convert_using_verb(token, smart_action):
key_word = "$TRANSFORM_VERB_"
if not smart_action.startswith(key_word):
raise Exception(f"Unknown action type {smart_action}")
encoding_part = f"{token}_{smart_action[len(key_word):]}"
decoded_target_word = decode_verb_form(encoding_part)
return decoded_target_word
def convert_using_split(token, smart_action):
key_word = "$TRANSFORM_SPLIT"
if not smart_action.startswith(key_word):
raise Exception(f"Unknown action type {smart_action}")
target_words = token.split("-")
return " ".join(target_words)
def convert_using_plural(token, smart_action):
if smart_action.endswith("PLURAL"):
return token + "s"
elif smart_action.endswith("SINGULAR"):
return token[:-1]
else:
raise Exception(f"Unknown action type {smart_action}")
def apply_reverse_transformation(source_token, transform):
if transform.startswith("$TRANSFORM"):
# deal with equal
if transform == "$KEEP":
return source_token
# deal with case
if transform.startswith("$TRANSFORM_CASE"):
return convert_using_case(source_token, transform)
# deal with verb
if transform.startswith("$TRANSFORM_VERB"):
return convert_using_verb(source_token, transform)
# deal with split
if transform.startswith("$TRANSFORM_SPLIT"):
return convert_using_split(source_token, transform)
# deal with single/plural
if transform.startswith("$TRANSFORM_AGREEMENT"):
return convert_using_plural(source_token, transform)
# raise exception if not find correct type
raise Exception(f"Unknown action type {transform}")
else:
return source_token
# def read_parallel_lines(fn1, fn2):
# lines1 = read_lines(fn1, skip_strip=True)
# lines2 = read_lines(fn2, skip_strip=True)
# assert len(lines1) == len(lines2)
# out_lines1, out_lines2 = [], []
# for line1, line2 in zip(lines1, lines2):
# if not line1.strip() or not line2.strip():
# continue
# else:
# out_lines1.append(line1)
# out_lines2.append(line2)
# return out_lines1, out_lines2
def read_parallel_lines(fn1, fn2):
with open(fn1, encoding='utf-8') as f1, open(fn2, encoding='utf-8') as f2:
for line1, line2 in zip(f1, f2):
line1 = line1.strip()
line2 = line2.strip()
yield line1, line2
def read_lines(fn, skip_strip=False):
if not os.path.exists(fn):
return []
with open(fn, 'r', encoding='utf-8') as f:
lines = f.readlines()
return [s.strip() for s in lines if s.strip() or skip_strip]
def write_lines(fn, lines, mode='w'):
if mode == 'w' and os.path.exists(fn):
os.remove(fn)
with open(fn, encoding='utf-8', mode=mode) as f:
f.writelines(['%s\n' % s for s in lines])
def decode_verb_form(original):
return DECODE_VERB_DICT.get(original)
def encode_verb_form(original_word, corrected_word):
decoding_request = original_word + "_" + corrected_word
decoding_response = ENCODE_VERB_DICT.get(decoding_request, "").strip()
if original_word and decoding_response:
answer = decoding_response
else:
answer = None
return answer
def get_weights_name(transformer_name, lowercase):
if transformer_name == 'bert' and lowercase:
return 'bert-base-uncased'
if transformer_name == 'bert' and not lowercase:
return 'bert-base-cased'
if transformer_name == 'bert-large' and not lowercase:
return 'bert-large-cased'
if transformer_name == 'distilbert':
if not lowercase:
print('Warning! This model was trained only on uncased sentences.')
return 'distilbert-base-uncased'
if transformer_name == 'albert':
if not lowercase:
print('Warning! This model was trained only on uncased sentences.')
return 'albert-base-v1'
if lowercase:
print('Warning! This model was trained only on cased sentences.')
if transformer_name == 'roberta':
return 'roberta-base'
if transformer_name == 'roberta-large':
return 'roberta-large'
if transformer_name == 'gpt2':
return 'gpt2'
if transformer_name == 'transformerxl':
return 'transfo-xl-wt103'
if transformer_name == 'xlnet':
return 'xlnet-base-cased'
if transformer_name == 'xlnet-large':
return 'xlnet-large-cased'
return transformer_name