haven-jeon
commited on
Commit
โข
54abea9
1
Parent(s):
9812bd8
Update example code
Browse files
README.md
CHANGED
@@ -14,23 +14,16 @@ import torch
|
|
14 |
from transformers import PreTrainedTokenizerFast
|
15 |
from transformers import BartForConditionalGeneration
|
16 |
|
17 |
-
tokenizer = PreTrainedTokenizerFast.from_pretrained(
|
18 |
-
'gogamza/kobart-summarization')
|
19 |
-
|
20 |
model = BartForConditionalGeneration.from_pretrained('gogamza/kobart-summarization')
|
21 |
|
22 |
-
text = "๊ณผ๊ฑฐ๋ฅผ ๋ ์ฌ๋ ค๋ณด์. ๋ฐฉ์ก์ ๋ณด๋ ์ฐ๋ฆฌ์
|
23 |
|
24 |
raw_input_ids = tokenizer.encode(text)
|
25 |
-
input_ids = [tokenizer.bos_token_id] +
|
26 |
-
|
27 |
-
summary_ids = model.generate(torch.tensor([input_ids])
|
28 |
-
|
29 |
-
early_stopping=False,
|
30 |
-
num_beams=5,
|
31 |
-
repetition_penalty=1.0,
|
32 |
-
eos_token_id=tokenizer.eos_token_id)
|
33 |
-
summ_text = tokenizer.batch_decode(summary_ids.tolist(), skip_special_tokens=True)[0]
|
34 |
```
|
35 |
|
36 |
|
|
|
14 |
from transformers import PreTrainedTokenizerFast
|
15 |
from transformers import BartForConditionalGeneration
|
16 |
|
17 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-summarization')
|
|
|
|
|
18 |
model = BartForConditionalGeneration.from_pretrained('gogamza/kobart-summarization')
|
19 |
|
20 |
+
text = "๊ณผ๊ฑฐ๋ฅผ ๋ ์ฌ๋ ค๋ณด์. ๋ฐฉ์ก์ ๋ณด๋ ์ฐ๋ฆฌ์ ๋ชจ์ต์. ๋
๋ณด์ ์ธ ๋งค์ฒด๋ TV์๋ค. ์จ ๊ฐ์กฑ์ด ๋๋ฌ์์ TV๋ฅผ ๋ดค๋ค. ๊ฐํน ๊ฐ์กฑ๋ค๋ผ๋ฆฌ ๋ด์ค์ ๋๋ผ๋ง, ์๋ฅ ํ๋ก๊ทธ๋จ์ ๋๋ฌ์ธ๊ณ ๋ฆฌ๋ชจ์ปจ ์ํ์ ์ด ๋ฒ์ด์ง๊ธฐ๋ ํ๋ค. ๊ฐ์ ์ ํธํ๋ ํ๋ก๊ทธ๋จ์ โ๋ณธ๋ฐฉโ์ผ๋ก ๋ณด๊ธฐ ์ํ ์ธ์์ด์๋ค. TV๊ฐ ํ ๋์ธ์ง ๋ ๋์ธ์ง ์ฌ๋ถ๋ ๊ทธ๋์ ์ค์ํ๋ค. ์ง๊ธ์ ์ด๋ค๊ฐ. โ์๋ฐฉ๊ทน์ฅโ์ด๋ผ๋ ๋ง์ ์๋ง์ด ๋๋ค. TV๊ฐ ์๋ ์ง๋ ๋ง๋ค. ๋ฏธ๋์ด์ ํ ํ์ ๋๋ฆด ์ ์๋ ๋ฐฉ๋ฒ์ ๋์ด๋ฌ๋ค. ๊ฐ์์ ๋ฐฉ์์ ๊ฐ์์ ํด๋ํฐ์ผ๋ก, ๋
ธํธ๋ถ์ผ๋ก, ํ๋ธ๋ฆฟ์ผ๋ก ์ฝํ
์ธ ๋ฅผ ์ฆ๊ธด๋ค."
|
21 |
|
22 |
raw_input_ids = tokenizer.encode(text)
|
23 |
+
input_ids = [tokenizer.bos_token_id] + raw_input_ids + [tokenizer.eos_token_id]
|
24 |
+
|
25 |
+
summary_ids = model.generate(torch.tensor([input_ids]))
|
26 |
+
tokenizer.decode(summary_ids.squeeze().tolist(), skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
27 |
```
|
28 |
|
29 |
|