Spaces:
Build error
Build error
import torch | |
import string | |
import streamlit as st | |
from transformers import GPT2LMHeadModel, PreTrainedTokenizerFast | |
def get_model(): | |
model = GPT2LMHeadModel.from_pretrained('skt/kogpt2-base-v2') | |
model.eval() | |
return model | |
tokenizer = PreTrainedTokenizerFast.from_pretrained("skt/kogpt2-base-v2", | |
bos_token='</s>', | |
eos_token='</s>', | |
unk_token='<unk>', | |
pad_token='<pad>', | |
mask_token='<mask>') | |
default_text = "ํ๋์ธ๋ค์ ์ ํญ์ ๋ถ์ํด ํ ๊น?" | |
N_SENT = 3 | |
model = get_model() | |
st.title("KoGPT2 Demo Page(ver 2.0)") | |
st.markdown(""" | |
### ๋ชจ๋ธ | |
| Model | # of params | Type | # of layers | # of heads | ffn_dim | hidden_dims | | |
|--------------|:----:|:-------:|--------:|--------:|--------:|--------------:| | |
| `KoGPT2` | 125M | Decoder | 12 | 12 | 3072 | 768 | | |
### ์ํ๋ง ๋ฐฉ๋ฒ | |
- greedy sampling | |
- ์ต๋ ์ถ๋ ฅ ๊ธธ์ด : 128/1,024 | |
## Conditional Generation | |
""") | |
text = st.text_area("Input Text:", value=default_text) | |
st.write(text) | |
punct = ('!', '?', '.') | |
if text: | |
st.markdown("## Predict") | |
with st.spinner('processing..'): | |
print(f'input > {text}') | |
input_ids = tokenizer(text)['input_ids'] | |
gen_ids = model.generate(torch.tensor([input_ids]), | |
max_length=128, | |
repetition_penalty=2.0) | |
generated = tokenizer.decode(gen_ids[0,:].tolist()).strip() | |
if generated != '' and generated[-1] not in punct: | |
for i in reversed(range(len(generated))): | |
if generated[i] in punct: | |
break | |
generated = generated[:(i+1)] | |
print(f'KoGPT > {generated}') | |
st.write(generated) | |