Spaces:
Build error
Build error
from argparse import ArgumentParser | |
from loguru import logger | |
from weakly_supervised_parser.settings import TRAINED_MODEL_PATH | |
from weakly_supervised_parser.utils.prepare_dataset import DataLoaderHelper | |
from weakly_supervised_parser.utils.populate_chart import PopulateCKYChart | |
from weakly_supervised_parser.tree.evaluate import calculate_F1_for_spans, tree_to_spans | |
from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier | |
from weakly_supervised_parser.settings import PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH, PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH | |
from weakly_supervised_parser.model.span_classifier import LightningModel | |
class Predictor: | |
def __init__(self, sentence): | |
self.sentence = sentence | |
self.sentence_list = sentence.split() | |
def obtain_best_parse(self, predict_type, model, scale_axis, predict_batch_size, return_df=False): | |
unique_tokens_flag, span_scores, df = PopulateCKYChart(sentence=self.sentence).fill_chart(predict_type=predict_type, | |
model=model, | |
scale_axis=scale_axis, | |
predict_batch_size=predict_batch_size) | |
if unique_tokens_flag: | |
best_parse = "(S " + " ".join(["(S " + item + ")" for item in self.sentence_list]) + ")" | |
logger.info("BEST PARSE", best_parse) | |
else: | |
best_parse = PopulateCKYChart(sentence=self.sentence).best_parse_tree(span_scores) | |
if return_df: | |
return best_parse, df | |
return best_parse | |
def process_test_sample(index, sentence, gold_file_path, predict_type, model, scale_axis, predict_batch_size, return_df=False): | |
best_parse, df = Predictor(sentence=sentence).obtain_best_parse(predict_type=predict_type, | |
model=model, | |
scale_axis=scale_axis, | |
predict_batch_size=predict_batch_size, | |
return_df=True) | |
gold_standard = DataLoaderHelper(input_file_object=gold_file_path) | |
sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard[index]), tree_to_spans(best_parse)) | |
if sentence_f1 < 25.0: | |
logger.warning(f"Index: {index} <> F1: {sentence_f1:.2f}") | |
else: | |
logger.info(f"Index: {index} <> F1: {sentence_f1:.2f}") | |
if return_df: | |
return best_parse, df | |
else: | |
return best_parse | |
def process_co_train_test_sample(index, sentence, gold_file_path, inside_model, outside_model, return_df=False): | |
_, df_inside = PopulateCKYChart(sentence=sentence).compute_scores(predict_type="inside", model=inside_model, return_df=True) | |
_, df_outside = PopulateCKYChart(sentence=sentence).compute_scores(predict_type="outside", model=outside_model, return_df=True) | |
df = df_inside.copy() | |
df["scores"] = df_inside["scores"] * df_outside["scores"] | |
_, span_scores, df = PopulateCKYChart(sentence=sentence).fill_chart(data=df) | |
best_parse = PopulateCKYChart(sentence=sentence).best_parse_tree(span_scores) | |
gold_standard = DataLoaderHelper(input_file_object=gold_file_path) | |
sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard[index]), tree_to_spans(best_parse)) | |
if sentence_f1 < 25.0: | |
logger.warning(f"Index: {index} <> F1: {sentence_f1:.2f}") | |
else: | |
logger.info(f"Index: {index} <> F1: {sentence_f1:.2f}") | |
return best_parse | |
def main(): | |
parser = ArgumentParser(description="Inference Pipeline for the Inside Outside String Classifier", add_help=True) | |
group = parser.add_mutually_exclusive_group(required=True) | |
group.add_argument("--use_inside", action="store_true", help="Whether to predict using inside model") | |
group.add_argument("--use_inside_self_train", action="store_true", help="Whether to predict using inside model with self-training") | |
group.add_argument("--use_outside", action="store_true", help="Whether to predict using outside model") | |
group.add_argument("--use_inside_outside_co_train", action="store_true", help="Whether to predict using inside-outside model with co-training") | |
parser.add_argument("--model_name_or_path", type=str, default="roberta-base", help="Path to the model identifier from huggingface.co/models") | |
parser.add_argument("--save_path", type=str, required=True, help="Path to save the final trees") | |
parser.add_argument("--scale_axis", choices=[None, 1], default=None, help="Whether to scale axis globally (None) or sequentially (1) across batches during softmax computation") | |
parser.add_argument("--predict_batch_size", type=int, help="Batch size during inference") | |
parser.add_argument( | |
"--inside_max_seq_length", default=256, type=int, help="The maximum total input sequence length after tokenization for the inside model" | |
) | |
parser.add_argument( | |
"--outside_max_seq_length", default=64, type=int, help="The maximum total input sequence length after tokenization for the outside model" | |
) | |
args = parser.parse_args() | |
if args.use_inside: | |
pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model.ckpt" | |
max_seq_length = args.inside_max_seq_length | |
if args.use_inside_self_train: | |
pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model_self_trained.onnx" | |
max_seq_length = args.inside_max_seq_length | |
if args.use_outside: | |
pre_trained_model_path = TRAINED_MODEL_PATH + "outside_model.onnx" | |
max_seq_length = args.outside_max_seq_length | |
if args.use_inside_outside_co_train: | |
inside_pre_trained_model_path = "inside_model_co_trained.onnx" | |
inside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.inside_max_seq_length) | |
inside_model.load_model(pre_trained_model_path=inside_pre_trained_model_path) | |
outside_pre_trained_model_path = "outside_model_co_trained.onnx" | |
outside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.outside_max_seq_length) | |
outside_model.load_model(pre_trained_model_path=outside_pre_trained_model_path) | |
else: | |
# model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=max_seq_length) | |
# model.load_model(pre_trained_model_path=pre_trained_model_path) | |
model = LightningModel.load_from_checkpoint(checkpoint_path=pre_trained_model_path) | |
if args.use_inside or args.use_inside_self_train: | |
predict_type = "inside" | |
if args.use_outside: | |
predict_type = "outside" | |
with open(args.save_path, "w") as out_file: | |
test_sentences = DataLoaderHelper(input_file_object=PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH).read_lines() | |
test_gold_file_path = PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH | |
for test_index, test_sentence in enumerate(test_sentences): | |
if args.use_inside_outside_co_train: | |
best_parse = process_co_train_test_sample( | |
test_index, test_sentence, test_gold_file_path, inside_model=inside_model, outside_model=outside_model | |
) | |
else: | |
best_parse = process_test_sample(test_index, test_sentence, test_gold_file_path, predict_type=predict_type, model=model, | |
scale_axis=args.scale_axis, predict_batch_size=args.predict_batch_size) | |
out_file.write(best_parse + "\n") | |
if __name__ == "__main__": | |
main() | |