File size: 2,484 Bytes
8b1c859
e0169c8
8b1c859
34b78ab
e0169c8
8b1c859
 
 
 
 
34b78ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b1c859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34b78ab
 
 
cfc7185
34b78ab
 
e0169c8
 
34b78ab
 
 
 
 
 
cfc7185
34b78ab
 
 
 
8b1c859
34b78ab
 
 
 
 
 
e0169c8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from jinja2 import Environment, FileSystemLoader

from gradio_app.backend.ChatGptInteractor import *
from gradio_app.backend.HuggingfaceGenerator import HuggingfaceGenerator

env = Environment(loader=FileSystemLoader('gradio_app/templates'))
context_template = env.get_template('context_template.j2')
start_system_message = context_template.render(documents=[])


def construct_mistral_messages(context, history):
    messages = []
    for q, a in history:
        if len(a) == 0:  # the last message
            q = context + f'\n\nQuery:\n\n{q}'
        messages.append({
            "role": "user",
            "content": q,
        })
        if len(a) != 0:  # some of the previous LLM answers
            messages.append({
                "role": "assistant",
                "content": a,
            })
    return messages


def construct_openai_messages(context, history):
    messages = [
        {
            "role": "system",
            "content": start_system_message,
        },
    ]
    for q, a in history:
        if len(a) == 0:  # the last message
            messages.append({
                "role": "system",
                "content": context,
            })
        messages.append({
            "role": "user",
            "content": q,
        })
        if len(a) != 0:  # some of the previous LLM answers
            messages.append({
                "role": "assistant",
                "content": a,
            })
    return messages


def get_message_constructor(llm_name):
    if llm_name == 'gpt-3.5-turbo':
        return construct_openai_messages
    if llm_name in ['mistralai/Mistral-7B-Instruct-v0.1', "tiiuae/falcon-180B-chat", "GeneZC/MiniChat-3B"]:
        return construct_mistral_messages
    raise ValueError('Unknown LLM name')


def get_llm_generator(llm_name):
    if llm_name == 'gpt-3.5-turbo':
        cgi = ChatGptInteractor(
            model_name=llm_name, max_tokens=512, temperature=0, stream=True
        )
        return cgi.chat_completion
    if llm_name == 'mistralai/Mistral-7B-Instruct-v0.1' or llm_name == "tiiuae/falcon-180B-chat":
        hfg = HuggingfaceGenerator(
            model_name=llm_name, temperature=0, max_new_tokens=512,
        )
        return hfg.generate

    if llm_name == "GeneZC/MiniChat-3B":
        hfg = HuggingfaceGenerator(
            model_name=llm_name, temperature=0, max_new_tokens=250, stream=False,
        )
        return hfg.generate
    raise ValueError('Unknown LLM name')