|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import json |
|
import os |
|
import re |
|
|
|
|
|
class InputExample: |
|
def __init__(self, paragraph, qa_list, label): |
|
self.paragraph = paragraph |
|
self.qa_list = qa_list |
|
self.label = label |
|
|
|
|
|
def get_examples(data_dir, set_type): |
|
""" |
|
Extract paragraph and question-answer list from each json file |
|
""" |
|
examples = [] |
|
|
|
levels = ["middle", "high"] |
|
set_type_c = set_type.split("-") |
|
if len(set_type_c) == 2: |
|
levels = [set_type_c[1]] |
|
set_type = set_type_c[0] |
|
for level in levels: |
|
cur_dir = os.path.join(data_dir, set_type, level) |
|
for filename in os.listdir(cur_dir): |
|
cur_path = os.path.join(cur_dir, filename) |
|
with open(cur_path, "r") as f: |
|
cur_data = json.load(f) |
|
answers = cur_data["answers"] |
|
options = cur_data["options"] |
|
questions = cur_data["questions"] |
|
context = cur_data["article"].replace("\n", " ") |
|
context = re.sub(r"\s+", " ", context) |
|
for i in range(len(answers)): |
|
label = ord(answers[i]) - ord("A") |
|
qa_list = [] |
|
question = questions[i] |
|
for j in range(4): |
|
option = options[i][j] |
|
if "_" in question: |
|
qa_cat = question.replace("_", option) |
|
else: |
|
qa_cat = " ".join([question, option]) |
|
qa_cat = re.sub(r"\s+", " ", qa_cat) |
|
qa_list.append(qa_cat) |
|
examples.append(InputExample(context, qa_list, label)) |
|
|
|
return examples |
|
|
|
|
|
def main(): |
|
""" |
|
Helper script to extract paragraphs questions and answers from RACE datasets. |
|
""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--input-dir", |
|
help="input directory for downloaded RACE dataset", |
|
) |
|
parser.add_argument( |
|
"--output-dir", |
|
help="output directory for extracted data", |
|
) |
|
args = parser.parse_args() |
|
|
|
if not os.path.exists(args.output_dir): |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
for set_type in ["train", "dev", "test-middle", "test-high"]: |
|
examples = get_examples(args.input_dir, set_type) |
|
qa_file_paths = [ |
|
os.path.join(args.output_dir, set_type + ".input" + str(i + 1)) |
|
for i in range(4) |
|
] |
|
qa_files = [open(qa_file_path, "w") for qa_file_path in qa_file_paths] |
|
outf_context_path = os.path.join(args.output_dir, set_type + ".input0") |
|
outf_label_path = os.path.join(args.output_dir, set_type + ".label") |
|
outf_context = open(outf_context_path, "w") |
|
outf_label = open(outf_label_path, "w") |
|
for example in examples: |
|
outf_context.write(example.paragraph + "\n") |
|
for i in range(4): |
|
qa_files[i].write(example.qa_list[i] + "\n") |
|
outf_label.write(str(example.label) + "\n") |
|
|
|
for f in qa_files: |
|
f.close() |
|
outf_label.close() |
|
outf_context.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|