Spaces:
Build error
Build error
File size: 7,901 Bytes
47c0211 245d478 47c0211 245d478 47c0211 245d478 47c0211 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from argparse import ArgumentParser
from loguru import logger
from weakly_supervised_parser.settings import TRAINED_MODEL_PATH
from weakly_supervised_parser.utils.prepare_dataset import DataLoaderHelper
from weakly_supervised_parser.utils.populate_chart import PopulateCKYChart
from weakly_supervised_parser.tree.evaluate import calculate_F1_for_spans, tree_to_spans
from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier
from weakly_supervised_parser.settings import PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH, PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH
from weakly_supervised_parser.model.span_classifier import LightningModel
class Predictor:
def __init__(self, sentence):
self.sentence = sentence
self.sentence_list = sentence.split()
def obtain_best_parse(self, predict_type, model, scale_axis, predict_batch_size, return_df=False):
unique_tokens_flag, span_scores, df = PopulateCKYChart(sentence=self.sentence).fill_chart(predict_type=predict_type,
model=model,
scale_axis=scale_axis,
predict_batch_size=predict_batch_size)
if unique_tokens_flag:
best_parse = "(S " + " ".join(["(S " + item + ")" for item in self.sentence_list]) + ")"
logger.info("BEST PARSE", best_parse)
else:
best_parse = PopulateCKYChart(sentence=self.sentence).best_parse_tree(span_scores)
if return_df:
return best_parse, df
return best_parse
def process_test_sample(index, sentence, gold_file_path, predict_type, model, scale_axis, predict_batch_size, return_df=False):
best_parse, df = Predictor(sentence=sentence).obtain_best_parse(predict_type=predict_type,
model=model,
scale_axis=scale_axis,
predict_batch_size=predict_batch_size,
return_df=True)
gold_standard = DataLoaderHelper(input_file_object=gold_file_path)
sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard[index]), tree_to_spans(best_parse))
if sentence_f1 < 25.0:
logger.warning(f"Index: {index} <> F1: {sentence_f1:.2f}")
else:
logger.info(f"Index: {index} <> F1: {sentence_f1:.2f}")
if return_df:
return best_parse, df
else:
return best_parse
def process_co_train_test_sample(index, sentence, gold_file_path, inside_model, outside_model, return_df=False):
_, df_inside = PopulateCKYChart(sentence=sentence).compute_scores(predict_type="inside", model=inside_model, return_df=True)
_, df_outside = PopulateCKYChart(sentence=sentence).compute_scores(predict_type="outside", model=outside_model, return_df=True)
df = df_inside.copy()
df["scores"] = df_inside["scores"] * df_outside["scores"]
_, span_scores, df = PopulateCKYChart(sentence=sentence).fill_chart(data=df)
best_parse = PopulateCKYChart(sentence=sentence).best_parse_tree(span_scores)
gold_standard = DataLoaderHelper(input_file_object=gold_file_path)
sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard[index]), tree_to_spans(best_parse))
if sentence_f1 < 25.0:
logger.warning(f"Index: {index} <> F1: {sentence_f1:.2f}")
else:
logger.info(f"Index: {index} <> F1: {sentence_f1:.2f}")
return best_parse
def main():
parser = ArgumentParser(description="Inference Pipeline for the Inside Outside String Classifier", add_help=True)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--use_inside", action="store_true", help="Whether to predict using inside model")
group.add_argument("--use_inside_self_train", action="store_true", help="Whether to predict using inside model with self-training")
group.add_argument("--use_outside", action="store_true", help="Whether to predict using outside model")
group.add_argument("--use_inside_outside_co_train", action="store_true", help="Whether to predict using inside-outside model with co-training")
parser.add_argument("--model_name_or_path", type=str, default="roberta-base", help="Path to the model identifier from huggingface.co/models")
parser.add_argument("--save_path", type=str, required=True, help="Path to save the final trees")
parser.add_argument("--scale_axis", choices=[None, 1], default=None, help="Whether to scale axis globally (None) or sequentially (1) across batches during softmax computation")
parser.add_argument("--predict_batch_size", type=int, help="Batch size during inference")
parser.add_argument(
"--inside_max_seq_length", default=256, type=int, help="The maximum total input sequence length after tokenization for the inside model"
)
parser.add_argument(
"--outside_max_seq_length", default=64, type=int, help="The maximum total input sequence length after tokenization for the outside model"
)
args = parser.parse_args()
if args.use_inside:
pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model.ckpt"
max_seq_length = args.inside_max_seq_length
if args.use_inside_self_train:
pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model_self_trained.onnx"
max_seq_length = args.inside_max_seq_length
if args.use_outside:
pre_trained_model_path = TRAINED_MODEL_PATH + "outside_model.onnx"
max_seq_length = args.outside_max_seq_length
if args.use_inside_outside_co_train:
inside_pre_trained_model_path = "inside_model_co_trained.onnx"
inside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.inside_max_seq_length)
inside_model.load_model(pre_trained_model_path=inside_pre_trained_model_path)
outside_pre_trained_model_path = "outside_model_co_trained.onnx"
outside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.outside_max_seq_length)
outside_model.load_model(pre_trained_model_path=outside_pre_trained_model_path)
else:
# model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=max_seq_length)
# model.load_model(pre_trained_model_path=pre_trained_model_path)
model = LightningModel.load_from_checkpoint(checkpoint_path=pre_trained_model_path)
if args.use_inside or args.use_inside_self_train:
predict_type = "inside"
if args.use_outside:
predict_type = "outside"
with open(args.save_path, "w") as out_file:
test_sentences = DataLoaderHelper(input_file_object=PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH).read_lines()
test_gold_file_path = PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH
for test_index, test_sentence in enumerate(test_sentences):
if args.use_inside_outside_co_train:
best_parse = process_co_train_test_sample(
test_index, test_sentence, test_gold_file_path, inside_model=inside_model, outside_model=outside_model
)
else:
best_parse = process_test_sample(test_index, test_sentence, test_gold_file_path, predict_type=predict_type, model=model,
scale_axis=args.scale_axis, predict_batch_size=args.predict_batch_size)
out_file.write(best_parse + "\n")
if __name__ == "__main__":
main()
|