Spaces:
Runtime error
Runtime error
import torch | |
import streamlit as st | |
from transformers.models.bart import BartForConditionalGeneration | |
from transformers import PreTrainedTokenizerFast | |
#@st.cache | |
#@st.cache_data(allow_output_mutation=True) | |
def load_model(): | |
#model = BartForConditionalGeneration.from_pretrained('logs/model_chp/epoch-6') | |
model = BartForConditionalGeneration.from_pretrained('LeeJang/news-summarization-v2') | |
# tokenizer = get_kobart_tokenizer() | |
return model | |
model = load_model() | |
tokenizer = PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-base-v1') | |
st.title("2문장 뉴스 요약기") | |
text = st.text_area("뉴스 입력:") | |
st.markdown("## 뉴스 원문") | |
st.write(text) | |
#''' | |
if text: | |
text = text.replace('\n', ' ') | |
text = text.strip() | |
arr = text.split(' ') | |
if len(arr) > 501: | |
#print('!!!') | |
arr = arr[:501] | |
text = ' '.join(arr) | |
st.markdown("## 요약 결과") | |
with st.spinner('processing..'): | |
input_ids = tokenizer.encode(text) | |
input_ids = torch.tensor(input_ids) | |
input_ids = input_ids.unsqueeze(0) | |
output = model.generate(input_ids, eos_token_id=1, max_length=512, num_beams=5) | |
output = tokenizer.decode(output[0], skip_special_tokens=True) | |
st.write(output) | |
#''' | |