from transformers import AutoModelForCausalLM, AutoTokenizer class Simulator: def __init__(self, model_name_or_path): """ 在传递 device_map 时,low_cpu_mem_usage 会自动设置为 True """ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.model = AutoModelForCausalLM.from_pretrained( model_name_or_path, torch_dtype="auto", device_map="auto" ) self.model.eval() self.generation_kwargs = dict( do_sample=True, temperature=0.7, max_length=500, max_new_tokens=10 ) def generate_query(self, history): """ user simulator :param history: :return: """ raise NotImplementedError def generate_response(self, input, history): """ assistant simulator :param input: :param history: :return: """ raise NotImplementedError