import dataclasses from enum import auto, Enum from typing import List, Any, Union, Tuple import base64 from io import BytesIO from PIL import Image class SeparatorStyle(Enum): """Different separator style.""" TWO = auto() PLAIN = auto() CHATML = auto() LLAMA_2 = auto() LLAMA_3 = auto() QWEN2 = auto() @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" system: str roles: List[str] messages: List[List[str]] offset: int sep_style: SeparatorStyle = SeparatorStyle.PLAIN sep: str = "###" sep2: str = None version: str = "Unknown" tokenizer_id: str = "" tokenizer: Any = None # Stop criteria (the default one is EOS token) stop_str: Union[str, List[str]] = None # Stops generation if meeting any token in this list stop_token_ids: List[int] = None skip_next: bool = False def get_prompt(self): messages = self.messages if self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message = message[0] ret += role + ": " + message + seps[i % 2] else: ret += role + ":" elif self.sep_style == SeparatorStyle.LLAMA_3: wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg ret = "<|begin_of_text|>" + wrap_sys(self.system) for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message = message[0] ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" ret += message.strip() + self.sep2 else: ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" return ret elif self.sep_style == SeparatorStyle.LLAMA_2: wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg wrap_inst = lambda msg: f"[INST] {msg} [/INST]" ret = "" for i, (role, message) in enumerate(messages): if i == 0: assert message, "first message should not be none" assert role == self.roles[0], "first message should come from user" if message: if type(message) is tuple: message, _, _ = message if i == 0: message = wrap_sys(self.system) + message if i % 2 == 0: message = wrap_inst(message) ret += self.sep + message else: ret += " " + message + " " + self.sep2 else: ret += "" ret = ret.lstrip(self.sep) elif self.sep_style == SeparatorStyle.PLAIN: seps = [self.sep, self.sep2] ret = self.system for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += message + seps[i % 2] else: ret += "" elif self.sep_style == SeparatorStyle.CHATML: ret = "" if self.system == "" else self.system + self.sep + "\n" for role, message in messages: if message: if type(message) is tuple: raise ValueError("Tuple not supported in CHATML") message, images = message message = "" * len(images) + message ret += role + "\n" + message + self.sep + "\n" else: ret += role + "\n" return ret elif self.sep_style == SeparatorStyle.QWEN2: start = '<|im_start|>' end = '<|im_end|>\n' ret = start + 'system\n' + self.system + end for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message if message.endswith('<|endoftext|>'): message = message.replace('<|endoftext|>', '') ret += start + role + "\n" + message + end + '<|endoftext|>' else: assert not '<|endoftext|>' in message, f"Invalid message: {message}" ret += start + role + "\n" + message + end else: ret += start + role + "\n" else: raise ValueError(f"Invalid style: {self.sep_style}") return ret def append_message(self, role, message): self.messages.append([role, message]) def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: msg, speech = msg ret.append([msg, None]) else: ret.append([msg, None]) else: ret[-1][-1] = msg return ret def copy(self): return Conversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version) def dict(self): if len(self.get_images()) > 0: return { "system": self.system, "roles": self.roles, "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } return { "system": self.system, "roles": self.roles, "messages": self.messages, "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } conv_vicuna_v1 = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), version="v1", messages=[], offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="", ) conv_llama_2 = Conversation( system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", roles=("USER", "ASSISTANT"), version="llama_v2", messages=[], offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", sep2="", ) conv_llama_3 = Conversation( system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", roles=("user", "assistant"), version="llama_v3", messages=[], offset=0, sep_style=SeparatorStyle.LLAMA_3, sep="", sep2="<|eot_id|>" ) conv_qwen_v1 = Conversation( system="You are a helpful assistant.", roles=("user", "assistant"), version="v1", messages=(), offset=0, sep_style=SeparatorStyle.QWEN2, ) conv_plain = Conversation( system="", roles=("", ""), messages=( ), offset=0, sep_style=SeparatorStyle.PLAIN, sep="", ) conv_qwen = Conversation( system="""<|im_start|>system You are a helpful assistant.""", roles=("<|im_start|>user", "<|im_start|>assistant"), version="qwen", messages=[], offset=0, sep_style=SeparatorStyle.CHATML, sep="<|im_end|>", ) default_conversation = conv_llama_3 conv_templates = { "v1": conv_vicuna_v1, "plain": conv_plain, "llama_2": conv_llama_2, "llama_3": conv_llama_3, 'v1_qwen2': conv_qwen_v1, "qwen_1_5": conv_qwen, } if __name__ == "__main__": print(default_conversation.get_prompt())