File size: 2,484 Bytes
8b1c859 e0169c8 8b1c859 34b78ab e0169c8 8b1c859 34b78ab 8b1c859 34b78ab cfc7185 34b78ab e0169c8 34b78ab cfc7185 34b78ab 8b1c859 34b78ab e0169c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from jinja2 import Environment, FileSystemLoader
from gradio_app.backend.ChatGptInteractor import *
from gradio_app.backend.HuggingfaceGenerator import HuggingfaceGenerator
env = Environment(loader=FileSystemLoader('gradio_app/templates'))
context_template = env.get_template('context_template.j2')
start_system_message = context_template.render(documents=[])
def construct_mistral_messages(context, history):
messages = []
for q, a in history:
if len(a) == 0: # the last message
q = context + f'\n\nQuery:\n\n{q}'
messages.append({
"role": "user",
"content": q,
})
if len(a) != 0: # some of the previous LLM answers
messages.append({
"role": "assistant",
"content": a,
})
return messages
def construct_openai_messages(context, history):
messages = [
{
"role": "system",
"content": start_system_message,
},
]
for q, a in history:
if len(a) == 0: # the last message
messages.append({
"role": "system",
"content": context,
})
messages.append({
"role": "user",
"content": q,
})
if len(a) != 0: # some of the previous LLM answers
messages.append({
"role": "assistant",
"content": a,
})
return messages
def get_message_constructor(llm_name):
if llm_name == 'gpt-3.5-turbo':
return construct_openai_messages
if llm_name in ['mistralai/Mistral-7B-Instruct-v0.1', "tiiuae/falcon-180B-chat", "GeneZC/MiniChat-3B"]:
return construct_mistral_messages
raise ValueError('Unknown LLM name')
def get_llm_generator(llm_name):
if llm_name == 'gpt-3.5-turbo':
cgi = ChatGptInteractor(
model_name=llm_name, max_tokens=512, temperature=0, stream=True
)
return cgi.chat_completion
if llm_name == 'mistralai/Mistral-7B-Instruct-v0.1' or llm_name == "tiiuae/falcon-180B-chat":
hfg = HuggingfaceGenerator(
model_name=llm_name, temperature=0, max_new_tokens=512,
)
return hfg.generate
if llm_name == "GeneZC/MiniChat-3B":
hfg = HuggingfaceGenerator(
model_name=llm_name, temperature=0, max_new_tokens=250, stream=False,
)
return hfg.generate
raise ValueError('Unknown LLM name')
|