haven-jeon commited on
Commit
54abea9
โ€ข
1 Parent(s): 9812bd8

Update example code

Browse files
Files changed (1) hide show
  1. README.md +6 -13
README.md CHANGED
@@ -14,23 +14,16 @@ import torch
14
  from transformers import PreTrainedTokenizerFast
15
  from transformers import BartForConditionalGeneration
16
 
17
- tokenizer = PreTrainedTokenizerFast.from_pretrained(
18
- 'gogamza/kobart-summarization')
19
-
20
  model = BartForConditionalGeneration.from_pretrained('gogamza/kobart-summarization')
21
 
22
- text = "๊ณผ๊ฑฐ๋ฅผ ๋– ์˜ฌ๋ ค๋ณด์ž. ๋ฐฉ์†ก์„ ๋ณด๋˜ ์šฐ๋ฆฌ์˜ ๋ชจ์Šต์„..."
23
 
24
  raw_input_ids = tokenizer.encode(text)
25
- input_ids = [tokenizer.bos_token_id] + \\
26
- raw_input_ids + [tokenizer.eos_token_id]
27
- summary_ids = model.generate(torch.tensor([input_ids]),
28
- max_length=150,
29
- early_stopping=False,
30
- num_beams=5,
31
- repetition_penalty=1.0,
32
- eos_token_id=tokenizer.eos_token_id)
33
- summ_text = tokenizer.batch_decode(summary_ids.tolist(), skip_special_tokens=True)[0]
34
  ```
35
 
36
 
 
14
  from transformers import PreTrainedTokenizerFast
15
  from transformers import BartForConditionalGeneration
16
 
17
+ tokenizer = PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-summarization')
 
 
18
  model = BartForConditionalGeneration.from_pretrained('gogamza/kobart-summarization')
19
 
20
+ text = "๊ณผ๊ฑฐ๋ฅผ ๋– ์˜ฌ๋ ค๋ณด์ž. ๋ฐฉ์†ก์„ ๋ณด๋˜ ์šฐ๋ฆฌ์˜ ๋ชจ์Šต์„. ๋…๋ณด์ ์ธ ๋งค์ฒด๋Š” TV์˜€๋‹ค. ์˜จ ๊ฐ€์กฑ์ด ๋‘˜๋Ÿฌ์•‰์•„ TV๋ฅผ ๋ดค๋‹ค. ๊ฐ„ํ˜น ๊ฐ€์กฑ๋“ค๋ผ๋ฆฌ ๋‰ด์Šค์™€ ๋“œ๋ผ๋งˆ, ์˜ˆ๋Šฅ ํ”„๋กœ๊ทธ๋žจ์„ ๋‘˜๋Ÿฌ์‹ธ๊ณ  ๋ฆฌ๋ชจ์ปจ ์Ÿํƒˆ์ „์ด ๋ฒŒ์–ด์ง€๊ธฐ๋„ ํ–ˆ๋‹ค. ๊ฐ์ž ์„ ํ˜ธํ•˜๋Š” ํ”„๋กœ๊ทธ๋žจ์„ โ€˜๋ณธ๋ฐฉโ€™์œผ๋กœ ๋ณด๊ธฐ ์œ„ํ•œ ์‹ธ์›€์ด์—ˆ๋‹ค. TV๊ฐ€ ํ•œ ๋Œ€์ธ์ง€ ๋‘ ๋Œ€์ธ์ง€ ์—ฌ๋ถ€๋„ ๊ทธ๋ž˜์„œ ์ค‘์š”ํ–ˆ๋‹ค. ์ง€๊ธˆ์€ ์–ด๋–ค๊ฐ€. โ€˜์•ˆ๋ฐฉ๊ทน์žฅโ€™์ด๋ผ๋Š” ๋ง์€ ์˜›๋ง์ด ๋๋‹ค. TV๊ฐ€ ์—†๋Š” ์ง‘๋„ ๋งŽ๋‹ค. ๋ฏธ๋””์–ด์˜ ํ˜œ ํƒ์„ ๋ˆ„๋ฆด ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์€ ๋Š˜์–ด๋‚ฌ๋‹ค. ๊ฐ์ž์˜ ๋ฐฉ์—์„œ ๊ฐ์ž์˜ ํœด๋Œ€ํฐ์œผ๋กœ, ๋…ธํŠธ๋ถ์œผ๋กœ, ํƒœ๋ธ”๋ฆฟ์œผ๋กœ ์ฝ˜ํ…์ธ  ๋ฅผ ์ฆ๊ธด๋‹ค."
21
 
22
  raw_input_ids = tokenizer.encode(text)
23
+ input_ids = [tokenizer.bos_token_id] + raw_input_ids + [tokenizer.eos_token_id]
24
+
25
+ summary_ids = model.generate(torch.tensor([input_ids]))
26
+ tokenizer.decode(summary_ids.squeeze().tolist(), skip_special_tokens=True)
 
 
 
 
 
27
  ```
28
 
29