|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
class Simulator: |
|
|
|
def __init__(self, model_name_or_path): |
|
""" |
|
在传递 device_map 时,low_cpu_mem_usage 会自动设置为 True |
|
""" |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_name_or_path, |
|
torch_dtype="auto", |
|
device_map="auto" |
|
) |
|
self.model.eval() |
|
|
|
self.generation_kwargs = dict( |
|
do_sample=False, |
|
temperature=0.7, |
|
max_length=500, |
|
max_new_tokens=10 |
|
) |
|
|
|
generation_kwargs = dict( |
|
|
|
max_length=500, |
|
max_new_tokens=200 |
|
) |
|
|
|
print(1) |
|
|
|
def generate_query(self, history): |
|
""" user simulator |
|
:param history: |
|
:return: |
|
""" |
|
raise NotImplementedError |
|
|
|
def generate_response(self, input, history): |
|
""" assistant simulator |
|
:param input: |
|
:param history: |
|
:return: |
|
""" |
|
raise NotImplementedError |
|
|