File size: 1,125 Bytes
d72c532 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from transformers import AutoModelForCausalLM, AutoTokenizer
class Simulator:
def __init__(self, model_name_or_path):
"""
在传递 device_map 时,low_cpu_mem_usage 会自动设置为 True
"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype="auto",
device_map="auto"
)
self.model.eval()
self.generation_kwargs = dict(
do_sample=False,
temperature=0.7,
max_length=500,
max_new_tokens=10
)
generation_kwargs = dict(
max_length=500,
max_new_tokens=200
)
print(1)
def generate_query(self, history):
""" user simulator
:param history:
:return:
"""
raise NotImplementedError
def generate_response(self, input, history):
""" assistant simulator
:param input:
:param history:
:return:
"""
raise NotImplementedError
|