self-chat / simulator.py
xu song
update
d72c532
raw
history blame
1.13 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
class Simulator:
def __init__(self, model_name_or_path):
"""
在传递 device_map 时,low_cpu_mem_usage 会自动设置为 True
"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype="auto",
device_map="auto"
)
self.model.eval()
self.generation_kwargs = dict(
do_sample=False,
temperature=0.7,
max_length=500,
max_new_tokens=10
)
generation_kwargs = dict(
max_length=500,
max_new_tokens=200
)
print(1)
def generate_query(self, history):
""" user simulator
:param history:
:return:
"""
raise NotImplementedError
def generate_response(self, input, history):
""" assistant simulator
:param input:
:param history:
:return:
"""
raise NotImplementedError