File size: 3,512 Bytes
815128e
a28a4f8
815128e
a28a4f8
815128e
 
 
 
 
 
4359eb6
e182c41
910c4c8
2826548
815128e
2826548
815128e
99d65c0
e182c41
a28a4f8
 
910c4c8
 
a28a4f8
910c4c8
 
e182c41
2826548
 
 
 
 
 
 
 
815128e
 
a28a4f8
910c4c8
a28a4f8
910c4c8
b975040
a28a4f8
 
 
 
 
 
3ac9f6d
 
 
 
a28a4f8
 
 
e182c41
 
 
 
815128e
 
a28a4f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
815128e
a28a4f8
 
815128e
 
 
a28a4f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Main entrypoint for the app."""

import os
from threading import Thread
import time
from queue import Queue
from timeit import default_timer as timer

import gradio as gr
from anyio.from_thread import start_blocking_portal

from app_modules.init import app_init
from app_modules.llm_chat_chain import ChatChain
from app_modules.utils import print_llm_response, remove_extra_spaces

llm_loader, qa_chain = app_init()

share_gradio_app = os.environ.get("SHARE_GRADIO_APP") == "true"
using_openai = os.environ.get("LLM_MODEL_TYPE") == "openai"
chat_with_orca_2 = (
    not using_openai and os.environ.get("USE_ORCA_2_PROMPT_TEMPLATE") == "true"
)
chat_history_enabled = (
    not chat_with_orca_2 and os.environ.get("CHAT_HISTORY_ENABLED") == "true"
)

model = (
    "OpenAI GPT-3.5"
    if using_openai
    else os.environ.get("HUGGINGFACE_MODEL_NAME_OR_PATH")
)
href = (
    "https://platform.openai.com/docs/models/gpt-3-5"
    if using_openai
    else f"https://huggingface.co./{model}"
)

if chat_with_orca_2:
    qa_chain = ChatChain(llm_loader)
    name = "Orca-2"
else:
    name = "AI Books"

title = f"Chat with {name}"
examples = (
    ["How to cook a fish?", "Who is the president of US now?"]
    if chat_with_orca_2
    else [
        "What's Machine Learning?",
        "What's Generative AI?",
        "What's Difference in Differences?",
        "What's Instrumental Variable?",
    ]
)
description = f"""\
<div align="left">
<p> Currently Running: <a href="{href}">{model}</a></p>
</div>
"""


def task(question, chat_history, q, result):
    start = timer()
    inputs = {"question": question, "chat_history": chat_history}
    ret = qa_chain.call_chain(inputs, None, q)
    end = timer()

    print(f"Completed in {end - start:.3f}s")
    print_llm_response(ret)

    result.put(ret)


def predict(message, history):
    print("predict:", message, history)

    chat_history = []
    if chat_history_enabled:
        for element in history:
            item = (element[0] or "", element[1] or "")
            chat_history.append(item)

    if not chat_history:
        qa_chain.reset()

    q = Queue()
    result = Queue()
    t = Thread(target=task, args=(message, chat_history, q, result))
    t.start()  # Starting the generation in a separate thread.

    partial_message = ""
    count = 2 if len(chat_history) > 0 else 1

    while count > 0:
        while q.empty():
            print("nothing generated yet - retry in 0.5s")
            time.sleep(0.5)

        for next_token in llm_loader.streamer:
            partial_message += next_token or ""
            # partial_message = remove_extra_spaces(partial_message)
            yield partial_message

        if count == 2:
            partial_message += "\n\n"

        count -= 1

    if not chat_with_orca_2:
        partial_message += "\n\nSources:\n"
        ret = result.get()
        titles = []
        for doc in ret["source_documents"]:
            page = doc.metadata["page"] + 1
            url = f"{doc.metadata['url']}#page={page}"
            file_name = doc.metadata["source"].split("/")[-1]
            title = f"{file_name} Page: {page}"
            if title not in titles:
                titles.append(title)
                partial_message += f"1. [{title}]({url})\n"

        yield partial_message


# Setting up the Gradio chat interface.
gr.ChatInterface(
    predict,
    title=title,
    description=description,
    examples=examples,
).launch(
    share=share_gradio_app
)  # Launching the web interface.