Spaces:
Build error
Build error
add initial files
Browse files- app.py +64 -0
- requirements.txt +2 -0
- weakly_supervised_parser/__init__.py +0 -0
- weakly_supervised_parser/inference.py +145 -0
- weakly_supervised_parser/model/__init__.py +0 -0
- weakly_supervised_parser/model/data_module_loader.py +79 -0
- weakly_supervised_parser/model/span_classifier.py +95 -0
- weakly_supervised_parser/model/trainer.py +128 -0
- weakly_supervised_parser/settings.py +33 -0
- weakly_supervised_parser/tree/__init__.py +0 -0
- weakly_supervised_parser/tree/evaluate.py +221 -0
- weakly_supervised_parser/tree/helpers.py +177 -0
- weakly_supervised_parser/utils/__init__.py +0 -0
- weakly_supervised_parser/utils/cky_algorithm.py +91 -0
- weakly_supervised_parser/utils/create_inside_outside_strings.py +40 -0
- weakly_supervised_parser/utils/distant_supervision.py +40 -0
- weakly_supervised_parser/utils/populate_chart.py +95 -0
- weakly_supervised_parser/utils/prepare_dataset.py +165 -0
app.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio
|
2 |
+
import benepar
|
3 |
+
import spacy
|
4 |
+
import nltk
|
5 |
+
|
6 |
+
from huggingface_hub import hf_hub_url, cached_download
|
7 |
+
|
8 |
+
from weakly_supervised_parser.tree.evaluate import calculate_F1_for_spans, tree_to_spans
|
9 |
+
from weakly_supervised_parser.inference import Predictor
|
10 |
+
from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier
|
11 |
+
|
12 |
+
benepar.download('benepar_en3')
|
13 |
+
|
14 |
+
nlp = spacy.load("en_core_web_md")
|
15 |
+
nlp.add_pipe("benepar", config={"model": "benepar_en3"})
|
16 |
+
|
17 |
+
inside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=256)
|
18 |
+
fetch_url_inside_model = hf_hub_url(repo_id="nickil/weakly-supervised-parsing", filename="inside_model.onnx", revision="main")
|
19 |
+
inside_model.load_model(pre_trained_model_path=cached_download(fetch_url_inside_model))
|
20 |
+
|
21 |
+
# outside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=64)
|
22 |
+
# outside_model.load_model(pre_trained_model_path=TRAINED_MODEL_PATH + "outside_model.onnx")
|
23 |
+
|
24 |
+
# inside_outside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=256)
|
25 |
+
# inside_outside_model.load_model(pre_trained_model_path=TRAINED_MODEL_PATH + "inside_outside_model.onnx")
|
26 |
+
|
27 |
+
|
28 |
+
def predict(sentence, model):
|
29 |
+
gold_standard = list(nlp(sentence).sents)[0]._.parse_string
|
30 |
+
if model == "inside":
|
31 |
+
best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="inside", model=inside_model, scale_axis=1, predict_batch_size=128)
|
32 |
+
elif model == "outside":
|
33 |
+
best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="outside", model=outside_model, scale_axis=1, predict_batch_size=128)
|
34 |
+
elif model == "inside-outside":
|
35 |
+
best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="inside_outside", model=inside_outside_model, scale_axis=1, predict_batch_size=128)
|
36 |
+
sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard), tree_to_spans(best_parse))
|
37 |
+
return gold_standard, best_parse, sentence_f1
|
38 |
+
|
39 |
+
|
40 |
+
iface = gradio.Interface(
|
41 |
+
title="Co-training an Unsupervised Constituency Parser with Weak Supervision",
|
42 |
+
description="Demo for the repository - [weakly-supervised-parsing](https://github.com/Nickil21/weakly-supervised-parsing) (ACL Findings 2022)",
|
43 |
+
theme="default",
|
44 |
+
article="""<h4 class='text-lg font-semibold my-2'>Note</h4>
|
45 |
+
- We use a strong supervised parsing model `benepar_en3` which is based on T5-small to compute the gold parse.<br>
|
46 |
+
- Sentence F1 score corresponds to the macro F1 score.
|
47 |
+
""",
|
48 |
+
allow_flagging="never",
|
49 |
+
fn=predict,
|
50 |
+
inputs=[
|
51 |
+
gradio.inputs.Textbox(label="Sentence", placeholder="Enter a sentence in English"),
|
52 |
+
gradio.inputs.Radio(["inside", "outside", "inside-outside"], default="inside", label="Choose Model"),
|
53 |
+
],
|
54 |
+
outputs=[
|
55 |
+
gradio.outputs.Textbox(label="Gold Parse Tree"),
|
56 |
+
gradio.outputs.Textbox(label="Predicted Parse Tree"),
|
57 |
+
gradio.outputs.Textbox(label="F1 score"),
|
58 |
+
],
|
59 |
+
examples=[
|
60 |
+
["Russia 's war on Ukraine unsettles investors expecting carve-out deal uptick for 2022 .", "inside-outside"],
|
61 |
+
["Bitcoin community under pressure to cut energy use .", "inside"],
|
62 |
+
],
|
63 |
+
)
|
64 |
+
iface.launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
spacy==3.1.4
|
2 |
+
benepar==0.2.0
|
weakly_supervised_parser/__init__.py
ADDED
File without changes
|
weakly_supervised_parser/inference.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser
|
2 |
+
from loguru import logger
|
3 |
+
|
4 |
+
from weakly_supervised_parser.settings import TRAINED_MODEL_PATH
|
5 |
+
from weakly_supervised_parser.utils.prepare_dataset import DataLoaderHelper
|
6 |
+
from weakly_supervised_parser.utils.populate_chart import PopulateCKYChart
|
7 |
+
from weakly_supervised_parser.tree.evaluate import calculate_F1_for_spans, tree_to_spans
|
8 |
+
from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier
|
9 |
+
from weakly_supervised_parser.settings import PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH, PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH
|
10 |
+
|
11 |
+
|
12 |
+
class Predictor:
|
13 |
+
def __init__(self, sentence):
|
14 |
+
self.sentence = sentence
|
15 |
+
self.sentence_list = sentence.split()
|
16 |
+
|
17 |
+
def obtain_best_parse(self, predict_type, model, scale_axis, predict_batch_size, return_df=False):
|
18 |
+
unique_tokens_flag, span_scores, df = PopulateCKYChart(sentence=self.sentence).fill_chart(predict_type=predict_type,
|
19 |
+
model=model,
|
20 |
+
scale_axis=scale_axis,
|
21 |
+
predict_batch_size=predict_batch_size)
|
22 |
+
|
23 |
+
if unique_tokens_flag:
|
24 |
+
best_parse = "(S " + " ".join(["(S " + item + ")" for item in self.sentence_list]) + ")"
|
25 |
+
logger.info("BEST PARSE", best_parse)
|
26 |
+
else:
|
27 |
+
best_parse = PopulateCKYChart(sentence=self.sentence).best_parse_tree(span_scores)
|
28 |
+
if return_df:
|
29 |
+
return best_parse, df
|
30 |
+
return best_parse
|
31 |
+
|
32 |
+
|
33 |
+
def process_test_sample(index, sentence, gold_file_path, predict_type, model, scale_axis, predict_batch_size, return_df=False):
|
34 |
+
best_parse, df = Predictor(sentence=sentence).obtain_best_parse(predict_type=predict_type,
|
35 |
+
model=model,
|
36 |
+
scale_axis=scale_axis,
|
37 |
+
predict_batch_size=predict_batch_size,
|
38 |
+
return_df=True)
|
39 |
+
gold_standard = DataLoaderHelper(input_file_object=gold_file_path)
|
40 |
+
sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard[index]), tree_to_spans(best_parse))
|
41 |
+
if sentence_f1 < 25.0:
|
42 |
+
logger.warning(f"Index: {index} <> F1: {sentence_f1:.2f}")
|
43 |
+
else:
|
44 |
+
logger.info(f"Index: {index} <> F1: {sentence_f1:.2f}")
|
45 |
+
if return_df:
|
46 |
+
return best_parse, df
|
47 |
+
else:
|
48 |
+
return best_parse
|
49 |
+
|
50 |
+
|
51 |
+
def process_co_train_test_sample(index, sentence, gold_file_path, inside_model, outside_model, return_df=False):
|
52 |
+
_, df_inside = PopulateCKYChart(sentence=sentence).compute_scores(predict_type="inside", model=inside_model, return_df=True)
|
53 |
+
_, df_outside = PopulateCKYChart(sentence=sentence).compute_scores(predict_type="outside", model=outside_model, return_df=True)
|
54 |
+
df = df_inside.copy()
|
55 |
+
df["scores"] = df_inside["scores"] * df_outside["scores"]
|
56 |
+
_, span_scores, df = PopulateCKYChart(sentence=sentence).fill_chart(data=df)
|
57 |
+
best_parse = PopulateCKYChart(sentence=sentence).best_parse_tree(span_scores)
|
58 |
+
gold_standard = DataLoaderHelper(input_file_object=gold_file_path)
|
59 |
+
sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard[index]), tree_to_spans(best_parse))
|
60 |
+
if sentence_f1 < 25.0:
|
61 |
+
logger.warning(f"Index: {index} <> F1: {sentence_f1:.2f}")
|
62 |
+
else:
|
63 |
+
logger.info(f"Index: {index} <> F1: {sentence_f1:.2f}")
|
64 |
+
return best_parse
|
65 |
+
|
66 |
+
|
67 |
+
def main():
|
68 |
+
parser = ArgumentParser(description="Inference Pipeline for the Inside Outside String Classifier", add_help=True)
|
69 |
+
|
70 |
+
group = parser.add_mutually_exclusive_group(required=True)
|
71 |
+
|
72 |
+
group.add_argument("--use_inside", action="store_true", help="Whether to predict using inside model")
|
73 |
+
|
74 |
+
group.add_argument("--use_inside_self_train", action="store_true", help="Whether to predict using inside model with self-training")
|
75 |
+
|
76 |
+
group.add_argument("--use_outside", action="store_true", help="Whether to predict using outside model")
|
77 |
+
|
78 |
+
group.add_argument("--use_inside_outside_co_train", action="store_true", help="Whether to predict using inside-outside model with co-training")
|
79 |
+
|
80 |
+
parser.add_argument("--model_name_or_path", type=str, default="roberta-base", help="Path to the model identifier from huggingface.co/models")
|
81 |
+
|
82 |
+
parser.add_argument("--save_path", type=str, required=True, help="Path to save the final trees")
|
83 |
+
|
84 |
+
parser.add_argument("--scale_axis", choices=[None, 1], default=None, help="Whether to scale axis globally (None) or sequentially (1) across batches during softmax computation")
|
85 |
+
|
86 |
+
parser.add_argument("--predict_batch_size", type=int, help="Batch size during inference")
|
87 |
+
|
88 |
+
parser.add_argument(
|
89 |
+
"--inside_max_seq_length", default=256, type=int, help="The maximum total input sequence length after tokenization for the inside model"
|
90 |
+
)
|
91 |
+
|
92 |
+
parser.add_argument(
|
93 |
+
"--outside_max_seq_length", default=64, type=int, help="The maximum total input sequence length after tokenization for the outside model"
|
94 |
+
)
|
95 |
+
|
96 |
+
args = parser.parse_args()
|
97 |
+
|
98 |
+
if args.use_inside:
|
99 |
+
pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model.onnx"
|
100 |
+
max_seq_length = args.inside_max_seq_length
|
101 |
+
|
102 |
+
if args.use_inside_self_train:
|
103 |
+
pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model_self_trained.onnx"
|
104 |
+
max_seq_length = args.inside_max_seq_length
|
105 |
+
|
106 |
+
if args.use_outside:
|
107 |
+
pre_trained_model_path = TRAINED_MODEL_PATH + "outside_model.onnx"
|
108 |
+
max_seq_length = args.outside_max_seq_length
|
109 |
+
|
110 |
+
if args.use_inside_outside_co_train:
|
111 |
+
inside_pre_trained_model_path = "inside_model_co_trained.onnx"
|
112 |
+
inside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.inside_max_seq_length)
|
113 |
+
inside_model.load_model(pre_trained_model_path=inside_pre_trained_model_path)
|
114 |
+
|
115 |
+
outside_pre_trained_model_path = "outside_model_co_trained.onnx"
|
116 |
+
outside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.outside_max_seq_length)
|
117 |
+
outside_model.load_model(pre_trained_model_path=outside_pre_trained_model_path)
|
118 |
+
else:
|
119 |
+
model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=max_seq_length)
|
120 |
+
model.load_model(pre_trained_model_path=pre_trained_model_path)
|
121 |
+
|
122 |
+
if args.use_inside or args.use_inside_self_train:
|
123 |
+
predict_type = "inside"
|
124 |
+
|
125 |
+
if args.use_outside:
|
126 |
+
predict_type = "outside"
|
127 |
+
|
128 |
+
with open(args.save_path, "w") as out_file:
|
129 |
+
print(type(args.scale_axis))
|
130 |
+
test_sentences = DataLoaderHelper(input_file_object=PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH).read_lines()
|
131 |
+
test_gold_file_path = PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH
|
132 |
+
for test_index, test_sentence in enumerate(test_sentences):
|
133 |
+
if args.use_inside_outside_co_train:
|
134 |
+
best_parse = process_co_train_test_sample(
|
135 |
+
test_index, test_sentence, test_gold_file_path, inside_model=inside_model, outside_model=outside_model
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
best_parse = process_test_sample(test_index, test_sentence, test_gold_file_path, predict_type=predict_type, model=model,
|
139 |
+
scale_axis=args.scale_axis, predict_batch_size=args.predict_batch_size)
|
140 |
+
|
141 |
+
out_file.write(best_parse + "\n")
|
142 |
+
|
143 |
+
|
144 |
+
if __name__ == "__main__":
|
145 |
+
main()
|
weakly_supervised_parser/model/__init__.py
ADDED
File without changes
|
weakly_supervised_parser/model/data_module_loader.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
from pytorch_lightning import LightningDataModule
|
6 |
+
|
7 |
+
|
8 |
+
class PyTorchDataModule(Dataset):
|
9 |
+
"""PyTorch Dataset class"""
|
10 |
+
|
11 |
+
def __init__(self, model_name_or_path: str, data: pd.DataFrame, max_seq_length: int = 256):
|
12 |
+
"""
|
13 |
+
Initiates a PyTorch Dataset Module for input data
|
14 |
+
"""
|
15 |
+
self.model_name_or_path = model_name_or_path
|
16 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
17 |
+
self.data = data
|
18 |
+
self.max_seq_length = max_seq_length
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
"""returns length of data"""
|
22 |
+
return len(self.data)
|
23 |
+
|
24 |
+
def __getitem__(self, index: int):
|
25 |
+
"""returns dictionary of input tensors to feed into the model"""
|
26 |
+
|
27 |
+
data_row = self.data.iloc[index]
|
28 |
+
sentence = data_row["sentence"]
|
29 |
+
|
30 |
+
sentence_encoding = self.tokenizer(
|
31 |
+
sentence,
|
32 |
+
max_length=self.max_seq_length,
|
33 |
+
padding="max_length",
|
34 |
+
truncation=True,
|
35 |
+
add_special_tokens=True,
|
36 |
+
return_tensors="pt",
|
37 |
+
)
|
38 |
+
|
39 |
+
out = dict(
|
40 |
+
sentence=sentence,
|
41 |
+
input_ids=sentence_encoding["input_ids"].flatten(),
|
42 |
+
attention_mask=sentence_encoding["attention_mask"].flatten(),
|
43 |
+
labels=data_row["label"].flatten(),
|
44 |
+
)
|
45 |
+
|
46 |
+
return out
|
47 |
+
|
48 |
+
|
49 |
+
class DataModule(LightningDataModule):
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
model_name_or_path: str,
|
53 |
+
train_df: pd.DataFrame,
|
54 |
+
eval_df: pd.DataFrame,
|
55 |
+
max_seq_length: int = 256,
|
56 |
+
train_batch_size: int = 32,
|
57 |
+
eval_batch_size: int = 32,
|
58 |
+
num_workers: int = 16,
|
59 |
+
**kwargs
|
60 |
+
):
|
61 |
+
super().__init__()
|
62 |
+
self.model_name_or_path = model_name_or_path
|
63 |
+
self.train_df = train_df
|
64 |
+
self.eval_df = eval_df
|
65 |
+
self.max_seq_length = max_seq_length
|
66 |
+
self.train_batch_size = train_batch_size
|
67 |
+
self.eval_batch_size = eval_batch_size
|
68 |
+
self.num_workers = num_workers
|
69 |
+
|
70 |
+
def setup(self, stage=None):
|
71 |
+
|
72 |
+
self.train_dataset = PyTorchDataModule(self.model_name_or_path, self.train_df, self.max_seq_length)
|
73 |
+
self.eval_dataset = PyTorchDataModule(self.model_name_or_path, self.eval_df, self.max_seq_length)
|
74 |
+
|
75 |
+
def train_dataloader(self) -> DataLoader:
|
76 |
+
return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True)
|
77 |
+
|
78 |
+
def val_dataloader(self) -> DataLoader:
|
79 |
+
return DataLoader(self.eval_dataset, batch_size=self.eval_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True)
|
weakly_supervised_parser/model/span_classifier.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchmetrics
|
3 |
+
from torch.optim import AdamW
|
4 |
+
from pytorch_lightning import LightningModule
|
5 |
+
from transformers import AutoConfig, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
|
6 |
+
|
7 |
+
|
8 |
+
class LightningModel(LightningModule):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
model_name_or_path: str,
|
12 |
+
num_labels: int = 2,
|
13 |
+
lr: float = 5e-6,
|
14 |
+
train_batch_size: int = 32,
|
15 |
+
adam_epsilon=1e-8,
|
16 |
+
warmup_steps: int = 0,
|
17 |
+
weight_decay: float = 0.0,
|
18 |
+
**kwargs
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
self.save_hyperparameters()
|
23 |
+
|
24 |
+
self.num_labels = num_labels
|
25 |
+
self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=self.num_labels)
|
26 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)
|
27 |
+
self.model.gradient_checkpointing_enable()
|
28 |
+
self.lr = lr
|
29 |
+
self.train_batch_size = train_batch_size
|
30 |
+
self.accuracy = torchmetrics.Accuracy()
|
31 |
+
self.f1score = torchmetrics.F1Score(num_classes=2)
|
32 |
+
self.mcc = torchmetrics.MatthewsCorrCoef(num_classes=2)
|
33 |
+
|
34 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
35 |
+
return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
36 |
+
|
37 |
+
def training_step(self, batch, batch_idx):
|
38 |
+
outputs = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
39 |
+
loss = outputs[0]
|
40 |
+
return loss
|
41 |
+
|
42 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
43 |
+
outputs = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
44 |
+
val_loss, logits = outputs[:2]
|
45 |
+
preds = torch.argmax(logits, axis=1)
|
46 |
+
labels = batch["labels"]
|
47 |
+
return {"loss": val_loss, "preds": preds, "labels": labels}
|
48 |
+
|
49 |
+
def validation_epoch_end(self, outputs):
|
50 |
+
preds = torch.cat([x["preds"] for x in outputs])
|
51 |
+
labels = torch.cat([x["labels"] for x in outputs])
|
52 |
+
loss = torch.stack([x["loss"] for x in outputs]).mean()
|
53 |
+
|
54 |
+
self.log("val_loss", loss, prog_bar=True)
|
55 |
+
self.log("val_accuracy", self.accuracy(preds, labels.squeeze()), prog_bar=True)
|
56 |
+
self.log("val_f1", self.f1score(preds, labels.squeeze()), prog_bar=True)
|
57 |
+
self.log("val_mcc", self.mcc(preds, labels.squeeze()), prog_bar=True)
|
58 |
+
return loss
|
59 |
+
|
60 |
+
def setup(self, stage=None):
|
61 |
+
# Get dataloader by calling it - train_dataloader() is called after setup() by default
|
62 |
+
train_loader = self.trainer.datamodule.train_dataloader()
|
63 |
+
|
64 |
+
# Calculate total steps
|
65 |
+
tb_size = self.train_batch_size * max(1, self.trainer.gpus)
|
66 |
+
ab_size = tb_size * self.trainer.accumulate_grad_batches
|
67 |
+
self.total_steps = int((len(train_loader.dataset) / ab_size) * float(self.trainer.max_epochs))
|
68 |
+
|
69 |
+
def configure_optimizers(self):
|
70 |
+
"""Prepare optimizer and schedule (linear warmup and decay)"""
|
71 |
+
model = self.model
|
72 |
+
no_decay = ["bias", "LayerNorm.weight"]
|
73 |
+
optimizer_grouped_parameters = [
|
74 |
+
{
|
75 |
+
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
76 |
+
"weight_decay": self.hparams.weight_decay,
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
80 |
+
"weight_decay": 0.0,
|
81 |
+
},
|
82 |
+
]
|
83 |
+
optimizer = AdamW(
|
84 |
+
optimizer_grouped_parameters,
|
85 |
+
lr=self.lr,
|
86 |
+
eps=self.hparams.adam_epsilon,
|
87 |
+
)
|
88 |
+
|
89 |
+
scheduler = get_linear_schedule_with_warmup(
|
90 |
+
optimizer,
|
91 |
+
num_warmup_steps=self.hparams.warmup_steps,
|
92 |
+
num_training_steps=self.total_steps,
|
93 |
+
)
|
94 |
+
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
|
95 |
+
return [optimizer], [scheduler]
|
weakly_supervised_parser/model/trainer.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import datasets
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
8 |
+
|
9 |
+
from pytorch_lightning import Trainer, seed_everything
|
10 |
+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
11 |
+
from transformers import AutoTokenizer, logging
|
12 |
+
|
13 |
+
from onnxruntime import InferenceSession
|
14 |
+
from scipy.special import softmax
|
15 |
+
|
16 |
+
from weakly_supervised_parser.model.data_module_loader import DataModule
|
17 |
+
from weakly_supervised_parser.model.span_classifier import LightningModel
|
18 |
+
|
19 |
+
|
20 |
+
# Disable model checkpoint warnings
|
21 |
+
logging.set_verbosity_error()
|
22 |
+
|
23 |
+
|
24 |
+
class InsideOutsideStringClassifier:
|
25 |
+
def __init__(self, model_name_or_path: str, num_labels: int = 2, max_seq_length: int = 256):
|
26 |
+
|
27 |
+
self.model_name_or_path = model_name_or_path
|
28 |
+
self.num_labels = num_labels
|
29 |
+
self.max_seq_length = max_seq_length
|
30 |
+
|
31 |
+
def fit(
|
32 |
+
self,
|
33 |
+
train_df: pd.DataFrame,
|
34 |
+
eval_df: pd.DataFrame,
|
35 |
+
outputdir: str,
|
36 |
+
filename: str,
|
37 |
+
devices: int = 1,
|
38 |
+
enable_progress_bar: bool = True,
|
39 |
+
enable_model_summary: bool = False,
|
40 |
+
enable_checkpointing: bool = False,
|
41 |
+
logger: bool = False,
|
42 |
+
accelerator: str = "auto",
|
43 |
+
train_batch_size: int = 32,
|
44 |
+
eval_batch_size: int = 32,
|
45 |
+
learning_rate: float = 5e-6,
|
46 |
+
max_epochs: int = 10,
|
47 |
+
dataloader_num_workers: int = 16,
|
48 |
+
seed: int = 42,
|
49 |
+
):
|
50 |
+
|
51 |
+
data_module = DataModule(
|
52 |
+
model_name_or_path=self.model_name_or_path,
|
53 |
+
train_df=train_df,
|
54 |
+
eval_df=eval_df,
|
55 |
+
max_seq_length=self.max_seq_length,
|
56 |
+
train_batch_size=train_batch_size,
|
57 |
+
eval_batch_size=eval_batch_size,
|
58 |
+
num_workers=dataloader_num_workers,
|
59 |
+
)
|
60 |
+
|
61 |
+
model = LightningModel(
|
62 |
+
model_name_or_path=self.model_name_or_path,
|
63 |
+
lr=learning_rate,
|
64 |
+
num_labels=self.num_labels,
|
65 |
+
train_batch_size=train_batch_size,
|
66 |
+
eval_batch_size=eval_batch_size,
|
67 |
+
)
|
68 |
+
|
69 |
+
seed_everything(seed, workers=True)
|
70 |
+
|
71 |
+
callbacks = []
|
72 |
+
callbacks.append(EarlyStopping(monitor="val_loss", patience=2, mode="min", check_finite=True))
|
73 |
+
# callbacks.append(ModelCheckpoint(monitor="val_loss", dirpath=outputdir, filename=filename, save_top_k=1, save_weights_only=True, mode="min"))
|
74 |
+
|
75 |
+
trainer = Trainer(
|
76 |
+
accelerator=accelerator,
|
77 |
+
devices=devices,
|
78 |
+
max_epochs=max_epochs,
|
79 |
+
callbacks=callbacks,
|
80 |
+
enable_progress_bar=enable_progress_bar,
|
81 |
+
enable_model_summary=enable_model_summary,
|
82 |
+
enable_checkpointing=enable_checkpointing,
|
83 |
+
logger=logger,
|
84 |
+
)
|
85 |
+
trainer.fit(model, data_module)
|
86 |
+
trainer.validate(model, data_module.val_dataloader())
|
87 |
+
|
88 |
+
train_batch = next(iter(data_module.train_dataloader()))
|
89 |
+
|
90 |
+
model.to_onnx(
|
91 |
+
file_path=f"{outputdir}/{filename}.onnx",
|
92 |
+
input_sample=(train_batch["input_ids"].cuda(), train_batch["attention_mask"].cuda()),
|
93 |
+
export_params=True,
|
94 |
+
opset_version=11,
|
95 |
+
input_names=["input", "attention_mask"],
|
96 |
+
output_names=["output"],
|
97 |
+
dynamic_axes={"input": {0: "batch_size"}, "attention_mask": {0: "batch_size"}, "output": {0: "batch_size"}},
|
98 |
+
)
|
99 |
+
|
100 |
+
def load_model(self, pre_trained_model_path):
|
101 |
+
self.model = InferenceSession(pre_trained_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
102 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
103 |
+
|
104 |
+
def preprocess_function(self, data):
|
105 |
+
features = self.tokenizer(
|
106 |
+
data["sentence"], max_length=self.max_seq_length, padding="max_length", add_special_tokens=True, truncation=True, return_tensors="np"
|
107 |
+
)
|
108 |
+
return features
|
109 |
+
|
110 |
+
def process_spans(self, spans, scale_axis):
|
111 |
+
spans_dataset = datasets.Dataset.from_pandas(spans)
|
112 |
+
processed = spans_dataset.map(self.preprocess_function, batched=True, batch_size=None)
|
113 |
+
inputs = {"input": processed["input_ids"], "attention_mask": processed["attention_mask"]}
|
114 |
+
with torch.no_grad():
|
115 |
+
return softmax(self.model.run(None, inputs)[0], axis=scale_axis)
|
116 |
+
|
117 |
+
def predict_proba(self, spans, scale_axis, predict_batch_size):
|
118 |
+
if spans.shape[0] > predict_batch_size:
|
119 |
+
output = []
|
120 |
+
span_batches = np.array_split(spans, spans.shape[0] // predict_batch_size)
|
121 |
+
for span_batch in span_batches:
|
122 |
+
output.extend(self.process_spans(span_batch, scale_axis))
|
123 |
+
return np.vstack(output)
|
124 |
+
else:
|
125 |
+
return self.process_spans(spans, scale_axis)
|
126 |
+
|
127 |
+
def predict(self, spans):
|
128 |
+
return self.predict_proba(spans).argmax(axis=1)
|
weakly_supervised_parser/settings.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PROJECT_DIR = "weakly_supervised_parser/"
|
2 |
+
PTB_TREES_ROOT_DIR = "data/PROCESSED/english/trees/"
|
3 |
+
PTB_SENTENCES_ROOT_DIR = "data/PROCESSED/english/sentences/"
|
4 |
+
|
5 |
+
PTB_TRAIN_SENTENCES_WITH_PUNCTUATION_PATH = PTB_SENTENCES_ROOT_DIR + "ptb-train-sentences-with-punctuation.txt"
|
6 |
+
PTB_VALID_SENTENCES_WITH_PUNCTUATION_PATH = PTB_SENTENCES_ROOT_DIR + "ptb-valid-sentences-with-punctuation.txt"
|
7 |
+
PTB_TEST_SENTENCES_WITH_PUNCTUATION_PATH = PTB_SENTENCES_ROOT_DIR + "ptb-test-sentences-with-punctuation.txt"
|
8 |
+
|
9 |
+
PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH = PTB_SENTENCES_ROOT_DIR + "ptb-train-sentences-without-punctuation.txt"
|
10 |
+
PTB_VALID_SENTENCES_WITHOUT_PUNCTUATION_PATH = PTB_SENTENCES_ROOT_DIR + "ptb-valid-sentences-without-punctuation.txt"
|
11 |
+
PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH = PTB_SENTENCES_ROOT_DIR + "ptb-test-sentences-without-punctuation.txt"
|
12 |
+
|
13 |
+
PTB_TRAIN_GOLD_WITH_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "ptb-train-gold-with-punctuation.txt"
|
14 |
+
PTB_VALID_GOLD_WITH_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "ptb-valid-gold-with-punctuation.txt"
|
15 |
+
PTB_TEST_GOLD_WITH_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "ptb-test-gold-with-punctuation.txt"
|
16 |
+
|
17 |
+
PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "ptb-train-gold-without-punctuation.txt"
|
18 |
+
PTB_VALID_GOLD_WITHOUT_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "ptb-valid-gold-without-punctuation.txt"
|
19 |
+
PTB_TEST_GOLD_WITHOUT_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "ptb-test-gold-without-punctuation.txt"
|
20 |
+
|
21 |
+
PTB_TRAIN_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH = PTB_TREES_ROOT_DIR + "ptb-train-gold-without-punctuation-aligned.txt"
|
22 |
+
PTB_VALID_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH = PTB_TREES_ROOT_DIR + "ptb-valid-gold-without-punctuation-aligned.txt"
|
23 |
+
PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH = PTB_TREES_ROOT_DIR + "ptb-test-gold-without-punctuation-aligned.txt"
|
24 |
+
|
25 |
+
YOON_KIM_TRAIN_GOLD_WITHOUT_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "Yoon_Kim/ptb-train-gold-filtered.txt"
|
26 |
+
YOON_KIM_VALID_GOLD_WITHOUT_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "Yoon_Kim/ptb-valid-gold-filtered.txt"
|
27 |
+
YOON_KIM_TEST_GOLD_WITHOUT_PUNCTUATION_PATH = PTB_TREES_ROOT_DIR + "Yoon_Kim/ptb-test-gold-filtered.txt"
|
28 |
+
|
29 |
+
# Predictions
|
30 |
+
PTB_SAVE_TREES_PATH = "TEMP/predictions/english/"
|
31 |
+
|
32 |
+
# Training
|
33 |
+
TRAINED_MODEL_PATH = PROJECT_DIR + "/model/TRAINED_MODEL/"
|
weakly_supervised_parser/tree/__init__.py
ADDED
File without changes
|
weakly_supervised_parser/tree/evaluate.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import collections
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
|
6 |
+
import nltk
|
7 |
+
|
8 |
+
|
9 |
+
def tree_to_spans(tree, keep_labels=False, keep_leaves=False, keep_whole_span=False):
|
10 |
+
if isinstance(tree, str):
|
11 |
+
tree = nltk.Tree.fromstring(tree)
|
12 |
+
|
13 |
+
length = len(tree.pos())
|
14 |
+
queue = collections.deque(tree.treepositions())
|
15 |
+
stack = [(queue.popleft(), 0)]
|
16 |
+
j = 0
|
17 |
+
spans = []
|
18 |
+
while stack != []:
|
19 |
+
(p, i) = stack[-1]
|
20 |
+
if not queue or queue[0][:-1] != p:
|
21 |
+
if isinstance(tree[p], nltk.tree.Tree):
|
22 |
+
if j - i > 1:
|
23 |
+
spans.append((tree[p].label(), (i, j)))
|
24 |
+
else:
|
25 |
+
j = i + 1
|
26 |
+
stack.pop()
|
27 |
+
else:
|
28 |
+
q = queue.popleft()
|
29 |
+
stack.append((q, j))
|
30 |
+
if not keep_whole_span:
|
31 |
+
spans = [span for span in spans if span[1] != (0, length)]
|
32 |
+
if not keep_labels:
|
33 |
+
spans = [span[1] for span in spans]
|
34 |
+
return spans
|
35 |
+
|
36 |
+
|
37 |
+
def test_tree_to_spans():
|
38 |
+
assert [(0, 2), (0, 3), (0, 4)] == tree_to_spans("(S (S (S (S (S 1) (S 2)) (S 3)) (S 4)) (S 5))", keep_labels=False)
|
39 |
+
assert [] == tree_to_spans("(S 1)", keep_labels=False)
|
40 |
+
assert [] == tree_to_spans("(S (S 1) (S 2))", keep_labels=False)
|
41 |
+
assert [(1, 3)] == tree_to_spans("(S (S 1) (S (S 2) (S 3)))", keep_labels=False)
|
42 |
+
assert [("S", (1, 3))] == tree_to_spans("(S (S 1) (S (S 2) (S 3)))", keep_labels=True)
|
43 |
+
|
44 |
+
|
45 |
+
def get_F1_score_intermediates(gold_spans, pred_spans):
|
46 |
+
"""Get intermediate results for calculating the F1 score"""
|
47 |
+
n_true_positives = 0
|
48 |
+
gold_span_counter = collections.Counter(gold_spans)
|
49 |
+
pred_span_counter = collections.Counter(pred_spans)
|
50 |
+
unique_spans = set(gold_spans + pred_spans)
|
51 |
+
for span in unique_spans:
|
52 |
+
n_true_positives += min(gold_span_counter[span], pred_span_counter[span])
|
53 |
+
return n_true_positives, len(gold_spans), len(pred_spans)
|
54 |
+
|
55 |
+
|
56 |
+
def calculate_F1_score_from_intermediates(n_true_positives, n_golds, n_predictions, precision_recall_f_score=False):
|
57 |
+
"""Calculate F1 score"""
|
58 |
+
if precision_recall_f_score:
|
59 |
+
zeros = (0, 0, 0)
|
60 |
+
else:
|
61 |
+
zeros = 0
|
62 |
+
if n_golds == 0:
|
63 |
+
return 100 if n_predictions == 0 else zeros
|
64 |
+
if n_true_positives == 0 or n_predictions == 0:
|
65 |
+
return zeros
|
66 |
+
recall = n_true_positives / n_golds
|
67 |
+
precision = n_true_positives / n_predictions
|
68 |
+
F1 = 2 * precision * recall / (precision + recall)
|
69 |
+
if precision_recall_f_score:
|
70 |
+
return precision, recall, F1 * 100
|
71 |
+
return F1 * 100
|
72 |
+
|
73 |
+
|
74 |
+
def calculate_F1_for_spans(gold_spans, pred_spans, precision_recall_f_score=False):
|
75 |
+
# CHANGE THIS LATER
|
76 |
+
# gold_spans = list(set(gold_spans))
|
77 |
+
###################################
|
78 |
+
tp, n_gold, n_pred = get_F1_score_intermediates(gold_spans, pred_spans)
|
79 |
+
if precision_recall_f_score:
|
80 |
+
p, r, F1 = calculate_F1_score_from_intermediates(tp, len(gold_spans), len(pred_spans), precision_recall_f_score=precision_recall_f_score)
|
81 |
+
return p, r, F1
|
82 |
+
F1 = calculate_F1_score_from_intermediates(tp, len(gold_spans), len(pred_spans))
|
83 |
+
return F1
|
84 |
+
|
85 |
+
|
86 |
+
def test_calculate_F1_for_spans():
|
87 |
+
pred = [(0, 1)]
|
88 |
+
gold = [(0, 1)]
|
89 |
+
assert calculate_F1_for_spans(gold, pred) == 100
|
90 |
+
pred = [(0, 0)]
|
91 |
+
gold = [(0, 1)]
|
92 |
+
assert calculate_F1_for_spans(gold, pred) == 0
|
93 |
+
pred = [(0, 0), (0, 1)]
|
94 |
+
gold = [(0, 1), (1, 1)]
|
95 |
+
assert calculate_F1_for_spans(gold, pred) == 50
|
96 |
+
pred = [(0, 0), (0, 0)]
|
97 |
+
gold = [(0, 0), (0, 0), (0, 1)]
|
98 |
+
assert calculate_F1_for_spans(gold, pred) == 80
|
99 |
+
pred = [(0, 0), (1, 0)]
|
100 |
+
gold = [(0, 0), (0, 0), (0, 1)]
|
101 |
+
assert calculate_F1_for_spans(gold, pred) == 40
|
102 |
+
|
103 |
+
|
104 |
+
def read_lines_from_file(filepath, len_limit):
|
105 |
+
with open(filepath, "r") as f:
|
106 |
+
for line in f:
|
107 |
+
tree = nltk.Tree.fromstring(line)
|
108 |
+
if len_limit is not None and len(tree.pos()) > len_limit:
|
109 |
+
continue
|
110 |
+
yield line.strip()
|
111 |
+
|
112 |
+
|
113 |
+
def read_spans_from_file(filepath, len_limit):
|
114 |
+
for line in read_lines_from_file(filepath, len_limit):
|
115 |
+
yield tree_to_spans(line, keep_labels=False, keep_leaves=False, keep_whole_span=False)
|
116 |
+
|
117 |
+
|
118 |
+
def calculate_corpus_level_F1_for_spans(gold_list, pred_list):
|
119 |
+
n_true_positives = 0
|
120 |
+
n_golds = 0
|
121 |
+
n_predictions = 0
|
122 |
+
for gold_spans, pred_spans in zip(gold_list, pred_list):
|
123 |
+
n_tp, n_g, n_p = get_F1_score_intermediates(gold_spans, pred_spans)
|
124 |
+
n_true_positives += n_tp
|
125 |
+
n_golds += n_g
|
126 |
+
n_predictions += n_p
|
127 |
+
F1 = calculate_F1_score_from_intermediates(n_true_positives, n_golds, n_predictions)
|
128 |
+
return F1
|
129 |
+
|
130 |
+
|
131 |
+
def calculate_sentence_level_F1_for_spans(gold_list, pred_list):
|
132 |
+
f1_scores = []
|
133 |
+
for gold_spans, pred_spans in zip(gold_list, pred_list):
|
134 |
+
f1 = calculate_F1_for_spans(gold_spans, pred_spans)
|
135 |
+
f1_scores.append(f1)
|
136 |
+
F1 = sum(f1_scores) / len(f1_scores)
|
137 |
+
return F1
|
138 |
+
|
139 |
+
|
140 |
+
def parse_evalb_results_from_file(filepath):
|
141 |
+
i_th_score = 0
|
142 |
+
score_of_all_length = None
|
143 |
+
score_of_length_10 = None
|
144 |
+
prefix_of_the_score_line = "Bracketing FMeasure ="
|
145 |
+
|
146 |
+
with open(filepath, "r") as f:
|
147 |
+
for line in f:
|
148 |
+
if line.startswith(prefix_of_the_score_line):
|
149 |
+
i_th_score += 1
|
150 |
+
if i_th_score == 1:
|
151 |
+
score_of_all_length = float(line.split()[-1])
|
152 |
+
elif i_th_score == 2:
|
153 |
+
score_of_length_10 = float(line.split()[-1])
|
154 |
+
else:
|
155 |
+
raise ValueError("Too many lines for F score")
|
156 |
+
return score_of_all_length, score_of_length_10
|
157 |
+
|
158 |
+
|
159 |
+
def execute_evalb(gold_file, pred_file, out_file, len_limit):
|
160 |
+
EVALB_PATH = "model/EVALB/"
|
161 |
+
subprocess.run("{} -p {} {} {} > {}".format(EVALB_PATH + "/evalb", EVALB_PATH + "unlabelled.prm", gold_file, pred_file, out_file), shell=True)
|
162 |
+
|
163 |
+
|
164 |
+
def calculate_evalb_F1_for_file(gold_file, pred_file, len_limit):
|
165 |
+
evalb_out_file = pred_file + ".evalb_out"
|
166 |
+
execute_evalb(gold_file, pred_file, evalb_out_file, len_limit)
|
167 |
+
F1_len_all, F1_len_10 = parse_evalb_results_from_file(evalb_out_file)
|
168 |
+
if len_limit is None:
|
169 |
+
return F1_len_all
|
170 |
+
elif len_limit == 10:
|
171 |
+
return F1_len_10
|
172 |
+
else:
|
173 |
+
raise ValueError(f"Unexpected len_limit: {len_limit}")
|
174 |
+
|
175 |
+
|
176 |
+
def calculate_sentence_level_F1_for_file(gold_file, pred_file, len_limit):
|
177 |
+
gold_list = list(read_spans_from_file(gold_file, len_limit))
|
178 |
+
pred_list = list(read_spans_from_file(pred_file, len_limit))
|
179 |
+
F1 = calculate_sentence_level_F1_for_spans(gold_list, pred_list)
|
180 |
+
return F1
|
181 |
+
|
182 |
+
|
183 |
+
def calculate_corpus_level_F1_for_file(gold_file, pred_file, len_limit):
|
184 |
+
gold_list = list(read_spans_from_file(gold_file, len_limit))
|
185 |
+
pred_list = list(read_spans_from_file(pred_file, len_limit))
|
186 |
+
F1 = calculate_corpus_level_F1_for_spans(gold_list, pred_list)
|
187 |
+
return F1
|
188 |
+
|
189 |
+
|
190 |
+
def evaluate_prediction_file(gold_file, pred_file, len_limit):
|
191 |
+
corpus_F1 = calculate_corpus_level_F1_for_file(gold_file, pred_file, len_limit)
|
192 |
+
sentence_F1 = calculate_sentence_level_F1_for_file(gold_file, pred_file, len_limit)
|
193 |
+
# evalb_F1 = calculate_evalb_F1_for_file(gold_file, pred_file, len_limit)
|
194 |
+
|
195 |
+
print("=====> Evaluation Results <=====")
|
196 |
+
print(f"Length constraint: f{len_limit}")
|
197 |
+
print(f"Micro F1: {corpus_F1:.2f}, Macro F1: {sentence_F1:.2f}") # , evalb_F1))
|
198 |
+
print("=====> Evaluation Results <=====")
|
199 |
+
|
200 |
+
|
201 |
+
def parse_args():
|
202 |
+
parser = argparse.ArgumentParser()
|
203 |
+
parser.add_argument("--gold_file", "-g", help="path to gold file")
|
204 |
+
parser.add_argument("--pred_file", "-p", help="path to prediction file")
|
205 |
+
parser.add_argument(
|
206 |
+
"--len_limit", default=None, type=int, choices=(None, 10, 20, 30, 40, 50, 100), help="length constraint for evaluation, 10 or None"
|
207 |
+
)
|
208 |
+
args = parser.parse_args()
|
209 |
+
|
210 |
+
return args
|
211 |
+
|
212 |
+
|
213 |
+
def main():
|
214 |
+
args = parse_args()
|
215 |
+
evaluate_prediction_file(args.gold_file, args.pred_file, args.len_limit)
|
216 |
+
|
217 |
+
|
218 |
+
if __name__ == "__main__":
|
219 |
+
main()
|
220 |
+
|
221 |
+
# python helper/evaluate.py -g TEMP/preprocessed_dev.txt -p TEMP/pred_dev_m_None.txt
|
weakly_supervised_parser/tree/helpers.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
from collections import Counter
|
3 |
+
from weakly_supervised_parser.tree.evaluate import tree_to_spans
|
4 |
+
|
5 |
+
|
6 |
+
class Tree(object):
|
7 |
+
def __init__(self, label, children, word):
|
8 |
+
self.label = label
|
9 |
+
self.children = children
|
10 |
+
self.word = word
|
11 |
+
|
12 |
+
def __str__(self):
|
13 |
+
return self.linearize()
|
14 |
+
|
15 |
+
def linearize(self):
|
16 |
+
if not self.children:
|
17 |
+
return f"({self.label} {self.word})"
|
18 |
+
return f"({self.label} {' '.join(c.linearize() for c in self.children)})"
|
19 |
+
|
20 |
+
def spans(self, start=0):
|
21 |
+
if not self.children:
|
22 |
+
return [(start, start + 1)]
|
23 |
+
span_list = []
|
24 |
+
position = start
|
25 |
+
for c in self.children:
|
26 |
+
cspans = c.spans(start=position)
|
27 |
+
span_list.extend(cspans)
|
28 |
+
position = cspans[0][1]
|
29 |
+
return [(start, position)] + span_list
|
30 |
+
|
31 |
+
def spans_labels(self, start=0):
|
32 |
+
if not self.children:
|
33 |
+
return [(start, start + 1, self.label)]
|
34 |
+
span_list = []
|
35 |
+
position = start
|
36 |
+
for c in self.children:
|
37 |
+
cspans = c.spans_labels(start=position)
|
38 |
+
span_list.extend(cspans)
|
39 |
+
position = cspans[0][1]
|
40 |
+
return [(start, position, self.label)] + span_list
|
41 |
+
|
42 |
+
|
43 |
+
def extract_sentence(sentence):
|
44 |
+
t = nltk.Tree.fromstring(sentence)
|
45 |
+
return " ".join(item[0] for item in t.pos())
|
46 |
+
|
47 |
+
|
48 |
+
def get_constituents(sample_string, want_spans_mapping=False, whole_sentence=True, labels=False):
|
49 |
+
t = nltk.Tree.fromstring(sample_string)
|
50 |
+
if want_spans_mapping:
|
51 |
+
spans = tree_to_spans(t, keep_labels=True)
|
52 |
+
return dict(Counter(item[1] for item in spans))
|
53 |
+
spans = tree_to_spans(t, keep_labels=True)
|
54 |
+
sentence = extract_sentence(sample_string).split()
|
55 |
+
|
56 |
+
labeled_consituents_lst = []
|
57 |
+
constituents = []
|
58 |
+
for span in spans:
|
59 |
+
labeled_consituents = {}
|
60 |
+
labeled_consituents["labels"] = span[0]
|
61 |
+
i, j = span[1][0], span[1][1]
|
62 |
+
constituents.append(" ".join(sentence[i:j]))
|
63 |
+
labeled_consituents["constituent"] = " ".join(sentence[i:j])
|
64 |
+
labeled_consituents_lst.append(labeled_consituents)
|
65 |
+
|
66 |
+
# Add original sentence
|
67 |
+
if whole_sentence:
|
68 |
+
constituents = constituents + [" ".join(sentence)]
|
69 |
+
|
70 |
+
if labels:
|
71 |
+
return labeled_consituents_lst
|
72 |
+
|
73 |
+
return constituents
|
74 |
+
|
75 |
+
|
76 |
+
def get_distituents(sample_string):
|
77 |
+
sentence = extract_sentence(sample_string).split()
|
78 |
+
|
79 |
+
def get_all_combinations(sentence):
|
80 |
+
L = sentence.split()
|
81 |
+
N = len(L)
|
82 |
+
out = []
|
83 |
+
for n in range(2, N):
|
84 |
+
for i in range(N - n + 1):
|
85 |
+
out.append((i, i + n))
|
86 |
+
return out
|
87 |
+
|
88 |
+
combinations = get_all_combinations(extract_sentence(sample_string))
|
89 |
+
constituents = list(get_constituents(sample_string, want_spans_mapping=True).keys())
|
90 |
+
spans = [item for item in combinations if item not in constituents]
|
91 |
+
distituents = []
|
92 |
+
for span in spans:
|
93 |
+
i, j = span[0], span[1]
|
94 |
+
distituents.append(" ".join(sentence[i:j]))
|
95 |
+
return distituents
|
96 |
+
|
97 |
+
|
98 |
+
def get_leaves(tree):
|
99 |
+
if not tree.children:
|
100 |
+
return [tree]
|
101 |
+
leaves = []
|
102 |
+
for c in tree.children:
|
103 |
+
leaves.extend(get_leaves(c))
|
104 |
+
return leaves
|
105 |
+
|
106 |
+
|
107 |
+
def unlinearize(string):
|
108 |
+
"""
|
109 |
+
(TOP (S (NP (PRP He)) (VP (VBD was) (ADJP (JJ right))) (. .)))
|
110 |
+
"""
|
111 |
+
tokens = string.replace("(", " ( ").replace(")", " ) ").split()
|
112 |
+
|
113 |
+
def read_tree(start):
|
114 |
+
if tokens[start + 2] != "(":
|
115 |
+
return Tree(tokens[start + 1], None, tokens[start + 2]), start + 4
|
116 |
+
i = start + 2
|
117 |
+
children = []
|
118 |
+
while tokens[i] != ")":
|
119 |
+
tree, i = read_tree(i)
|
120 |
+
children.append(tree)
|
121 |
+
return Tree(tokens[start + 1], children, None), i + 1
|
122 |
+
|
123 |
+
tree, _ = read_tree(0)
|
124 |
+
return tree
|
125 |
+
|
126 |
+
|
127 |
+
def recall_by_label(gold_standard, best_parse):
|
128 |
+
correct = {}
|
129 |
+
total = {}
|
130 |
+
for tree1, tree2 in zip(gold_standard, best_parse):
|
131 |
+
try:
|
132 |
+
leaves1, leaves2 = get_leaves(tree1["tree"]), get_leaves(tree2["tree"])
|
133 |
+
for l1, l2 in zip(leaves1, leaves2):
|
134 |
+
assert l1.word.lower() == l2.word.lower(), f"{l1.word} =/= {l2.word}"
|
135 |
+
spanlabels = tree1["tree"].spans_labels()
|
136 |
+
spans = tree2["tree"].spans()
|
137 |
+
|
138 |
+
for (i, j, label) in spanlabels:
|
139 |
+
if j - i != 1:
|
140 |
+
if label not in correct:
|
141 |
+
correct[label] = 0
|
142 |
+
total[label] = 0
|
143 |
+
if (i, j) in spans:
|
144 |
+
correct[label] += 1
|
145 |
+
total[label] += 1
|
146 |
+
except Exception as e:
|
147 |
+
print(e)
|
148 |
+
acc = {}
|
149 |
+
for label in total.keys():
|
150 |
+
acc[label] = correct[label] / total[label]
|
151 |
+
return acc
|
152 |
+
|
153 |
+
|
154 |
+
def label_recall_output(gold_standard, best_parse):
|
155 |
+
best_parse_trees = []
|
156 |
+
gold_standard_trees = []
|
157 |
+
for t1, t2 in zip(gold_standard, best_parse):
|
158 |
+
gold_standard_trees.append({"tree": unlinearize(t1)})
|
159 |
+
best_parse_trees.append({"tree": unlinearize(t2)})
|
160 |
+
|
161 |
+
dct = recall_by_label(gold_standard=gold_standard_trees, best_parse=best_parse_trees)
|
162 |
+
labels = ["SBAR", "NP", "VP", "PP", "ADJP", "ADVP"]
|
163 |
+
l = [{label: f"{recall * 100:.2f}"} for label, recall in dct.items() if label in labels]
|
164 |
+
df = pd.DataFrame([item.values() for item in l], index=[item.keys() for item in l], columns=["recall"])
|
165 |
+
df.index = df.index.map(lambda x: list(x)[0])
|
166 |
+
df_out = df.reindex(labels)
|
167 |
+
return df_out
|
168 |
+
|
169 |
+
|
170 |
+
if __name__ == "__main__":
|
171 |
+
import pandas as pd
|
172 |
+
from weakly_supervised_parser.utils.prepare_dataset import PTBDataset
|
173 |
+
from weakly_supervised_parser.settings import PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH, PTB_SAVE_TREES_PATH
|
174 |
+
|
175 |
+
best_parse = PTBDataset(PTB_SAVE_TREES_PATH + "inside_model_predictions.txt").retrieve_all_sentences()
|
176 |
+
gold_standard = PTBDataset(PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH).retrieve_all_sentences()
|
177 |
+
print(label_recall_output(gold_standard, best_parse))
|
weakly_supervised_parser/utils/__init__.py
ADDED
File without changes
|
weakly_supervised_parser/utils/cky_algorithm.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import numpy as np
|
3 |
+
from weakly_supervised_parser.tree.helpers import Tree
|
4 |
+
|
5 |
+
|
6 |
+
def CKY(sent_all, prob_s, label_s, verbose=False):
|
7 |
+
r"""
|
8 |
+
choose tree with maximum expected number of constituents,
|
9 |
+
or max \sum_{(i,j) \in tree} p((i,j) is constituent)
|
10 |
+
"""
|
11 |
+
|
12 |
+
def backpt_to_tree(sent, backpt, label_table):
|
13 |
+
def to_tree(i, j):
|
14 |
+
if j - i == 1:
|
15 |
+
return Tree(sent[i], None, sent[i])
|
16 |
+
else:
|
17 |
+
k = backpt[i][j]
|
18 |
+
return Tree(label_table[i][j], [to_tree(i, k), to_tree(k, j)], None)
|
19 |
+
|
20 |
+
return to_tree(0, len(sent))
|
21 |
+
|
22 |
+
def to_table(value_s, i_s, j_s):
|
23 |
+
table = [[None for _ in range(np.max(j_s) + 1)] for _ in range(np.max(i_s) + 1)]
|
24 |
+
for value, i, j in zip(value_s, i_s, j_s):
|
25 |
+
table[i][j] = value
|
26 |
+
return table
|
27 |
+
|
28 |
+
# produce list of spans to pass to is_constituent, while keeping track of which sentence
|
29 |
+
sent_s, i_s, j_s = [], [], []
|
30 |
+
idx_all = []
|
31 |
+
for sent in sent_all:
|
32 |
+
start = len(sent_s)
|
33 |
+
for i in range(len(sent)):
|
34 |
+
for j in range(i + 1, len(sent) + 1):
|
35 |
+
sent_s.append(sent)
|
36 |
+
i_s.append(i)
|
37 |
+
j_s.append(j)
|
38 |
+
idx_all.append((start, len(sent_s)))
|
39 |
+
|
40 |
+
# feed spans to is_constituent
|
41 |
+
# prob_s, label_s = self.is_constituent(sent_s, i_s, j_s, verbose = verbose)
|
42 |
+
|
43 |
+
# given span probs, perform CKY to get best tree for each sentence.
|
44 |
+
tree_all, prob_all = [], []
|
45 |
+
for sent, idx in zip(sent_all, idx_all):
|
46 |
+
# first, use tables to keep track of things
|
47 |
+
k, l = idx
|
48 |
+
prob, label = prob_s[k:l], label_s[k:l]
|
49 |
+
i, j = i_s[k:l], j_s[k:l]
|
50 |
+
|
51 |
+
prob_table = to_table(prob, i, j)
|
52 |
+
label_table = to_table(label, i, j)
|
53 |
+
|
54 |
+
# perform cky using scores and backpointers
|
55 |
+
score_table = [[None for _ in range(len(sent) + 1)] for _ in range(len(sent))]
|
56 |
+
backpt_table = [[None for _ in range(len(sent) + 1)] for _ in range(len(sent))]
|
57 |
+
for i in range(len(sent)): # base case: single words
|
58 |
+
score_table[i][i + 1] = 1
|
59 |
+
for j in range(2, len(sent) + 1):
|
60 |
+
for i in range(j - 2, -1, -1):
|
61 |
+
best, argmax = -np.inf, None
|
62 |
+
for k in range(i + 1, j): # find splitpoint
|
63 |
+
score = score_table[i][k] + score_table[k][j]
|
64 |
+
if score > best:
|
65 |
+
best, argmax = score, k
|
66 |
+
score_table[i][j] = best + prob_table[i][j]
|
67 |
+
backpt_table[i][j] = argmax
|
68 |
+
|
69 |
+
tree = backpt_to_tree(sent, backpt_table, label_table)
|
70 |
+
tree_all.append(tree)
|
71 |
+
prob_all.append(prob_table)
|
72 |
+
|
73 |
+
return tree_all, prob_all
|
74 |
+
|
75 |
+
|
76 |
+
def get_best_parse(sentence, spans):
|
77 |
+
flattened_scores = []
|
78 |
+
for i in range(spans.shape[0]):
|
79 |
+
for j in range(spans.shape[1]):
|
80 |
+
if i > j:
|
81 |
+
continue
|
82 |
+
else:
|
83 |
+
flattened_scores.append(spans[i, j])
|
84 |
+
prob_s, label_s = flattened_scores, ["S"] * len(flattened_scores)
|
85 |
+
# print(prob_s, label_s)
|
86 |
+
trees, _ = CKY(sent_all=sentence, prob_s=prob_s, label_s=label_s)
|
87 |
+
s = str(trees[0])
|
88 |
+
# Replace previous occurrence of string
|
89 |
+
out = re.sub(r"(?<![^\s()])([^\s()]+)(?=\s+\1(?![^\s()]))", "S", s)
|
90 |
+
# best_parse = "(ROOT " + out + ")"
|
91 |
+
return out # best_parse
|
weakly_supervised_parser/utils/create_inside_outside_strings.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class InsideOutside:
|
2 |
+
def __init__(self, sentence):
|
3 |
+
self.sentence = sentence.split()
|
4 |
+
self.sentence_length = len(self.sentence)
|
5 |
+
|
6 |
+
def calculate_inside(self, idx_start, idx_end):
|
7 |
+
# get inside string
|
8 |
+
return self.sentence[idx_start:idx_end]
|
9 |
+
|
10 |
+
def calculate_outside(self, idx_start, idx_end):
|
11 |
+
# get outside string
|
12 |
+
if idx_start == 0 and idx_end == self.sentence_length:
|
13 |
+
left_outside = ["<s>"] # bos_token roberta # ["[UNK]"]
|
14 |
+
right_outside = ["</s>"] # eos_token roberta # ["[UNK]"]
|
15 |
+
elif idx_start == 0:
|
16 |
+
left_outside = ["<s>"] # ["[UNK]"]
|
17 |
+
right_outside = self.sentence[idx_end:]
|
18 |
+
elif idx_end == self.sentence_length:
|
19 |
+
left_outside = self.sentence[:idx_start]
|
20 |
+
right_outside = ["</s>"] # ["[UNK]"]
|
21 |
+
else:
|
22 |
+
left_outside = self.sentence[:idx_start]
|
23 |
+
right_outside = self.sentence[idx_end:]
|
24 |
+
return left_outside, right_outside
|
25 |
+
|
26 |
+
def create_inside_outside_matrix(self, ngram):
|
27 |
+
i, j = ngram[0][0], ngram[0][-1]
|
28 |
+
inside_string = self.calculate_inside(i, j)
|
29 |
+
outside_string = self.calculate_outside(i, j)
|
30 |
+
output_dict = {
|
31 |
+
"span": ngram[0],
|
32 |
+
"inside_string": " ".join(inside_string),
|
33 |
+
"left_outside_string": " ".join(outside_string[0]),
|
34 |
+
"right_outside_string": " ".join(outside_string[-1]),
|
35 |
+
}
|
36 |
+
inside_string_template = output_dict["inside_string"]
|
37 |
+
outside_string_template = (
|
38 |
+
output_dict["left_outside_string"].split()[-1] + " " + "<mask>" + " " + output_dict["right_outside_string"].split()[0]
|
39 |
+
)
|
40 |
+
return output_dict, inside_string_template, outside_string_template
|
weakly_supervised_parser/utils/distant_supervision.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict, Counter
|
2 |
+
from nltk.corpus import stopwords
|
3 |
+
|
4 |
+
|
5 |
+
class RuleBasedHeuristic:
|
6 |
+
def __init__(self, sentence=None, corpus=None):
|
7 |
+
self.sentence = sentence
|
8 |
+
self.corpus = corpus
|
9 |
+
|
10 |
+
def add_contiguous_titlecase_words(self, row):
|
11 |
+
matches = []
|
12 |
+
dd = defaultdict(list)
|
13 |
+
count = 0
|
14 |
+
for i, j in zip(row, row[1:]):
|
15 |
+
if j[0] - i[0] == 1:
|
16 |
+
dd[count].append(i[-1] + " " + j[-1])
|
17 |
+
else:
|
18 |
+
count += 1
|
19 |
+
for key, value in dd.items():
|
20 |
+
if len(value) > 1:
|
21 |
+
out = value[0]
|
22 |
+
inter = ""
|
23 |
+
for item in value[1:]:
|
24 |
+
inter += " " + item.split()[-1]
|
25 |
+
matches.append(out + inter)
|
26 |
+
else:
|
27 |
+
matches.extend(value)
|
28 |
+
return matches
|
29 |
+
|
30 |
+
def augment_using_most_frequent_starting_token(self, N=1):
|
31 |
+
first_token = []
|
32 |
+
for sentence in self.corpus:
|
33 |
+
first_token.append(sentence.split()[0])
|
34 |
+
return Counter(first_token).most_common(N)
|
35 |
+
|
36 |
+
def get_top_tokens(self, top_most_common_ptb=None):
|
37 |
+
out = set(stopwords.words("english"))
|
38 |
+
if top_most_common_ptb:
|
39 |
+
out.update([token for token, counts in self.augment_using_most_frequent_starting_token(N=top_most_common_ptb)])
|
40 |
+
return out
|
weakly_supervised_parser/utils/populate_chart.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from datasets.utils import set_progress_bar_enabled
|
5 |
+
|
6 |
+
from weakly_supervised_parser.utils.prepare_dataset import NGramify
|
7 |
+
from weakly_supervised_parser.utils.create_inside_outside_strings import InsideOutside
|
8 |
+
from weakly_supervised_parser.utils.cky_algorithm import get_best_parse
|
9 |
+
from weakly_supervised_parser.utils.distant_supervision import RuleBasedHeuristic
|
10 |
+
from weakly_supervised_parser.utils.prepare_dataset import PTBDataset
|
11 |
+
from weakly_supervised_parser.settings import PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH
|
12 |
+
|
13 |
+
# Disable Dataset.map progress bar
|
14 |
+
set_progress_bar_enabled(False)
|
15 |
+
|
16 |
+
# ptb = PTBDataset(data_path=PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH)
|
17 |
+
# ptb_top_100_common = [item.lower() for item in RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).get_top_tokens(top_most_common_ptb=100)]
|
18 |
+
ptb_top_100_common = ['this', 'myself', 'shouldn', 'not', 'analysts', 'same', 'mightn', 'we', 'american', 'the', 'another', 'until', "aren't", 'when', 'if', 'am', 'over', 'ma', 'as', 'of', 'with', 'even', 'couldn', 'not', "needn't", 'where', 'there', 'isn', 'however', 'my', 'sales', 'here', 'at', 'yours', 'into', 'wouldn', 'officials', 'no', "hasn't", 'to', 'wasn', 'any', 'ours', 'out', 'each', "wasn't", 'is', 'and', 'me', 'off', 'once', "it's", 'they', 'most', 'also', 'through', 'hasn', 'our', 'or', 'after', "weren't", 'about', 'mr.', 'first', 'haven', 'needn', 'have', "isn't", 'now', "didn't", 'on', 'theirs', 'these', 'before', 'there', 'was', 'which', 'those', 'having', 'do', 'most', 'own', 'among', 'because', 'for', "should've", "shan't", 'so', 'being', 'few', 'too', 'to', 'at', 'people', 'her', 'meanwhile', 'both', 'down', 'doesn', 'below', 'mustn', 'an', 'two', 'more', 'japanese', 'ford', "you'd", 'about', 'but', 'doing', 'itself', 've', 'under', 'what', 'again', 'then', 'your', 'himself', 'now', 'against', 'just', 'does', 'net', "couldn't", 'that', 'he', 'revenue', 'because', 'yesterday', 'them', 'i', 'their', 'all', 'under', 'up', "haven't", 'while', "won't", 'it', 'more', 'it', 'ain', 'him', 'still', 'a', 'he', 'despite', 'should', 'during', 'nor', "shouldn't", 'such', "doesn't", 'are', "that'll", 'since', 'yourselves', 'such', 'those', 'after', 'weren', "you're", 'd', 'like', 'did', 'hadn', 'themselves', 'its', 'but', 'been', 's', "don't", 'these', 'they', 'this', 'his', "mightn't", 'moreover', 'how', 'new', 'above', 'ourselves', 'so', 'why', 'between', 'their', 'general', "wouldn't", 'who', 'i', 'in', 'don', 'shan', 'u.s.', 'ibm', 'separately', 'had', 'you', 'federal', 'if', 'our', 'and', 'only', 'y', 'many', 'one', 'no', 'though', 'won', 'last', 'from', 'each', 'traders', 'john', 'further', 'hers', 'both', "you've", "you'll", 'that', 'all', 'its', 'only', 'here', 'according', "mustn't", 'while', 'in', 'what', 'didn', 'when', 'some', 'on', 'can', 'yourself', 'herself', 'than', 'with', 'has', 'she', 'during', 'will', 'of', 'thus', 'you', 'very', 'o', 'investors', 'a', 'ms.', 'japan', 'were', 'the', 'we', 'm', 'as', 'll', 'be', 'by', 'other', 'yet', 'whom', 'some', 'indeed', 'other', "she's", "hadn't", 'by', 'earlier', 'for', 'instead', 'she', 'an', 't', 're', 'his', 'then', 'aren', 'although']
|
19 |
+
# ptb_most_common_first_token = RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).augment_using_most_frequent_starting_token(N=1)[0][0].lower()
|
20 |
+
ptb_most_common_first_token = "the"
|
21 |
+
|
22 |
+
|
23 |
+
class PopulateCKYChart:
|
24 |
+
def __init__(self, sentence):
|
25 |
+
self.sentence = sentence
|
26 |
+
self.sentence_list = sentence.split()
|
27 |
+
self.sentence_length = len(sentence.split())
|
28 |
+
self.span_scores = np.zeros((self.sentence_length + 1, self.sentence_length + 1), dtype=float)
|
29 |
+
self.all_spans = NGramify(self.sentence).generate_ngrams(single_span=True, whole_span=True)
|
30 |
+
|
31 |
+
def compute_scores(self, model, predict_type, scale_axis, predict_batch_size, chunks=128):
|
32 |
+
inside_strings = []
|
33 |
+
outside_strings = []
|
34 |
+
inside_scores = []
|
35 |
+
outside_scores = []
|
36 |
+
|
37 |
+
for span in self.all_spans:
|
38 |
+
_, inside_string, outside_string = InsideOutside(sentence=self.sentence).create_inside_outside_matrix(span)
|
39 |
+
inside_strings.append(inside_string)
|
40 |
+
outside_strings.append(outside_string)
|
41 |
+
|
42 |
+
data = pd.DataFrame({"inside_sentence": inside_strings, "outside_sentence": outside_strings, "span": [span[0] for span in self.all_spans]})
|
43 |
+
|
44 |
+
if predict_type == "inside":
|
45 |
+
|
46 |
+
if data.shape[0] > chunks:
|
47 |
+
data_chunks = np.array_split(data, data.shape[0] // chunks)
|
48 |
+
for data_chunk in data_chunks:
|
49 |
+
inside_scores.extend(model.predict_proba(spans=data_chunk.rename(columns={"inside_sentence": "sentence"})[["sentence"]],
|
50 |
+
scale_axis=scale_axis,
|
51 |
+
predict_batch_size=predict_batch_size)[:, 1])
|
52 |
+
else:
|
53 |
+
inside_scores.extend(model.predict_proba(spans=data.rename(columns={"inside_sentence": "sentence"})[["sentence"]],
|
54 |
+
scale_axis=scale_axis,
|
55 |
+
predict_batch_size=predict_batch_size)[:, 1])
|
56 |
+
|
57 |
+
data["inside_scores"] = inside_scores
|
58 |
+
data.loc[
|
59 |
+
(data["inside_sentence"].str.lower().str.startswith(ptb_most_common_first_token))
|
60 |
+
& (data["inside_sentence"].str.lower().str.split().str.len() == 2)
|
61 |
+
& (~data["inside_sentence"].str.lower().str.split().str[-1].isin(RuleBasedHeuristic().get_top_tokens())),
|
62 |
+
"inside_scores",
|
63 |
+
] = 1
|
64 |
+
|
65 |
+
is_upper_or_title = all([item.istitle() or item.isupper() for item in self.sentence.split()])
|
66 |
+
is_stop = any([item for item in self.sentence.split() if item.lower() in ptb_top_100_common])
|
67 |
+
|
68 |
+
flags = is_upper_or_title and not is_stop
|
69 |
+
|
70 |
+
data["scores"] = data["inside_scores"]
|
71 |
+
|
72 |
+
elif predict_type == "outside":
|
73 |
+
outside_scores.extend(model.predict_proba(spans=data.rename(columns={"outside_sentence": "sentence"})[["sentence"]],
|
74 |
+
scale_axis=scale_axis,
|
75 |
+
predict_batch_size=predict_batch_size)[:, 1])
|
76 |
+
data["outside_scores"] = outside_scores
|
77 |
+
flags = False
|
78 |
+
data["scores"] = data["outside_scores"]
|
79 |
+
|
80 |
+
return flags, data
|
81 |
+
|
82 |
+
def fill_chart(self, model, predict_type, scale_axis, predict_batch_size, data=None):
|
83 |
+
if data is None:
|
84 |
+
flags, data = self.compute_scores(model, predict_type, scale_axis, predict_batch_size)
|
85 |
+
for span in self.all_spans:
|
86 |
+
for i in range(0, self.sentence_length):
|
87 |
+
for j in range(i + 1, self.sentence_length + 1):
|
88 |
+
if span[0] == (i, j):
|
89 |
+
self.span_scores[i, j] = data.loc[data["span"] == span[0], "scores"].item()
|
90 |
+
return flags, self.span_scores, data
|
91 |
+
|
92 |
+
def best_parse_tree(self, span_scores):
|
93 |
+
span_scores_cky_format = span_scores[:-1, 1:]
|
94 |
+
best_parse = get_best_parse(sentence=[self.sentence_list], spans=span_scores_cky_format)
|
95 |
+
return best_parse
|
weakly_supervised_parser/utils/prepare_dataset.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
|
6 |
+
from weakly_supervised_parser.utils.process_ptb import punctuation_words, currency_tags_words
|
7 |
+
from weakly_supervised_parser.utils.distant_supervision import RuleBasedHeuristic
|
8 |
+
|
9 |
+
|
10 |
+
filterchars = punctuation_words + currency_tags_words
|
11 |
+
filterchars = [char for char in filterchars if char not in list(",;-") and char not in "``" and char not in "''"]
|
12 |
+
|
13 |
+
|
14 |
+
class NGramify:
|
15 |
+
def __init__(self, sentence):
|
16 |
+
self.sentence = sentence.split()
|
17 |
+
self.sentence_length = len(self.sentence)
|
18 |
+
self.ngrams = []
|
19 |
+
|
20 |
+
def generate_ngrams(self, single_span=True, whole_span=True):
|
21 |
+
# number of substrings possible is N*(N+1)/2
|
22 |
+
# exclude substring or spans of length 1 and length N
|
23 |
+
if single_span:
|
24 |
+
start = 1
|
25 |
+
else:
|
26 |
+
start = 2
|
27 |
+
if whole_span:
|
28 |
+
end = self.sentence_length + 1
|
29 |
+
else:
|
30 |
+
end = self.sentence_length
|
31 |
+
for n in range(start, end):
|
32 |
+
for i in range(self.sentence_length - n + 1):
|
33 |
+
self.ngrams.append(((i, i + n), self.sentence[i : i + n]))
|
34 |
+
return self.ngrams
|
35 |
+
|
36 |
+
def generate_all_possible_spans(self):
|
37 |
+
for n in range(2, self.sentence_length):
|
38 |
+
for i in range(self.sentence_length - n + 1):
|
39 |
+
if i > 0 and (i + n) < self.sentence_length:
|
40 |
+
self.ngrams.append(
|
41 |
+
(
|
42 |
+
(i, i + n),
|
43 |
+
" ".join(self.sentence[i : i + n]),
|
44 |
+
" ".join(self.sentence[0:i])
|
45 |
+
+ " ("
|
46 |
+
+ " ".join(self.sentence[i : i + n])
|
47 |
+
+ ") "
|
48 |
+
+ " ".join(self.sentence[i + n : self.sentence_length]),
|
49 |
+
)
|
50 |
+
)
|
51 |
+
elif i == 0:
|
52 |
+
self.ngrams.append(
|
53 |
+
(
|
54 |
+
(i, i + n),
|
55 |
+
" ".join(self.sentence[i : i + n]),
|
56 |
+
"(" + " ".join(self.sentence[i : i + n]) + ") " + " ".join(self.sentence[i + n : self.sentence_length]),
|
57 |
+
)
|
58 |
+
)
|
59 |
+
elif (i + n) == self.sentence_length:
|
60 |
+
self.ngrams.append(
|
61 |
+
(
|
62 |
+
(i, i + n),
|
63 |
+
" ".join(self.sentence[i : i + n]),
|
64 |
+
" ".join(self.sentence[0:i]) + " (" + " ".join(self.sentence[i : i + n]) + ")",
|
65 |
+
)
|
66 |
+
)
|
67 |
+
return self.ngrams
|
68 |
+
|
69 |
+
|
70 |
+
class DataLoaderHelper:
|
71 |
+
def __init__(self, input_file_object=None, output_file_object=None):
|
72 |
+
self.input_file_object = input_file_object
|
73 |
+
self.output_file_object = output_file_object
|
74 |
+
|
75 |
+
def read_lines(self):
|
76 |
+
with open(self.input_file_object, "r") as f:
|
77 |
+
lines = f.read().splitlines()
|
78 |
+
return lines
|
79 |
+
|
80 |
+
def __getitem__(self, index):
|
81 |
+
return self.read_lines()[index]
|
82 |
+
|
83 |
+
def write_lines(self, keys, values):
|
84 |
+
with open(self.output_file_object, "w", newline="\n") as output_file:
|
85 |
+
dict_writer = csv.DictWriter(output_file, keys, delimiter="\t")
|
86 |
+
dict_writer.writeheader()
|
87 |
+
dict_writer.writerows(values)
|
88 |
+
|
89 |
+
|
90 |
+
class PTBDataset:
|
91 |
+
def __init__(self, data_path):
|
92 |
+
self.data = pd.read_csv(data_path, sep="\t", header=None, names=["sentence"])
|
93 |
+
self.data["sentence"] = self.data
|
94 |
+
|
95 |
+
def __len__(self):
|
96 |
+
return len(self.data)
|
97 |
+
|
98 |
+
def __getitem__(self, index):
|
99 |
+
return self.data["sentence"].loc[index]
|
100 |
+
|
101 |
+
def retrieve_all_sentences(self, N=None):
|
102 |
+
if N:
|
103 |
+
return self.data["sentence"].iloc[:N].tolist()
|
104 |
+
return self.data["sentence"].tolist()
|
105 |
+
|
106 |
+
def preprocess(self):
|
107 |
+
self.data["sentence"] = self.data["sentence"].apply(
|
108 |
+
lambda row: " ".join([sentence for sentence in row.split() if sentence not in filterchars])
|
109 |
+
)
|
110 |
+
return self.data
|
111 |
+
|
112 |
+
def seed_bootstrap_constituent(self):
|
113 |
+
whole_span_slice = self.data["sentence"]
|
114 |
+
func = lambda x: RuleBasedHeuristic().add_contiguous_titlecase_words(
|
115 |
+
row=[(index, character) for index, character in enumerate(x) if character.istitle() or "'" in character]
|
116 |
+
)
|
117 |
+
titlecase_matches = [item for sublist in self.data["sentence"].str.split().apply(func).tolist() for item in sublist if len(item.split()) > 1]
|
118 |
+
titlecase_matches_df = pd.Series(titlecase_matches)
|
119 |
+
titlecase_matches_df = titlecase_matches_df[~titlecase_matches_df.str.split().str[0].str.contains("'")].str.replace("''", "")
|
120 |
+
most_frequent_start_token = RuleBasedHeuristic(corpus=self.retrieve_all_sentences()).augment_using_most_frequent_starting_token(N=1)[0][0]
|
121 |
+
most_frequent_start_token_df = titlecase_matches_df[titlecase_matches_df.str.startswith(most_frequent_start_token)].str.lower()
|
122 |
+
constituent_samples = pd.DataFrame(dict(sentence=pd.concat([whole_span_slice, titlecase_matches_df, most_frequent_start_token_df]), label=1))
|
123 |
+
return constituent_samples
|
124 |
+
|
125 |
+
def seed_bootstrap_distituent(self):
|
126 |
+
avg_sent_len = int(self.data["sentence"].str.split().str.len().mean())
|
127 |
+
last_but_one_slice = self.data["sentence"].str.split().str[:-1].str.join(" ")
|
128 |
+
last_but_two_slice = self.data[self.data["sentence"].str.split().str.len() > avg_sent_len + 10]["sentence"].str.split().str[:-2].str.join(" ")
|
129 |
+
last_but_three_slice = (
|
130 |
+
self.data[self.data["sentence"].str.split().str.len() > avg_sent_len + 20]["sentence"].str.split().str[:-3].str.join(" ")
|
131 |
+
)
|
132 |
+
last_but_four_slice = (
|
133 |
+
self.data[self.data["sentence"].str.split().str.len() > avg_sent_len + 30]["sentence"].str.split().str[:-4].str.join(" ")
|
134 |
+
)
|
135 |
+
last_but_five_slice = (
|
136 |
+
self.data[self.data["sentence"].str.split().str.len() > avg_sent_len + 40]["sentence"].str.split().str[:-5].str.join(" ")
|
137 |
+
)
|
138 |
+
last_but_six_slice = self.data[self.data["sentence"].str.split().str.len() > avg_sent_len + 50]["sentence"].str.split().str[:-6].str.join(" ")
|
139 |
+
distituent_samples = pd.DataFrame(
|
140 |
+
dict(
|
141 |
+
sentence=pd.concat(
|
142 |
+
[
|
143 |
+
last_but_one_slice,
|
144 |
+
last_but_two_slice,
|
145 |
+
last_but_three_slice,
|
146 |
+
last_but_four_slice,
|
147 |
+
last_but_five_slice,
|
148 |
+
last_but_six_slice,
|
149 |
+
]
|
150 |
+
),
|
151 |
+
label=0,
|
152 |
+
)
|
153 |
+
)
|
154 |
+
return distituent_samples
|
155 |
+
|
156 |
+
def train_validation_split(self, seed, test_size=0.5, shuffle=True):
|
157 |
+
self.preprocess()
|
158 |
+
bootstrap_constituent_samples = self.seed_bootstrap_constituent()
|
159 |
+
bootstrap_distituent_samples = self.seed_bootstrap_distituent()
|
160 |
+
df = pd.concat([bootstrap_constituent_samples, bootstrap_distituent_samples], ignore_index=True)
|
161 |
+
df = df.drop_duplicates(subset=["sentence"]).dropna(subset=["sentence"])
|
162 |
+
df["sentence"] = df["sentence"].str.strip()
|
163 |
+
df = df[df["sentence"].str.split().str.len() > 1]
|
164 |
+
train, validation = train_test_split(df, test_size=test_size, random_state=seed, shuffle=shuffle)
|
165 |
+
return train.head(8000), validation.head(2000)
|