Usage
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained('TeraSpace/dialofred')
model = AutoModelForSeq2SeqLM.from_pretrained('TeraSpace/dialofred', device_map=device)# Add torch_dtype=torch.bfloat16 to use less memory
while True:
text_inp = input("=>")
lm_text=f'<SC1>- {text_inp}\n- <extra_id_0>'
input_ids=torch.tensor([tokenizer.encode(lm_text)]).to(model.device)
# outputs=model.generate(input_ids=input_ids,
# max_length=200,
# eos_token_id=tokenizer.eos_token_id,
# early_stopping=True,
# do_sample=True,
# temperature=1.0,
# top_k=0,
# top_p=0.85)
# outputs=model.generate(input_ids,eos_token_id=tokenizer.eos_token_id,early_stopping=True)
outputs=model.generate(input_ids=input_ids,
max_length=200,
eos_token_id=tokenizer.eos_token_id,
early_stopping=True,
do_sample=True,
temperature=0.7,
top_k=0,
top_p=0.8)
print(tokenizer.decode(outputs[0][1:]))
- Downloads last month
- 59
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.