Copycats commited on
Commit
8023460
1 Parent(s): 0a73afd

fix dependencies

Browse files
Files changed (2) hide show
  1. app.py +10 -38
  2. requirements.txt +3 -3
app.py CHANGED
@@ -15,36 +15,6 @@ def get_model():
15
  tokenizer, model = get_model()
16
 
17
 
18
- def predict_answer(qa_text_pair):
19
- # Encoding
20
- encodings = tokenizer(
21
- qa_text_pair['question'], qa_text_pair['context'],
22
- max_length=512,
23
- truncation=True,
24
- padding="max_length",
25
- return_token_type_ids=False,
26
- return_offsets_mapping=True
27
- )
28
- encodings = {key: torch.tensor([val]).to(device) for key, val in encodings.items()}
29
-
30
- # Predict
31
- with torch.no_grad():
32
- pred = model(encodings['input_ids'], encodings['attention_mask'])
33
- start_logits, end_logits = pred.start_logits, pred.end_logits
34
- token_start_index, token_end_index = start_logits.argmax(dim=-1), end_logits.argmax(dim=-1)
35
- pred_ids = encodings['input_ids'][0][token_start_index: token_end_index + 1]
36
-
37
- # Answer start/end offset of context.
38
- answer_start_offset = int(encodings['offset_mapping'][0][token_start_index][0][0])
39
- answer_end_offset = int(encodings['offset_mapping'][0][token_end_index][0][1])
40
- answer_offset = (answer_start_offset, answer_end_offset)
41
-
42
- # Decoding
43
- answer_text = tokenizer.decode(pred_ids) # text
44
- del encodings
45
- return {'answer_text':answer_text, 'answer_offset':answer_offset}
46
-
47
-
48
  ## Title
49
  st.title('☁️ Bespin → QuestionAnswering')
50
 
@@ -89,24 +59,26 @@ if st.button("Submit", key='question'):
89
  max_length=512,
90
  truncation=True,
91
  padding="max_length",
92
- return_token_type_ids=False
 
93
  )
94
  encodings = {key: torch.tensor([val]) for key, val in encodings.items()}
95
- input_ids = encodings["input_ids"]
96
- attention_mask = encodings["attention_mask"]
97
 
98
  # Predict
99
- pred = model(input_ids, attention_mask=attention_mask)
100
-
101
  start_logits, end_logits = pred.start_logits, pred.end_logits
102
  token_start_index, token_end_index = start_logits.argmax(dim=-1), end_logits.argmax(dim=-1)
103
- pred_ids = input_ids[0][token_start_index: token_end_index + 1]
104
-
105
- # Decoding
106
  prediction = tokenizer.decode(pred_ids)
107
 
 
 
 
 
 
108
  # answer
109
  st.success(prediction)
110
 
 
111
  except Exception as e:
112
  st.error(e)
 
15
  tokenizer, model = get_model()
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  ## Title
19
  st.title('☁️ Bespin → QuestionAnswering')
20
 
 
59
  max_length=512,
60
  truncation=True,
61
  padding="max_length",
62
+ return_token_type_ids=False,
63
+ return_offsets_mapping=True
64
  )
65
  encodings = {key: torch.tensor([val]) for key, val in encodings.items()}
 
 
66
 
67
  # Predict
68
+ pred = model(encodings["input_ids"], attention_mask=encodings["attention_mask"])
 
69
  start_logits, end_logits = pred.start_logits, pred.end_logits
70
  token_start_index, token_end_index = start_logits.argmax(dim=-1), end_logits.argmax(dim=-1)
71
+ pred_ids = encodings["input_ids"][0][token_start_index: token_end_index + 1]
 
 
72
  prediction = tokenizer.decode(pred_ids)
73
 
74
+ # Offset
75
+ answer_start_offset = int(encodings['offset_mapping'][0][token_start_index][0][0])
76
+ answer_end_offset = int(encodings['offset_mapping'][0][token_end_index][0][1])
77
+ answer_offset = (answer_start_offset, answer_end_offset)
78
+
79
  # answer
80
  st.success(prediction)
81
 
82
+
83
  except Exception as e:
84
  st.error(e)
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- torch
2
- transformers==4.16.0
3
- streamlit
 
1
+ torch==1.11.0
2
+ transformers==4.20.0
3
+ streamlit==1.10.0