model_dir / README.md
julien-c's picture
julien-c HF staff
Migrate model card from transformers-repo
32ac4ff
|
raw
history blame
No virus
1.59 kB
## Generating Chinese poetry by topic.
```python
from transformers import *
tokenizer = BertTokenizer.from_pretrained("gaochangkuan/model_dir")
model = AutoModelWithLMHead.from_pretrained("gaochangkuan/model_dir")
prompt= '''<s>田园躬耕'''
length= 84
stop_token='</s>'
temperature = 1.2
repetition_penalty=1.3
k= 30
p= 0.95
device ='cuda'
seed=2020
no_cuda=False
prompt_text = prompt if prompt else input("Model prompt >>> ")
encoded_prompt = tokenizer.encode(
'<s>'+prompt_text+'<sep>',
add_special_tokens=False,
return_tensors="pt"
)
encoded_prompt = encoded_prompt.to(device)
output_sequences = model.generate(
input_ids=encoded_prompt,
max_length=length,
min_length=10,
do_sample=True,
early_stopping=True,
num_beams=10,
temperature=temperature,
top_k=k,
top_p=p,
repetition_penalty=repetition_penalty,
bad_words_ids=None,
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
length_penalty=1.2,
no_repeat_ngram_size=2,
num_return_sequences=1,
attention_mask=None,
decoder_start_token_id=tokenizer.bos_token_id,)
generated_sequence = output_sequences[0].tolist()
text = tokenizer.decode(generated_sequence)
text = text[: text.find(stop_token) if stop_token else None]
print(''.join(text).replace(' ','').replace('<pad>','').replace('<s>',''))
```