Spaces:
Build error
Build error
haven-jeon
commited on
Commit
β’
96484a5
1
Parent(s):
23b2f5d
fix tokenizer
Browse files
app.py
CHANGED
@@ -1,10 +1,7 @@
|
|
1 |
import torch
|
2 |
import string
|
3 |
import streamlit as st
|
4 |
-
from transformers import GPT2LMHeadModel
|
5 |
-
from tokenizers import Tokenizer
|
6 |
-
|
7 |
-
|
8 |
|
9 |
|
10 |
@st.cache
|
@@ -13,7 +10,13 @@ def get_model():
|
|
13 |
model.eval()
|
14 |
return model
|
15 |
|
16 |
-
tokenizer =
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
default_text = "νλμΈλ€μ μ νμ λΆμν΄ ν κΉ?"
|
19 |
|
@@ -39,26 +42,16 @@ st.markdown("""
|
|
39 |
|
40 |
text = st.text_area("Input Text:", value=default_text)
|
41 |
st.write(text)
|
42 |
-
st.markdown("""
|
43 |
-
> *νμ¬ 2core μΈμ€ν΄μ€μμ μμΈ‘μ΄ μ§νλμ΄ λ€μ λ릴 μ μμ*
|
44 |
-
""")
|
45 |
punct = ('!', '?', '.')
|
46 |
|
47 |
if text:
|
48 |
st.markdown("## Predict")
|
49 |
with st.spinner('processing..'):
|
50 |
print(f'input > {text}')
|
51 |
-
input_ids = tokenizer
|
52 |
gen_ids = model.generate(torch.tensor([input_ids]),
|
53 |
max_length=128,
|
54 |
-
repetition_penalty=2.0
|
55 |
-
# num_beams=2,
|
56 |
-
# length_penalty=1.0,
|
57 |
-
use_cache=True,
|
58 |
-
pad_token_id=tokenizer.token_to_id('<pad>'),
|
59 |
-
eos_token_id=tokenizer.token_to_id('</s>'),
|
60 |
-
bos_token_id=tokenizer.token_to_id('</s>'),
|
61 |
-
bad_words_ids=[[tokenizer.token_to_id('<unk>')] ])
|
62 |
generated = tokenizer.decode(gen_ids[0,:].tolist()).strip()
|
63 |
if generated != '' and generated[-1] not in punct:
|
64 |
for i in reversed(range(len(generated))):
|
|
|
1 |
import torch
|
2 |
import string
|
3 |
import streamlit as st
|
4 |
+
from transformers import GPT2LMHeadModel, PreTrainedTokenizerFast
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
@st.cache
|
|
|
10 |
model.eval()
|
11 |
return model
|
12 |
|
13 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained("skt/kogpt2-base-v2",
|
14 |
+
bos_token='</s>',
|
15 |
+
eos_token='</s>',
|
16 |
+
unk_token='<unk>',
|
17 |
+
pad_token='<pad>',
|
18 |
+
mask_token='<mask>')
|
19 |
+
|
20 |
|
21 |
default_text = "νλμΈλ€μ μ νμ λΆμν΄ ν κΉ?"
|
22 |
|
|
|
42 |
|
43 |
text = st.text_area("Input Text:", value=default_text)
|
44 |
st.write(text)
|
|
|
|
|
|
|
45 |
punct = ('!', '?', '.')
|
46 |
|
47 |
if text:
|
48 |
st.markdown("## Predict")
|
49 |
with st.spinner('processing..'):
|
50 |
print(f'input > {text}')
|
51 |
+
input_ids = tokenizer(text)['input_ids']
|
52 |
gen_ids = model.generate(torch.tensor([input_ids]),
|
53 |
max_length=128,
|
54 |
+
repetition_penalty=2.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
generated = tokenizer.decode(gen_ids[0,:].tolist()).strip()
|
56 |
if generated != '' and generated[-1] not in punct:
|
57 |
for i in reversed(range(len(generated))):
|