|
|
|
from typing import Protocol, List, Tuple |
|
from transformers import AutoTokenizer |
|
|
|
|
|
class PromptTemplate(Protocol): |
|
"""Protocol for prompt templates.""" |
|
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str: |
|
pass |
|
|
|
|
|
class LlamaPromptTemplate: |
|
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], max_history_turns: int = 1) -> str: |
|
system_message = f"Please assist based on the following context: {context}" |
|
prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>" |
|
|
|
for user_msg, assistant_msg in chat_history[-max_history_turns:]: |
|
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_msg}<|eot_id|>" |
|
prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{assistant_msg}<|eot_id|>" |
|
|
|
prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{user_input}<|eot_id|>" |
|
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" |
|
return prompt |
|
|
|
|
|
class TransformersPromptTemplate: |
|
def __init__(self, model_path: str): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
def format(self, context: str, user_input: str, chat_history: List[Tuple[str, str]], **kwargs) -> str: |
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": f"Please assist based on the following context: {context}", |
|
} |
|
] |
|
|
|
for user_msg, assistant_msg in chat_history: |
|
messages.extend([ |
|
{"role": "user", "content": user_msg}, |
|
{"role": "assistant", "content": assistant_msg} |
|
]) |
|
|
|
messages.append({"role": "user", "content": user_input}) |
|
|
|
tokenized_chat = self.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
return tokenized_chat |
|
|