JaeHyeong commited on
Commit
fb82652
β€’
1 Parent(s): 5457347
Files changed (2) hide show
  1. app.py +57 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoModelForQuestionAnswering, AutoTokenizer
4
+
5
+
6
+ @st.cache(allow_output_mutation=True)
7
+ def get_model():
8
+ # Load fine-tuned MRC model by HuggingFace Model Hub
9
+ HUGGINGFACE_MODEL_PATH = "bespin-global/klue-bert-base-aihub-mrc"
10
+ tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_PATH )
11
+ model = AutoModelForQuestionAnswering.from_pretrained(HUGGINGFACE_MODEL_PATH )
12
+
13
+ return tokenizer, model
14
+
15
+ tokenizer, model = get_model()
16
+
17
+ ## Title
18
+ st.title('BespinGlobal - Machine Reading Comprehension', anchor='https://huggingface.co/bespin-global/klue-bert-base-aihub-mrc')
19
+
20
+ ## Text
21
+ st.text('bespin-global/klue-bert-base-aihub-mrc λͺ¨λΈ μ„±λŠ₯ ν…ŒμŠ€νŠΈ νŽ˜μ΄μ§€ μž…λ‹ˆλ‹€.')
22
+
23
+ # Text Input
24
+ context = st.text_area("πŸ“‘ Context HERE!", placeholder="Please input some context..", height=300, on_change=None)
25
+
26
+ # Text Area
27
+ question = st.text_area("πŸ’‘ Question HERE!", placeholder="Please input your question..")
28
+ if st.button("Submit", key='question'):
29
+ try:
30
+ # Progress spinner
31
+ with st.spinner('Wait for it...'):
32
+ # Encoding
33
+ encodings = tokenizer(context, question,
34
+ max_length=512,
35
+ truncation=True,
36
+ padding="max_length",
37
+ return_token_type_ids=False
38
+ )
39
+ encodings = {key: torch.tensor([val]) for key, val in encodings.items()}
40
+ input_ids = encodings["input_ids"]
41
+ attention_mask = encodings["attention_mask"]
42
+
43
+ # Predict
44
+ pred = model(input_ids, attention_mask=attention_mask)
45
+
46
+ start_logits, end_logits = pred.start_logits, pred.end_logits
47
+ token_start_index, token_end_index = start_logits.argmax(dim=-1), end_logits.argmax(dim=-1)
48
+ pred_ids = input_ids[0][token_start_index: token_end_index + 1]
49
+
50
+ # Decoding
51
+ prediction = tokenizer.decode(pred_ids)
52
+
53
+ # answer
54
+ st.success(prediction)
55
+
56
+ except Exception as e:
57
+ st.error(e)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ streamlit