|
import time |
|
import pytest |
|
|
|
from tests.utils import wrap_test_forked |
|
from src.enums import source_prefix, source_postfix |
|
from src.prompter import generate_prompt |
|
|
|
example_data_point0 = dict(instruction="Summarize", |
|
input="Ducks eat seeds by the lake, then swim in the lake where fish eat small animals.", |
|
output="Ducks eat and swim at the lake.") |
|
|
|
example_data_point1 = dict(instruction="Who is smarter, Einstein or Newton?", |
|
output="Einstein.") |
|
|
|
example_data_point2 = dict(input="Who is smarter, Einstein or Newton?", |
|
output="Einstein.") |
|
|
|
example_data_points = [example_data_point0, example_data_point1, example_data_point2] |
|
|
|
|
|
@wrap_test_forked |
|
def test_train_prompt(prompt_type='instruct', data_point=0): |
|
example_data_point = example_data_points[data_point] |
|
return generate_prompt(example_data_point, prompt_type, '', False, False, False) |
|
|
|
|
|
@wrap_test_forked |
|
def test_test_prompt(prompt_type='instruct', data_point=0): |
|
example_data_point = example_data_points[data_point] |
|
example_data_point.pop('output', None) |
|
return generate_prompt(example_data_point, prompt_type, '', False, False, False) |
|
|
|
|
|
@wrap_test_forked |
|
def test_test_prompt2(prompt_type='human_bot', data_point=0): |
|
example_data_point = example_data_points[data_point] |
|
example_data_point.pop('output', None) |
|
res = generate_prompt(example_data_point, prompt_type, '', False, False, False) |
|
print(res, flush=True) |
|
return res |
|
|
|
|
|
prompt_fastchat = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Hello! ASSISTANT: Hi!</s>USER: How are you? ASSISTANT: I'm good</s>USER: Go to the market? ASSISTANT:""" |
|
|
|
prompt_humanbot = """<human>: Hello!\n<bot>: Hi!\n<human>: How are you?\n<bot>: I'm good\n<human>: Go to the market?\n<bot>:""" |
|
|
|
prompt_prompt_answer = "<|prompt|>Hello!<|endoftext|><|answer|>Hi!<|endoftext|><|prompt|>How are you?<|endoftext|><|answer|>I'm good<|endoftext|><|prompt|>Go to the market?<|endoftext|><|answer|>" |
|
|
|
prompt_prompt_answer_openllama = "<|prompt|>Hello!</s><|answer|>Hi!</s><|prompt|>How are you?</s><|answer|>I'm good</s><|prompt|>Go to the market?</s><|answer|>" |
|
|
|
prompt_mpt_instruct = """Below is an instruction that describes a task. Write a response that appropriately completes the request. |
|
|
|
### Instruction |
|
Hello! |
|
|
|
### Response |
|
Hi! |
|
|
|
### Instruction |
|
How are you? |
|
|
|
### Response |
|
I'm good |
|
|
|
### Instruction |
|
Go to the market? |
|
|
|
### Response |
|
""" |
|
|
|
prompt_mpt_chat = """<|im_start|>system |
|
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers. |
|
<|im_end|><|im_start|>user |
|
Hello!<|im_end|><|im_start|>assistant |
|
Hi!<|im_end|><|im_start|>user |
|
How are you?<|im_end|><|im_start|>assistant |
|
I'm good<|im_end|><|im_start|>user |
|
Go to the market?<|im_end|><|im_start|>assistant |
|
""" |
|
|
|
prompt_falcon = """User: Hello! |
|
|
|
Assistant: Hi! |
|
|
|
User: How are you? |
|
|
|
Assistant: I'm good |
|
|
|
User: Go to the market? |
|
|
|
Assistant:""" |
|
|
|
|
|
@wrap_test_forked |
|
@pytest.mark.parametrize("prompt_type,expected", |
|
[ |
|
('vicuna11', prompt_fastchat), |
|
('human_bot', prompt_humanbot), |
|
('prompt_answer', prompt_prompt_answer), |
|
('prompt_answer_openllama', prompt_prompt_answer_openllama), |
|
('mptinstruct', prompt_mpt_instruct), |
|
('mptchat', prompt_mpt_chat), |
|
('falcon', prompt_falcon), |
|
] |
|
) |
|
def test_prompt_with_context(prompt_type, expected): |
|
prompt_dict = None |
|
langchain_mode = 'Disabled' |
|
chat = True |
|
model_max_length = 2048 |
|
memory_restriction_level = 0 |
|
keep_sources_in_context1 = False |
|
iinput = '' |
|
stream_output = False |
|
debug = False |
|
|
|
from src.prompter import Prompter |
|
from src.gen import history_to_context |
|
|
|
t0 = time.time() |
|
history = [["Hello!", "Hi!"], |
|
["How are you?", "I'm good"], |
|
["Go to the market?", None] |
|
] |
|
print("duration1: %s %s" % (prompt_type, time.time() - t0), flush=True) |
|
t0 = time.time() |
|
context = history_to_context(history, langchain_mode, prompt_type, prompt_dict, chat, |
|
model_max_length, memory_restriction_level, |
|
keep_sources_in_context1) |
|
print("duration2: %s %s" % (prompt_type, time.time() - t0), flush=True) |
|
t0 = time.time() |
|
instruction = history[-1][0] |
|
|
|
|
|
prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output) |
|
print("duration3: %s %s" % (prompt_type, time.time() - t0), flush=True) |
|
t0 = time.time() |
|
data_point = dict(context=context, instruction=instruction, input=iinput) |
|
prompt = prompter.generate_prompt(data_point) |
|
print(prompt) |
|
print("duration4: %s %s" % (prompt_type, time.time() - t0), flush=True) |
|
assert prompt == expected |
|
assert prompt.find(source_prefix) == -1 |
|
|
|
|
|
prompt_fastchat1 = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Go to the market? ASSISTANT:""" |
|
|
|
prompt_humanbot1 = """<human>: Go to the market?\n<bot>:""" |
|
|
|
prompt_prompt_answer1 = "<|prompt|>Go to the market?<|endoftext|><|answer|>" |
|
|
|
prompt_prompt_answer_openllama1 = "<|prompt|>Go to the market?</s><|answer|>" |
|
|
|
prompt_mpt_instruct1 = """Below is an instruction that describes a task. Write a response that appropriately completes the request. |
|
|
|
### Instruction |
|
Go to the market? |
|
|
|
### Response |
|
""" |
|
|
|
prompt_mpt_chat1 = """<|im_start|>system |
|
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers. |
|
<|im_end|><|im_start|>user |
|
Go to the market?<|im_end|><|im_start|>assistant |
|
""" |
|
|
|
prompt_falcon1 = """User: Go to the market? |
|
|
|
Assistant:""" |
|
|
|
|
|
@pytest.mark.parametrize("prompt_type,expected", |
|
[ |
|
('vicuna11', prompt_fastchat1), |
|
('human_bot', prompt_humanbot1), |
|
('prompt_answer', prompt_prompt_answer1), |
|
('prompt_answer_openllama', prompt_prompt_answer_openllama1), |
|
('mptinstruct', prompt_mpt_instruct1), |
|
('mptchat', prompt_mpt_chat1), |
|
('falcon', prompt_falcon1), |
|
] |
|
) |
|
@wrap_test_forked |
|
def test_prompt_with_no_context(prompt_type, expected): |
|
prompt_dict = None |
|
chat = True |
|
iinput = '' |
|
stream_output = False |
|
debug = False |
|
|
|
from src.prompter import Prompter |
|
context = '' |
|
instruction = "Go to the market?" |
|
|
|
|
|
prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output) |
|
data_point = dict(context=context, instruction=instruction, input=iinput) |
|
prompt = prompter.generate_prompt(data_point) |
|
print(prompt) |
|
assert prompt == expected |
|
assert prompt.find(source_prefix) == -1 |
|
|
|
|
|
@wrap_test_forked |
|
def test_source(): |
|
prompt = "Who are you?%s\nFOO\n%s" % (source_prefix, source_postfix) |
|
assert prompt.find(source_prefix) >= 0 |
|
|