BAAI
/

AquilaChat-7B / chat_test.py
shunxing1234's picture
Update chat_test.py
7034db9
raw
history blame contribute delete
843 Bytes
#If you need to use this code, please install the following transformers
#https://github.com/shunxing1234/transformers
from transformers import AutoTokenizer, AquilaForCausalLM
import torch
from cyg_conversation import default_conversation, covert_prompt_to_input_ids_with_history
tokenizer = AutoTokenizer.from_pretrained("BAAI/AquilaChat-7B")
model = AquilaForCausalLM.from_pretrained("BAAI/AquilaChat-7B")
model.eval()
model.to("cuda:4")
vocab = tokenizer.vocab
print(len(vocab))
text = "请给出10个要到北京旅游的理由。"
tokens = covert_prompt_to_input_ids_with_history(text, history=[], tokenizer=tokenizer, max_token=512)
tokens = torch.tensor(tokens)[None,].to("cuda:4")
out = model.generate(tokens, do_sample=True, max_length=512, eos_token_id=100007)[0]
out = tokenizer.decode(out.cpu().numpy().tolist())
print(out)