kimda commited on
Commit
44f92d5
โ€ข
1 Parent(s): 73e079e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -0
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration
4
+ # from transformers import๋กœ ์‹œ์ž‘ํ•˜๋Š” import ๋ฌธ์„ ๋ณด๋ฉด
5
+ # ๋งŽ์€ ๊ฒฝ์šฐ AutoTokenizer, AutoModel
6
+ # tokenizer = AutoTokenizer.from_pretrained("model ์ด๋ฆ„ ์–ด์ฉŒ๊ณ  ์ €์ฉŒ๊ณ ")
7
+ # PreTrainedTokenizerFast : https://huggingface.co/docs/transformers/main_classes/tokenizer
8
+ # BART๋Š” encoder-decoder ๋ชจ๋ธ์˜ ์˜ˆ์‹œ
9
+
10
+ model_name = "ainize/kobart-news"
11
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name)
12
+ model = BartForConditionalGeneration.from_pretrained(model_name)
13
+
14
+ # ์›๋ฌธ์„ ๋ฐ›์•„์„œ ์š”์•ฝ๋ฌธ์„ ๋ฐ˜ํ™˜
15
+ def summ(txt):
16
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
17
+ summary_text_ids = model.generate(
18
+ input_ids=input_ids,
19
+ bos_token_id=model.config.bos_token_id,
20
+ eos_token_id=model.config.eos_token_id,
21
+ length_penalty=2.0,
22
+ max_length=142,
23
+ min_length=56,
24
+ num_beams=4)
25
+ return tokenizer.decode(summary_text_ids[0], skip_special_tokens=True)
26
+
27
+ interface = gr.Interface(summ,
28
+ [gr.Textbox(label="original text")],
29
+ [gr.Textbox(label="summary")])
30
+
31
+ interface.launch(share=True)