Spaces:
Runtime error
Runtime error
added generate.py
Browse files- generate.py +45 -0
generate.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
+
|
4 |
+
|
5 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained("tuanle/VN-News-GPT2", cache_dir="cache/")
|
7 |
+
model = AutoModelForCausalLM.from_pretrained("tuanle/VN-News-GPT2", cache_dir="cache/").to(device)
|
8 |
+
print("Loading model...")
|
9 |
+
print("Model is ready to serve...")
|
10 |
+
|
11 |
+
def generate(category, headline,
|
12 |
+
min_len = 60,
|
13 |
+
max_len = 768,
|
14 |
+
num_beams = 5,
|
15 |
+
num_return_sequences = 3,
|
16 |
+
top_k = 50,
|
17 |
+
top_p = 1):
|
18 |
+
"""
|
19 |
+
top_p: If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
|
20 |
+
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
21 |
+
num_beams: Number of beams for beam search. 1 means no beam search.
|
22 |
+
"""
|
23 |
+
text = f"<|startoftext|> {category} <|headline|> {headline}"
|
24 |
+
|
25 |
+
input_ids = tokenizer.encode(text, return_tensors='pt').to(device)
|
26 |
+
|
27 |
+
sample_outputs = model.generate(input_ids,
|
28 |
+
do_sample=True,
|
29 |
+
max_length=max_len,
|
30 |
+
min_length=min_len,
|
31 |
+
# temperature = .8,
|
32 |
+
top_k= top_k,
|
33 |
+
top_p = top_p,
|
34 |
+
num_beams= num_beams,
|
35 |
+
early_stopping= True,
|
36 |
+
no_repeat_ngram_size= 2 ,
|
37 |
+
num_return_sequences= num_return_sequences)
|
38 |
+
|
39 |
+
outputs = []
|
40 |
+
for i, sample_output in enumerate(sample_outputs):
|
41 |
+
temp = tokenizer.decode(sample_output.tolist())
|
42 |
+
print(f">> Generated text {i+1}\n\n{temp}")
|
43 |
+
print('\n---')
|
44 |
+
outputs.append(temp)
|
45 |
+
return outputs
|