Spaces:
Build error
Build error
update model ckpt
Browse files
app.py
CHANGED
@@ -10,14 +10,16 @@ from weakly_supervised_parser.tree.evaluate import calculate_F1_for_spans, tree_
|
|
10 |
from weakly_supervised_parser.inference import Predictor
|
11 |
from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier
|
12 |
|
|
|
|
|
13 |
benepar.download('benepar_en3')
|
14 |
|
15 |
nlp = spacy.load("en_core_web_md")
|
16 |
nlp.add_pipe("benepar", config={"model": "benepar_en3"})
|
17 |
|
18 |
-
inside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=256)
|
19 |
-
fetch_url_inside_model = hf_hub_url(repo_id="nickil/weakly-supervised-parsing", filename="inside_model.
|
20 |
-
inside_model.
|
21 |
|
22 |
# outside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=64)
|
23 |
# outside_model.load_model(pre_trained_model_path=TRAINED_MODEL_PATH + "outside_model.onnx")
|
@@ -35,7 +37,7 @@ def predict(sentence, model):
|
|
35 |
elif model == "inside-outside":
|
36 |
best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="inside_outside", model=inside_outside_model, scale_axis=1, predict_batch_size=128)
|
37 |
sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard), tree_to_spans(best_parse))
|
38 |
-
return gold_standard, best_parse, sentence_f1
|
39 |
|
40 |
|
41 |
iface = gradio.Interface(
|
|
|
10 |
from weakly_supervised_parser.inference import Predictor
|
11 |
from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier
|
12 |
|
13 |
+
from weakly_supervised_parser.model.span_classifier import LightningModel
|
14 |
+
|
15 |
benepar.download('benepar_en3')
|
16 |
|
17 |
nlp = spacy.load("en_core_web_md")
|
18 |
nlp.add_pipe("benepar", config={"model": "benepar_en3"})
|
19 |
|
20 |
+
# inside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=256)
|
21 |
+
fetch_url_inside_model = hf_hub_url(repo_id="nickil/weakly-supervised-parsing", filename="inside_model.ckpt", revision="main")
|
22 |
+
inside_model = LightningModel.load_from_checkpoint(checkpoint_path=cached_download(fetch_url_inside_model))
|
23 |
|
24 |
# outside_model = InsideOutsideStringClassifier(model_name_or_path="roberta-base", max_seq_length=64)
|
25 |
# outside_model.load_model(pre_trained_model_path=TRAINED_MODEL_PATH + "outside_model.onnx")
|
|
|
37 |
elif model == "inside-outside":
|
38 |
best_parse = Predictor(sentence=sentence).obtain_best_parse(predict_type="inside_outside", model=inside_outside_model, scale_axis=1, predict_batch_size=128)
|
39 |
sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard), tree_to_spans(best_parse))
|
40 |
+
return gold_standard, best_parse, f"{sentence_f1:.2f}"
|
41 |
|
42 |
|
43 |
iface = gradio.Interface(
|
weakly_supervised_parser/inference.py
CHANGED
@@ -8,6 +8,8 @@ from weakly_supervised_parser.tree.evaluate import calculate_F1_for_spans, tree_
|
|
8 |
from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier
|
9 |
from weakly_supervised_parser.settings import PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH, PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH
|
10 |
|
|
|
|
|
11 |
|
12 |
class Predictor:
|
13 |
def __init__(self, sentence):
|
@@ -96,7 +98,7 @@ def main():
|
|
96 |
args = parser.parse_args()
|
97 |
|
98 |
if args.use_inside:
|
99 |
-
pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model.
|
100 |
max_seq_length = args.inside_max_seq_length
|
101 |
|
102 |
if args.use_inside_self_train:
|
@@ -116,8 +118,10 @@ def main():
|
|
116 |
outside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.outside_max_seq_length)
|
117 |
outside_model.load_model(pre_trained_model_path=outside_pre_trained_model_path)
|
118 |
else:
|
119 |
-
model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=max_seq_length)
|
120 |
-
model.load_model(pre_trained_model_path=pre_trained_model_path)
|
|
|
|
|
121 |
|
122 |
if args.use_inside or args.use_inside_self_train:
|
123 |
predict_type = "inside"
|
@@ -126,7 +130,6 @@ def main():
|
|
126 |
predict_type = "outside"
|
127 |
|
128 |
with open(args.save_path, "w") as out_file:
|
129 |
-
print(type(args.scale_axis))
|
130 |
test_sentences = DataLoaderHelper(input_file_object=PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH).read_lines()
|
131 |
test_gold_file_path = PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH
|
132 |
for test_index, test_sentence in enumerate(test_sentences):
|
|
|
8 |
from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier
|
9 |
from weakly_supervised_parser.settings import PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH, PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH
|
10 |
|
11 |
+
from weakly_supervised_parser.model.span_classifier import LightningModel
|
12 |
+
|
13 |
|
14 |
class Predictor:
|
15 |
def __init__(self, sentence):
|
|
|
98 |
args = parser.parse_args()
|
99 |
|
100 |
if args.use_inside:
|
101 |
+
pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model.ckpt"
|
102 |
max_seq_length = args.inside_max_seq_length
|
103 |
|
104 |
if args.use_inside_self_train:
|
|
|
118 |
outside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.outside_max_seq_length)
|
119 |
outside_model.load_model(pre_trained_model_path=outside_pre_trained_model_path)
|
120 |
else:
|
121 |
+
# model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=max_seq_length)
|
122 |
+
# model.load_model(pre_trained_model_path=pre_trained_model_path)
|
123 |
+
|
124 |
+
model = LightningModel.load_from_checkpoint(checkpoint_path=pre_trained_model_path)
|
125 |
|
126 |
if args.use_inside or args.use_inside_self_train:
|
127 |
predict_type = "inside"
|
|
|
130 |
predict_type = "outside"
|
131 |
|
132 |
with open(args.save_path, "w") as out_file:
|
|
|
133 |
test_sentences = DataLoaderHelper(input_file_object=PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH).read_lines()
|
134 |
test_gold_file_path = PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH
|
135 |
for test_index, test_sentence in enumerate(test_sentences):
|
weakly_supervised_parser/model/data_module_loader.py
CHANGED
@@ -35,13 +35,15 @@ class PyTorchDataModule(Dataset):
|
|
35 |
add_special_tokens=True,
|
36 |
return_tensors="pt",
|
37 |
)
|
38 |
-
|
39 |
out = dict(
|
40 |
sentence=sentence,
|
41 |
input_ids=sentence_encoding["input_ids"].flatten(),
|
42 |
attention_mask=sentence_encoding["attention_mask"].flatten(),
|
43 |
-
labels=data_row["label"].flatten(),
|
44 |
)
|
|
|
|
|
|
|
45 |
|
46 |
return out
|
47 |
|
@@ -52,6 +54,7 @@ class DataModule(LightningDataModule):
|
|
52 |
model_name_or_path: str,
|
53 |
train_df: pd.DataFrame,
|
54 |
eval_df: pd.DataFrame,
|
|
|
55 |
max_seq_length: int = 256,
|
56 |
train_batch_size: int = 32,
|
57 |
eval_batch_size: int = 32,
|
@@ -62,6 +65,7 @@ class DataModule(LightningDataModule):
|
|
62 |
self.model_name_or_path = model_name_or_path
|
63 |
self.train_df = train_df
|
64 |
self.eval_df = eval_df
|
|
|
65 |
self.max_seq_length = max_seq_length
|
66 |
self.train_batch_size = train_batch_size
|
67 |
self.eval_batch_size = eval_batch_size
|
@@ -71,9 +75,15 @@ class DataModule(LightningDataModule):
|
|
71 |
|
72 |
self.train_dataset = PyTorchDataModule(self.model_name_or_path, self.train_df, self.max_seq_length)
|
73 |
self.eval_dataset = PyTorchDataModule(self.model_name_or_path, self.eval_df, self.max_seq_length)
|
|
|
|
|
|
|
74 |
|
75 |
def train_dataloader(self) -> DataLoader:
|
76 |
return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True)
|
77 |
|
78 |
def val_dataloader(self) -> DataLoader:
|
79 |
return DataLoader(self.eval_dataset, batch_size=self.eval_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True)
|
|
|
|
|
|
|
|
35 |
add_special_tokens=True,
|
36 |
return_tensors="pt",
|
37 |
)
|
38 |
+
|
39 |
out = dict(
|
40 |
sentence=sentence,
|
41 |
input_ids=sentence_encoding["input_ids"].flatten(),
|
42 |
attention_mask=sentence_encoding["attention_mask"].flatten(),
|
|
|
43 |
)
|
44 |
+
|
45 |
+
if "label" in self.data.columns:
|
46 |
+
out.update(dict(labels=data_row["label"].flatten()))
|
47 |
|
48 |
return out
|
49 |
|
|
|
54 |
model_name_or_path: str,
|
55 |
train_df: pd.DataFrame,
|
56 |
eval_df: pd.DataFrame,
|
57 |
+
test_df: pd.DataFrame,
|
58 |
max_seq_length: int = 256,
|
59 |
train_batch_size: int = 32,
|
60 |
eval_batch_size: int = 32,
|
|
|
65 |
self.model_name_or_path = model_name_or_path
|
66 |
self.train_df = train_df
|
67 |
self.eval_df = eval_df
|
68 |
+
self.test_df = test_df
|
69 |
self.max_seq_length = max_seq_length
|
70 |
self.train_batch_size = train_batch_size
|
71 |
self.eval_batch_size = eval_batch_size
|
|
|
75 |
|
76 |
self.train_dataset = PyTorchDataModule(self.model_name_or_path, self.train_df, self.max_seq_length)
|
77 |
self.eval_dataset = PyTorchDataModule(self.model_name_or_path, self.eval_df, self.max_seq_length)
|
78 |
+
|
79 |
+
if isinstance(self.test_df, pd.DataFrame):
|
80 |
+
self.test_dataset = PyTorchDataModule(self.model_name_or_path, self.test_df, self.max_seq_length)
|
81 |
|
82 |
def train_dataloader(self) -> DataLoader:
|
83 |
return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True)
|
84 |
|
85 |
def val_dataloader(self) -> DataLoader:
|
86 |
return DataLoader(self.eval_dataset, batch_size=self.eval_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True)
|
87 |
+
|
88 |
+
def predict_dataloader(self) -> DataLoader:
|
89 |
+
return DataLoader(self.test_dataset, batch_size=len(self.test_dataset), shuffle=False, num_workers=self.num_workers, pin_memory=True)
|
weakly_supervised_parser/model/span_classifier.py
CHANGED
@@ -45,6 +45,11 @@ class LightningModel(LightningModule):
|
|
45 |
preds = torch.argmax(logits, axis=1)
|
46 |
labels = batch["labels"]
|
47 |
return {"loss": val_loss, "preds": preds, "labels": labels}
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
def validation_epoch_end(self, outputs):
|
50 |
preds = torch.cat([x["preds"] for x in outputs])
|
@@ -58,6 +63,8 @@ class LightningModel(LightningModule):
|
|
58 |
return loss
|
59 |
|
60 |
def setup(self, stage=None):
|
|
|
|
|
61 |
# Get dataloader by calling it - train_dataloader() is called after setup() by default
|
62 |
train_loader = self.trainer.datamodule.train_dataloader()
|
63 |
|
|
|
45 |
preds = torch.argmax(logits, axis=1)
|
46 |
labels = batch["labels"]
|
47 |
return {"loss": val_loss, "preds": preds, "labels": labels}
|
48 |
+
|
49 |
+
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
50 |
+
batch = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}
|
51 |
+
outputs = self(**batch)
|
52 |
+
return torch.nn.functional.softmax(outputs.logits, dim=1)[:, 1]
|
53 |
|
54 |
def validation_epoch_end(self, outputs):
|
55 |
preds = torch.cat([x["preds"] for x in outputs])
|
|
|
63 |
return loss
|
64 |
|
65 |
def setup(self, stage=None):
|
66 |
+
if stage != "fit":
|
67 |
+
return None
|
68 |
# Get dataloader by calling it - train_dataloader() is called after setup() by default
|
69 |
train_loader = self.trainer.datamodule.train_dataloader()
|
70 |
|
weakly_supervised_parser/model/trainer.py
CHANGED
@@ -10,7 +10,7 @@ from pytorch_lightning import Trainer, seed_everything
|
|
10 |
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
11 |
from transformers import AutoTokenizer, logging
|
12 |
|
13 |
-
from onnxruntime import InferenceSession
|
14 |
from scipy.special import softmax
|
15 |
|
16 |
from weakly_supervised_parser.model.data_module_loader import DataModule
|
@@ -37,7 +37,7 @@ class InsideOutsideStringClassifier:
|
|
37 |
devices: int = 1,
|
38 |
enable_progress_bar: bool = True,
|
39 |
enable_model_summary: bool = False,
|
40 |
-
enable_checkpointing: bool =
|
41 |
logger: bool = False,
|
42 |
accelerator: str = "auto",
|
43 |
train_batch_size: int = 32,
|
@@ -52,6 +52,7 @@ class InsideOutsideStringClassifier:
|
|
52 |
model_name_or_path=self.model_name_or_path,
|
53 |
train_df=train_df,
|
54 |
eval_df=eval_df,
|
|
|
55 |
max_seq_length=self.max_seq_length,
|
56 |
train_batch_size=train_batch_size,
|
57 |
eval_batch_size=eval_batch_size,
|
@@ -70,7 +71,7 @@ class InsideOutsideStringClassifier:
|
|
70 |
|
71 |
callbacks = []
|
72 |
callbacks.append(EarlyStopping(monitor="val_loss", patience=2, mode="min", check_finite=True))
|
73 |
-
|
74 |
|
75 |
trainer = Trainer(
|
76 |
accelerator=accelerator,
|
@@ -98,10 +99,7 @@ class InsideOutsideStringClassifier:
|
|
98 |
)
|
99 |
|
100 |
def load_model(self, pre_trained_model_path):
|
101 |
-
|
102 |
-
options.intra_op_num_threads = 32
|
103 |
-
options.inter_op_num_threads = 32
|
104 |
-
self.model = InferenceSession(pre_trained_model_path, options, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
105 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
106 |
|
107 |
def preprocess_function(self, data):
|
@@ -129,3 +127,24 @@ class InsideOutsideStringClassifier:
|
|
129 |
|
130 |
def predict(self, spans):
|
131 |
return self.predict_proba(spans).argmax(axis=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
11 |
from transformers import AutoTokenizer, logging
|
12 |
|
13 |
+
from onnxruntime import InferenceSession
|
14 |
from scipy.special import softmax
|
15 |
|
16 |
from weakly_supervised_parser.model.data_module_loader import DataModule
|
|
|
37 |
devices: int = 1,
|
38 |
enable_progress_bar: bool = True,
|
39 |
enable_model_summary: bool = False,
|
40 |
+
enable_checkpointing: bool = True,
|
41 |
logger: bool = False,
|
42 |
accelerator: str = "auto",
|
43 |
train_batch_size: int = 32,
|
|
|
52 |
model_name_or_path=self.model_name_or_path,
|
53 |
train_df=train_df,
|
54 |
eval_df=eval_df,
|
55 |
+
test_df=None,
|
56 |
max_seq_length=self.max_seq_length,
|
57 |
train_batch_size=train_batch_size,
|
58 |
eval_batch_size=eval_batch_size,
|
|
|
71 |
|
72 |
callbacks = []
|
73 |
callbacks.append(EarlyStopping(monitor="val_loss", patience=2, mode="min", check_finite=True))
|
74 |
+
callbacks.append(ModelCheckpoint(monitor="val_loss", dirpath=outputdir, filename=filename, save_top_k=1, save_weights_only=True, mode="min"))
|
75 |
|
76 |
trainer = Trainer(
|
77 |
accelerator=accelerator,
|
|
|
99 |
)
|
100 |
|
101 |
def load_model(self, pre_trained_model_path):
|
102 |
+
self.model = InferenceSession(pre_trained_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
|
|
|
|
|
|
103 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
104 |
|
105 |
def preprocess_function(self, data):
|
|
|
127 |
|
128 |
def predict(self, spans):
|
129 |
return self.predict_proba(spans).argmax(axis=1)
|
130 |
+
|
131 |
+
|
132 |
+
class InsideOutsideStringPredictor:
|
133 |
+
|
134 |
+
def __init__(self, model_name_or_path, max_seq_length, pre_trained_model_path, num_workers=32):
|
135 |
+
self.model_name_or_path = model_name_or_path
|
136 |
+
self.pre_trained_model_path = pre_trained_model_path
|
137 |
+
self.max_seq_length = max_seq_length
|
138 |
+
self.num_workers = num_workers
|
139 |
+
|
140 |
+
def predict_proba(self, test_df):
|
141 |
+
test_dataloader = data_module = DataModule(
|
142 |
+
model_name_or_path=self.model_name_or_path,
|
143 |
+
train_df=None,
|
144 |
+
eval_df=None,
|
145 |
+
test_df=test_df,
|
146 |
+
max_seq_length=self.max_seq_length,
|
147 |
+
num_workers=self.num_workers,
|
148 |
+
)
|
149 |
+
|
150 |
+
return trainer.predict(model, dataloaders=test_dataloader)
|
weakly_supervised_parser/utils/populate_chart.py
CHANGED
@@ -1,17 +1,24 @@
|
|
1 |
import pandas as pd
|
2 |
import numpy as np
|
|
|
3 |
|
4 |
from datasets.utils import set_progress_bar_enabled
|
5 |
|
6 |
from weakly_supervised_parser.utils.prepare_dataset import NGramify
|
7 |
from weakly_supervised_parser.utils.create_inside_outside_strings import InsideOutside
|
|
|
8 |
from weakly_supervised_parser.utils.cky_algorithm import get_best_parse
|
9 |
from weakly_supervised_parser.utils.distant_supervision import RuleBasedHeuristic
|
10 |
from weakly_supervised_parser.utils.prepare_dataset import PTBDataset
|
11 |
from weakly_supervised_parser.settings import PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH
|
12 |
|
|
|
|
|
|
|
13 |
# Disable Dataset.map progress bar
|
14 |
set_progress_bar_enabled(False)
|
|
|
|
|
15 |
|
16 |
# ptb = PTBDataset(data_path=PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH)
|
17 |
# ptb_top_100_common = [item.lower() for item in RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).get_top_tokens(top_most_common_ptb=100)]
|
@@ -19,6 +26,10 @@ ptb_top_100_common = ['this', 'myself', 'shouldn', 'not', 'analysts', 'same', 'm
|
|
19 |
# ptb_most_common_first_token = RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).augment_using_most_frequent_starting_token(N=1)[0][0].lower()
|
20 |
ptb_most_common_first_token = "the"
|
21 |
|
|
|
|
|
|
|
|
|
22 |
|
23 |
class PopulateCKYChart:
|
24 |
def __init__(self, sentence):
|
@@ -43,17 +54,20 @@ class PopulateCKYChart:
|
|
43 |
|
44 |
if predict_type == "inside":
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
57 |
data["inside_scores"] = inside_scores
|
58 |
data.loc[
|
59 |
(data["inside_sentence"].str.lower().str.startswith(ptb_most_common_first_token))
|
|
|
1 |
import pandas as pd
|
2 |
import numpy as np
|
3 |
+
import logging
|
4 |
|
5 |
from datasets.utils import set_progress_bar_enabled
|
6 |
|
7 |
from weakly_supervised_parser.utils.prepare_dataset import NGramify
|
8 |
from weakly_supervised_parser.utils.create_inside_outside_strings import InsideOutside
|
9 |
+
from weakly_supervised_parser.model.trainer import InsideOutsideStringPredictor
|
10 |
from weakly_supervised_parser.utils.cky_algorithm import get_best_parse
|
11 |
from weakly_supervised_parser.utils.distant_supervision import RuleBasedHeuristic
|
12 |
from weakly_supervised_parser.utils.prepare_dataset import PTBDataset
|
13 |
from weakly_supervised_parser.settings import PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH
|
14 |
|
15 |
+
from weakly_supervised_parser.model.data_module_loader import DataModule
|
16 |
+
from weakly_supervised_parser.model.span_classifier import LightningModel
|
17 |
+
|
18 |
# Disable Dataset.map progress bar
|
19 |
set_progress_bar_enabled(False)
|
20 |
+
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)
|
21 |
+
|
22 |
|
23 |
# ptb = PTBDataset(data_path=PTB_TRAIN_SENTENCES_WITHOUT_PUNCTUATION_PATH)
|
24 |
# ptb_top_100_common = [item.lower() for item in RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).get_top_tokens(top_most_common_ptb=100)]
|
|
|
26 |
# ptb_most_common_first_token = RuleBasedHeuristic(corpus=ptb.retrieve_all_sentences()).augment_using_most_frequent_starting_token(N=1)[0][0].lower()
|
27 |
ptb_most_common_first_token = "the"
|
28 |
|
29 |
+
from pytorch_lightning import Trainer
|
30 |
+
|
31 |
+
trainer = Trainer(accelerator="auto", enable_progress_bar=False, gpus=-1)
|
32 |
+
|
33 |
|
34 |
class PopulateCKYChart:
|
35 |
def __init__(self, sentence):
|
|
|
54 |
|
55 |
if predict_type == "inside":
|
56 |
|
57 |
+
# if data.shape[0] > chunks:
|
58 |
+
# data_chunks = np.array_split(data, data.shape[0] // chunks)
|
59 |
+
# for data_chunk in data_chunks:
|
60 |
+
# inside_scores.extend(model.predict_proba(spans=data_chunk.rename(columns={"inside_sentence": "sentence"})[["sentence"]],
|
61 |
+
# scale_axis=scale_axis,
|
62 |
+
# predict_batch_size=predict_batch_size)[:, 1])
|
63 |
+
# else:
|
64 |
+
# inside_scores.extend(model.predict_proba(spans=data.rename(columns={"inside_sentence": "sentence"})[["sentence"]],
|
65 |
+
# scale_axis=scale_axis,
|
66 |
+
# predict_batch_size=predict_batch_size)[:, 1])
|
67 |
+
test_dataloader = DataModule(model_name_or_path="roberta-base", train_df=None, eval_df=None,
|
68 |
+
test_df=data.rename(columns={"inside_sentence": "sentence"})[["sentence"]])
|
69 |
+
inside_scores.extend(trainer.predict(model, dataloaders=test_dataloader)[0])
|
70 |
+
|
71 |
data["inside_scores"] = inside_scores
|
72 |
data.loc[
|
73 |
(data["inside_sentence"].str.lower().str.startswith(ptb_most_common_first_token))
|