Shredder commited on
Commit
8380634
1 Parent(s): 86b14f0

Upload predict.py

Browse files
Files changed (1) hide show
  1. predict.py +126 -0
predict.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
4
+ from multiprocessing import cpu_count
5
+
6
+ from transformers import (
7
+ AutoConfig,
8
+ AutoModelForQuestionAnswering,
9
+ AutoTokenizer,
10
+ squad_convert_examples_to_features
11
+ )
12
+
13
+ from transformers.data.processors.squad import SquadResult, SquadV2Processor, SquadExample
14
+ from transformers.data.metrics.squad_metrics import compute_predictions_logits
15
+
16
+
17
+ def run_prediction(question_texts, context_text, model_path, n_best_size=1):
18
+ max_seq_length = 512
19
+ doc_stride = 256
20
+ n_best_size = n_best_size
21
+ max_query_length = 64
22
+ max_answer_length = 512
23
+ do_lower_case = False
24
+ null_score_diff_threshold = 0.0
25
+
26
+ def to_list(tensor):
27
+ return tensor.detach().cpu().tolist()
28
+
29
+ config_class, model_class, tokenizer_class = (AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer)
30
+ config = config_class.from_pretrained(model_path)
31
+ tokenizer = tokenizer_class.from_pretrained(model_path, do_lower_case=True, use_fast=False)
32
+ model = model_class.from_pretrained(model_path, config=config)
33
+
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ model.to(device)
36
+
37
+ processor = SquadV2Processor()
38
+ examples = []
39
+
40
+ timer = time.time()
41
+ for i, question_text in enumerate(question_texts):
42
+
43
+ example = SquadExample(
44
+ qas_id=str(i),
45
+ question_text=question_text,
46
+ context_text=context_text,
47
+ answer_text=None,
48
+ start_position_character=None,
49
+ title="Predict",
50
+ answers=None,
51
+ )
52
+
53
+ examples.append(example)
54
+ print(f'Created Squad Examples in {time.time()-timer} seconds')
55
+
56
+ print(f'Number of CPUs: {cpu_count()}')
57
+ timer = time.time()
58
+ features, dataset = squad_convert_examples_to_features(
59
+ examples=examples,
60
+ tokenizer=tokenizer,
61
+ max_seq_length=max_seq_length,
62
+ doc_stride=doc_stride,
63
+ max_query_length=max_query_length,
64
+ is_training=False,
65
+ return_dataset="pt",
66
+ threads=cpu_count(),
67
+ )
68
+ print(f'Converted Examples to Features in {time.time()-timer} seconds')
69
+
70
+ eval_sampler = SequentialSampler(dataset)
71
+ eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=10)
72
+
73
+ all_results = []
74
+
75
+ timer = time.time()
76
+ for batch in eval_dataloader:
77
+ model.eval()
78
+ batch = tuple(t.to(device) for t in batch)
79
+
80
+ with torch.no_grad():
81
+ inputs = {
82
+ "input_ids": batch[0],
83
+ "attention_mask": batch[1],
84
+ "token_type_ids": batch[2],
85
+ }
86
+
87
+ example_indices = batch[3]
88
+
89
+ outputs = model(**inputs)
90
+
91
+ for i, example_index in enumerate(example_indices):
92
+ eval_feature = features[example_index.item()]
93
+ unique_id = int(eval_feature.unique_id)
94
+
95
+ output = [to_list(output[i]) for output in outputs.to_tuple()]
96
+
97
+ start_logits, end_logits = output
98
+ result = SquadResult(unique_id, start_logits, end_logits)
99
+ all_results.append(result)
100
+ print(f'Model predictions completed in {time.time()-timer} seconds')
101
+
102
+ print(all_results)
103
+
104
+ output_nbest_file = None
105
+ if n_best_size > 1:
106
+ output_nbest_file = "nbest.json"
107
+
108
+ timer = time.time()
109
+ final_predictions = compute_predictions_logits(
110
+ all_examples=examples,
111
+ all_features=features,
112
+ all_results=all_results,
113
+ n_best_size=n_best_size,
114
+ max_answer_length=max_answer_length,
115
+ do_lower_case=do_lower_case,
116
+ output_prediction_file=None,
117
+ output_nbest_file=output_nbest_file,
118
+ output_null_log_odds_file=None,
119
+ verbose_logging=False,
120
+ version_2_with_negative=True,
121
+ null_score_diff_threshold=null_score_diff_threshold,
122
+ tokenizer=tokenizer
123
+ )
124
+ print(f'Logits converted to predictions in {time.time()-timer} seconds')
125
+
126
+ return final_predictions