AlexanderKazakov
commited on
Commit
·
8b1c859
1
Parent(s):
360f505
improve interface and cut documents to fit the context length
Browse files- gradio_app/app.py +69 -55
- gradio_app/backend/ChatGptInteractor.py +170 -0
- gradio_app/backend/query_llm.py +55 -19
- gradio_app/templates/context_html_template.j2 +95 -0
- gradio_app/templates/context_template.j2 +20 -0
- settings.py +7 -1
gradio_app/app.py
CHANGED
@@ -11,8 +11,9 @@ from time import perf_counter
|
|
11 |
import gradio as gr
|
12 |
from jinja2 import Environment, FileSystemLoader
|
13 |
|
14 |
-
from backend.
|
15 |
-
from backend.
|
|
|
16 |
|
17 |
from settings import *
|
18 |
|
@@ -24,23 +25,29 @@ logger = logging.getLogger(__name__)
|
|
24 |
env = Environment(loader=FileSystemLoader('gradio_app/templates'))
|
25 |
|
26 |
# Load the templates directly from the environment
|
27 |
-
|
28 |
-
|
29 |
|
30 |
# Examples
|
31 |
-
examples = [
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
def add_text(history, text):
|
37 |
history = [] if history is None else history
|
38 |
-
history = history + [(text,
|
39 |
return history, gr.Textbox(value="", interactive=False)
|
40 |
|
41 |
|
42 |
def bot(history, api_kind):
|
43 |
-
top_k_rank =
|
|
|
|
|
44 |
query = history[-1][0]
|
45 |
|
46 |
if not query:
|
@@ -53,71 +60,78 @@ def bot(history, api_kind):
|
|
53 |
|
54 |
query_vec = embedder.encode(query)
|
55 |
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank).to_list()
|
|
|
|
|
56 |
documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
|
57 |
|
58 |
document_time = perf_counter() - document_start
|
59 |
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
elif api_kind is None:
|
70 |
-
gr.Warning("API name was not provided")
|
71 |
-
raise ValueError("API name was not provided")
|
72 |
else:
|
73 |
-
gr.
|
74 |
-
raise ValueError(f"API {api_kind} is not supported")
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
history
|
79 |
-
|
|
|
80 |
|
81 |
|
82 |
with gr.Blocks() as demo:
|
83 |
-
chatbot = gr.Chatbot(
|
84 |
-
[],
|
85 |
-
elem_id="chatbot",
|
86 |
-
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
|
87 |
-
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
|
88 |
-
bubble_full_width=False,
|
89 |
-
show_copy_button=True,
|
90 |
-
show_share_button=True,
|
91 |
-
)
|
92 |
-
|
93 |
with gr.Row():
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
prompt_html = gr.HTML()
|
105 |
# Turn off interactivity while generating if you click
|
106 |
-
txt_msg = txt_btn.click(
|
107 |
-
|
|
|
|
|
|
|
108 |
|
109 |
# Turn it back on
|
110 |
-
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [
|
111 |
|
112 |
# Turn off interactivity while generating if you hit enter
|
113 |
-
txt_msg =
|
114 |
-
bot, [chatbot, api_kind], [chatbot,
|
115 |
|
116 |
# Turn it back on
|
117 |
-
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [
|
118 |
-
|
119 |
-
# Examples
|
120 |
-
gr.Examples(examples, txt)
|
121 |
|
122 |
demo.queue()
|
123 |
demo.launch(debug=True)
|
|
|
11 |
import gradio as gr
|
12 |
from jinja2 import Environment, FileSystemLoader
|
13 |
|
14 |
+
from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
|
15 |
+
from gradio_app.backend.query_llm import generate_hf, generate_openai, construct_openai_messages
|
16 |
+
from gradio_app.backend.semantic_search import table, embedder
|
17 |
|
18 |
from settings import *
|
19 |
|
|
|
25 |
env = Environment(loader=FileSystemLoader('gradio_app/templates'))
|
26 |
|
27 |
# Load the templates directly from the environment
|
28 |
+
context_template = env.get_template('context_template.j2')
|
29 |
+
context_html_template = env.get_template('context_html_template.j2')
|
30 |
|
31 |
# Examples
|
32 |
+
examples = [
|
33 |
+
'What is BERT?',
|
34 |
+
'Tell me about BERT deep learning model',
|
35 |
+
'What is the capital of China?',
|
36 |
+
'Why is the sky blue?',
|
37 |
+
'Who won the mens world cup in 2014?',
|
38 |
+
]
|
39 |
|
40 |
|
41 |
def add_text(history, text):
|
42 |
history = [] if history is None else history
|
43 |
+
history = history + [(text, "")]
|
44 |
return history, gr.Textbox(value="", interactive=False)
|
45 |
|
46 |
|
47 |
def bot(history, api_kind):
|
48 |
+
top_k_rank = 5
|
49 |
+
thresh_dist = 1.2
|
50 |
+
history[-1][1] = ""
|
51 |
query = history[-1][0]
|
52 |
|
53 |
if not query:
|
|
|
60 |
|
61 |
query_vec = embedder.encode(query)
|
62 |
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank).to_list()
|
63 |
+
thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
|
64 |
+
documents = [d for d in documents if d['_distance'] <= thresh_dist]
|
65 |
documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
|
66 |
|
67 |
document_time = perf_counter() - document_start
|
68 |
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
69 |
|
70 |
+
while len(documents) != 0:
|
71 |
+
context = context_template.render(documents=documents)
|
72 |
+
context_html = context_html_template.render(documents=documents)
|
73 |
+
messages = construct_openai_messages(context, history)
|
74 |
+
num_tokens = num_tokens_from_messages(messages, OPENAI_LLM_NAME)
|
75 |
+
if num_tokens + 512 < context_lengths[OPENAI_LLM_NAME]:
|
76 |
+
break
|
77 |
+
documents.pop()
|
|
|
|
|
|
|
78 |
else:
|
79 |
+
raise gr.Error('Model context length exceeded, reload the page')
|
|
|
80 |
|
81 |
+
for part in generate_openai(messages):
|
82 |
+
history[-1][1] += part
|
83 |
+
yield history, context_html
|
84 |
+
else:
|
85 |
+
print('Finished generation stream.')
|
86 |
|
87 |
|
88 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
with gr.Row():
|
90 |
+
with gr.Column():
|
91 |
+
chatbot = gr.Chatbot(
|
92 |
+
[],
|
93 |
+
elem_id="chatbot",
|
94 |
+
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
|
95 |
+
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
|
96 |
+
bubble_full_width=False,
|
97 |
+
show_copy_button=True,
|
98 |
+
show_share_button=True,
|
99 |
+
height=600,
|
100 |
+
)
|
101 |
+
|
102 |
+
with gr.Row():
|
103 |
+
input_textbox = gr.Textbox(
|
104 |
+
scale=3,
|
105 |
+
show_label=False,
|
106 |
+
placeholder="Enter text and press enter",
|
107 |
+
container=False,
|
108 |
+
)
|
109 |
+
txt_btn = gr.Button(value="Submit text", scale=1)
|
110 |
+
|
111 |
+
api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="OpenAI", label='Backend')
|
112 |
+
|
113 |
+
# Examples
|
114 |
+
gr.Examples(examples, input_textbox)
|
115 |
+
|
116 |
+
with gr.Column():
|
117 |
+
context_html = gr.HTML()
|
118 |
|
|
|
119 |
# Turn off interactivity while generating if you click
|
120 |
+
txt_msg = txt_btn.click(
|
121 |
+
add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False
|
122 |
+
).then(
|
123 |
+
bot, [chatbot, api_kind], [chatbot, context_html]
|
124 |
+
)
|
125 |
|
126 |
# Turn it back on
|
127 |
+
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
|
128 |
|
129 |
# Turn off interactivity while generating if you hit enter
|
130 |
+
txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
|
131 |
+
bot, [chatbot, api_kind], [chatbot, context_html])
|
132 |
|
133 |
# Turn it back on
|
134 |
+
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
|
|
|
|
|
|
|
135 |
|
136 |
demo.queue()
|
137 |
demo.launch(debug=True)
|
gradio_app/backend/ChatGptInteractor.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import tiktoken
|
4 |
+
import openai
|
5 |
+
|
6 |
+
|
7 |
+
with open('data/openaikey.txt') as f:
|
8 |
+
OPENAI_KEY = f.read().strip()
|
9 |
+
openai.api_key = OPENAI_KEY
|
10 |
+
|
11 |
+
|
12 |
+
def num_tokens_from_messages(messages, model):
|
13 |
+
"""
|
14 |
+
Return the number of tokens used by a list of messages.
|
15 |
+
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
16 |
+
"""
|
17 |
+
try:
|
18 |
+
encoding = tiktoken.encoding_for_model(model)
|
19 |
+
except KeyError:
|
20 |
+
print("Warning: model not found. Using cl100k_base encoding.")
|
21 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
22 |
+
if model in {
|
23 |
+
"gpt-3.5-turbo-0613",
|
24 |
+
"gpt-3.5-turbo-16k-0613",
|
25 |
+
"gpt-4-0314",
|
26 |
+
"gpt-4-32k-0314",
|
27 |
+
"gpt-4-0613",
|
28 |
+
"gpt-4-32k-0613",
|
29 |
+
}:
|
30 |
+
tokens_per_message = 3
|
31 |
+
tokens_per_name = 1
|
32 |
+
elif model == "gpt-3.5-turbo-0301":
|
33 |
+
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
34 |
+
tokens_per_name = -1 # if there's a name, the role is omitted
|
35 |
+
elif "gpt-3.5-turbo" in model:
|
36 |
+
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
37 |
+
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
|
38 |
+
elif "gpt-4" in model:
|
39 |
+
# print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
40 |
+
return num_tokens_from_messages(messages, model="gpt-4-0613")
|
41 |
+
else:
|
42 |
+
raise NotImplementedError(
|
43 |
+
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
|
44 |
+
)
|
45 |
+
num_tokens = 0
|
46 |
+
for message in messages:
|
47 |
+
num_tokens += tokens_per_message
|
48 |
+
for key, value in message.items():
|
49 |
+
num_tokens += len(encoding.encode(value, disallowed_special=()))
|
50 |
+
if key == "name":
|
51 |
+
num_tokens += tokens_per_name
|
52 |
+
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
53 |
+
return num_tokens
|
54 |
+
|
55 |
+
|
56 |
+
class ChatGptInteractor:
|
57 |
+
def __init__(self, model_name='gpt-3.5-turbo'):
|
58 |
+
self.model_name = model_name
|
59 |
+
self.tokenizer = tiktoken.encoding_for_model(self.model_name)
|
60 |
+
|
61 |
+
def chat_completion_simple(
|
62 |
+
self,
|
63 |
+
*,
|
64 |
+
user_text,
|
65 |
+
system_text=None,
|
66 |
+
max_tokens=None,
|
67 |
+
temperature=None,
|
68 |
+
stream=False,
|
69 |
+
):
|
70 |
+
return self.chat_completion(
|
71 |
+
self._construct_messages_simple(user_text, system_text),
|
72 |
+
max_tokens=max_tokens,
|
73 |
+
temperature=temperature,
|
74 |
+
stream=stream,
|
75 |
+
)
|
76 |
+
|
77 |
+
def count_tokens_simple(self, *, user_text, system_text=None):
|
78 |
+
return self.count_tokens(self._construct_messages_simple(user_text, system_text))
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def _construct_messages_simple(user_text, system_text=None):
|
82 |
+
messages = []
|
83 |
+
if system_text is not None:
|
84 |
+
messages.append({
|
85 |
+
"role": "system",
|
86 |
+
"content": system_text
|
87 |
+
})
|
88 |
+
messages.append({
|
89 |
+
"role": "user",
|
90 |
+
"content": user_text
|
91 |
+
})
|
92 |
+
return messages
|
93 |
+
|
94 |
+
def chat_completion(
|
95 |
+
self,
|
96 |
+
messages,
|
97 |
+
max_tokens=None,
|
98 |
+
temperature=None,
|
99 |
+
stream=False,
|
100 |
+
):
|
101 |
+
print(f'Sending request to {self.model_name} stream={stream} ...')
|
102 |
+
t1 = time.time()
|
103 |
+
completion = self._request(
|
104 |
+
model=self.model_name,
|
105 |
+
messages=messages,
|
106 |
+
max_tokens=max_tokens,
|
107 |
+
temperature=temperature,
|
108 |
+
stream=stream,
|
109 |
+
)
|
110 |
+
if stream:
|
111 |
+
return completion
|
112 |
+
t2 = time.time()
|
113 |
+
usage = completion['usage']
|
114 |
+
print(
|
115 |
+
f'Received response: {usage["prompt_tokens"]} in + {usage["completion_tokens"]} out'
|
116 |
+
f' = {usage["total_tokens"]} total tokens. Time: {t2 - t1:3.1f} seconds'
|
117 |
+
)
|
118 |
+
return completion.choices[0].message['content']
|
119 |
+
|
120 |
+
@staticmethod
|
121 |
+
def get_stream_text(stream_part):
|
122 |
+
return stream_part['choices'][0]['delta'].get('content', '')
|
123 |
+
|
124 |
+
def count_tokens(self, messages):
|
125 |
+
return num_tokens_from_messages(messages, self.model_name)
|
126 |
+
|
127 |
+
def _request(self, *args, **kwargs):
|
128 |
+
for _ in range(5):
|
129 |
+
try:
|
130 |
+
completion = openai.ChatCompletion.create(
|
131 |
+
*args, **kwargs,
|
132 |
+
request_timeout=100.0,
|
133 |
+
)
|
134 |
+
return completion
|
135 |
+
except (openai.error.Timeout, openai.error.ServiceUnavailableError):
|
136 |
+
continue
|
137 |
+
raise RuntimeError('Failed to connect to OpenAI (timeout error)')
|
138 |
+
|
139 |
+
|
140 |
+
if __name__ == '__main__':
|
141 |
+
cgi = ChatGptInteractor()
|
142 |
+
|
143 |
+
for txt in [
|
144 |
+
"Hello World!",
|
145 |
+
"Hello",
|
146 |
+
" World!",
|
147 |
+
"World!",
|
148 |
+
"World",
|
149 |
+
"!",
|
150 |
+
" ",
|
151 |
+
" ",
|
152 |
+
" ",
|
153 |
+
" ",
|
154 |
+
"\n",
|
155 |
+
"\n\t",
|
156 |
+
]:
|
157 |
+
print(f'`{txt}` | {cgi.tokenizer.encode(txt)}')
|
158 |
+
|
159 |
+
st = 'You are a helpful assistant and an experienced programmer, ' \
|
160 |
+
'answering questions exactly in two rhymed sentences'
|
161 |
+
ut = 'Explain the principle of recursion in programming'
|
162 |
+
print('Count tokens:', cgi.count_tokens_simple(user_text=ut, system_text=st))
|
163 |
+
|
164 |
+
print(cgi.chat_completion_simple(user_text=ut, system_text=st))
|
165 |
+
print('---')
|
166 |
+
|
167 |
+
for part in cgi.chat_completion_simple(user_text=ut, system_text=st, stream=True):
|
168 |
+
print(cgi.get_stream_text(part), end='')
|
169 |
+
print('\n---')
|
170 |
+
|
gradio_app/backend/query_llm.py
CHANGED
@@ -1,22 +1,20 @@
|
|
1 |
-
import openai
|
2 |
import gradio as gr
|
3 |
|
4 |
-
from os import getenv
|
5 |
from typing import Any, Dict, Generator, List
|
6 |
|
7 |
from huggingface_hub import InferenceClient
|
8 |
from transformers import AutoTokenizer
|
|
|
9 |
|
10 |
from settings import *
|
|
|
11 |
|
12 |
|
13 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
14 |
|
15 |
-
|
16 |
-
HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
|
17 |
|
18 |
-
|
19 |
-
hf_client = InferenceClient(LLM_NAME, token=HF_TOKEN)
|
20 |
|
21 |
|
22 |
def format_prompt(message: str, api_kind: str):
|
@@ -42,7 +40,7 @@ def format_prompt(message: str, api_kind: str):
|
|
42 |
|
43 |
|
44 |
def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
|
45 |
-
|
46 |
"""
|
47 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
48 |
|
@@ -69,13 +67,13 @@ def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tok
|
|
69 |
'repetition_penalty': repetition_penalty,
|
70 |
'do_sample': True,
|
71 |
'seed': 42,
|
72 |
-
|
73 |
-
|
74 |
formatted_prompt = format_prompt(prompt, "hf")
|
75 |
|
76 |
try:
|
77 |
stream = hf_client.text_generation(formatted_prompt, **generate_kwargs,
|
78 |
-
|
79 |
output = ""
|
80 |
for response in stream:
|
81 |
output += response.token.text
|
@@ -96,8 +94,44 @@ def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tok
|
|
96 |
return "I do not know what happened, but I couldn't understand you."
|
97 |
|
98 |
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
"""
|
102 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
103 |
|
@@ -116,21 +150,23 @@ def generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new
|
|
116 |
|
117 |
temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
|
118 |
top_p = float(top_p)
|
119 |
-
|
120 |
generate_kwargs = {
|
121 |
'temperature': temperature,
|
122 |
'max_tokens': max_new_tokens,
|
123 |
'top_p': top_p,
|
124 |
'frequency_penalty': max(-2., min(repetition_penalty, 2.)),
|
125 |
-
|
126 |
|
127 |
formatted_prompt = format_prompt(prompt, "openai")
|
128 |
|
129 |
try:
|
130 |
-
stream = openai.ChatCompletion.create(
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
134 |
output = ""
|
135 |
for chunk in stream:
|
136 |
output += chunk.choices[0].delta.get("content", "")
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
|
|
3 |
from typing import Any, Dict, Generator, List
|
4 |
|
5 |
from huggingface_hub import InferenceClient
|
6 |
from transformers import AutoTokenizer
|
7 |
+
from jinja2 import Environment, FileSystemLoader
|
8 |
|
9 |
from settings import *
|
10 |
+
from gradio_app.backend.ChatGptInteractor import *
|
11 |
|
12 |
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained(HF_LLM_NAME)
|
14 |
|
15 |
+
HF_TOKEN = None
|
|
|
16 |
|
17 |
+
hf_client = InferenceClient(HF_LLM_NAME, token=HF_TOKEN)
|
|
|
18 |
|
19 |
|
20 |
def format_prompt(message: str, api_kind: str):
|
|
|
40 |
|
41 |
|
42 |
def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
|
43 |
+
top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
|
44 |
"""
|
45 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
46 |
|
|
|
67 |
'repetition_penalty': repetition_penalty,
|
68 |
'do_sample': True,
|
69 |
'seed': 42,
|
70 |
+
}
|
71 |
+
|
72 |
formatted_prompt = format_prompt(prompt, "hf")
|
73 |
|
74 |
try:
|
75 |
stream = hf_client.text_generation(formatted_prompt, **generate_kwargs,
|
76 |
+
stream=True, details=True, return_full_text=False)
|
77 |
output = ""
|
78 |
for response in stream:
|
79 |
output += response.token.text
|
|
|
94 |
return "I do not know what happened, but I couldn't understand you."
|
95 |
|
96 |
|
97 |
+
env = Environment(loader=FileSystemLoader('gradio_app/templates'))
|
98 |
+
context_template = env.get_template('context_template.j2')
|
99 |
+
start_system_message = context_template.render(documents=[])
|
100 |
+
|
101 |
+
|
102 |
+
def construct_openai_messages(context, history):
|
103 |
+
messages = [
|
104 |
+
{
|
105 |
+
"role": "system",
|
106 |
+
"content": start_system_message,
|
107 |
+
},
|
108 |
+
]
|
109 |
+
for q, a in history:
|
110 |
+
if len(a) == 0: # the last message
|
111 |
+
messages.append({
|
112 |
+
"role": "system",
|
113 |
+
"content": context,
|
114 |
+
})
|
115 |
+
messages.append({
|
116 |
+
"role": "user",
|
117 |
+
"content": q,
|
118 |
+
})
|
119 |
+
if len(a) != 0: # some of the previous LLM answers
|
120 |
+
messages.append({
|
121 |
+
"role": "assistant",
|
122 |
+
"content": a,
|
123 |
+
})
|
124 |
+
return messages
|
125 |
+
|
126 |
+
|
127 |
+
def generate_openai(messages):
|
128 |
+
cgi = ChatGptInteractor(model_name=OPENAI_LLM_NAME)
|
129 |
+
for part in cgi.chat_completion(messages, max_tokens=512, temperature=0, stream=True):
|
130 |
+
yield cgi.get_stream_text(part)
|
131 |
+
|
132 |
+
|
133 |
+
def _generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
|
134 |
+
top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
|
135 |
"""
|
136 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
137 |
|
|
|
150 |
|
151 |
temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
|
152 |
top_p = float(top_p)
|
153 |
+
|
154 |
generate_kwargs = {
|
155 |
'temperature': temperature,
|
156 |
'max_tokens': max_new_tokens,
|
157 |
'top_p': top_p,
|
158 |
'frequency_penalty': max(-2., min(repetition_penalty, 2.)),
|
159 |
+
}
|
160 |
|
161 |
formatted_prompt = format_prompt(prompt, "openai")
|
162 |
|
163 |
try:
|
164 |
+
stream = openai.ChatCompletion.create(
|
165 |
+
model=OPENAI_LLM_NAME,
|
166 |
+
messages=formatted_prompt,
|
167 |
+
**generate_kwargs,
|
168 |
+
stream=True
|
169 |
+
)
|
170 |
output = ""
|
171 |
for chunk in stream:
|
172 |
output += chunk.choices[0].delta.get("content", "")
|
gradio_app/templates/context_html_template.j2
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Information Page</title>
|
7 |
+
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap">
|
8 |
+
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600&display=swap">
|
9 |
+
<style>
|
10 |
+
* {
|
11 |
+
font-family: "Source Sans Pro";
|
12 |
+
}
|
13 |
+
|
14 |
+
.instructions > * {
|
15 |
+
color: #111 !important;
|
16 |
+
}
|
17 |
+
|
18 |
+
details.doc-box * {
|
19 |
+
color: #111 !important;
|
20 |
+
}
|
21 |
+
|
22 |
+
.dark {
|
23 |
+
background: #111;
|
24 |
+
color: white;
|
25 |
+
}
|
26 |
+
|
27 |
+
.doc-box {
|
28 |
+
padding: 10px;
|
29 |
+
margin-top: 10px;
|
30 |
+
background-color: #baecc2;
|
31 |
+
border-radius: 6px;
|
32 |
+
color: #111 !important;
|
33 |
+
max-width: 700px;
|
34 |
+
box-shadow: rgba(0, 0, 0, 0.2) 0px 1px 2px 0px;
|
35 |
+
}
|
36 |
+
|
37 |
+
.doc-full {
|
38 |
+
margin: 10px 14px;
|
39 |
+
line-height: 1.6rem;
|
40 |
+
}
|
41 |
+
|
42 |
+
.instructions {
|
43 |
+
color: #111 !important;
|
44 |
+
background: #b7bdfd;
|
45 |
+
display: block;
|
46 |
+
border-radius: 6px;
|
47 |
+
padding: 6px 10px;
|
48 |
+
line-height: 1.6rem;
|
49 |
+
max-width: 700px;
|
50 |
+
box-shadow: rgba(0, 0, 0, 0.2) 0px 1px 2px 0px;
|
51 |
+
}
|
52 |
+
|
53 |
+
.query {
|
54 |
+
color: #111 !important;
|
55 |
+
background: #ffbcbc;
|
56 |
+
display: block;
|
57 |
+
border-radius: 6px;
|
58 |
+
padding: 6px 10px;
|
59 |
+
line-height: 1.6rem;
|
60 |
+
max-width: 700px;
|
61 |
+
box-shadow: rgba(0, 0, 0, 0.2) 0px 1px 2px 0px;
|
62 |
+
}
|
63 |
+
</style>
|
64 |
+
</head>
|
65 |
+
<body>
|
66 |
+
<div class="prose svelte-1ybaih5" id="context_html">
|
67 |
+
<h2>Context:</h2>
|
68 |
+
{% for doc in documents %}
|
69 |
+
<details class="doc-box">
|
70 |
+
<summary>
|
71 |
+
<b>Doc {{ loop.index }}:</b> <span class="doc-short">{{ doc[:1000] }}...</span>
|
72 |
+
</summary>
|
73 |
+
<div class="doc-full">{{ doc }}</div>
|
74 |
+
</details>
|
75 |
+
{% endfor %}
|
76 |
+
</div>
|
77 |
+
|
78 |
+
<script>
|
79 |
+
document.addEventListener("DOMContentLoaded", function() {
|
80 |
+
const detailsElements = document.querySelectorAll('.doc-box');
|
81 |
+
|
82 |
+
detailsElements.forEach(detail => {
|
83 |
+
detail.addEventListener('toggle', function() {
|
84 |
+
const docShort = this.querySelector('.doc-short');
|
85 |
+
if (this.open) {
|
86 |
+
docShort.style.display = 'none';
|
87 |
+
} else {
|
88 |
+
docShort.style.display = 'inline';
|
89 |
+
}
|
90 |
+
});
|
91 |
+
});
|
92 |
+
});
|
93 |
+
</script>
|
94 |
+
</body>
|
95 |
+
</html>
|
gradio_app/templates/context_template.j2
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are a helpful assistant.
|
2 |
+
|
3 |
+
You answer questions based only on the provided information.
|
4 |
+
|
5 |
+
If there is no relevant information in the context, just say "No relevant information".
|
6 |
+
|
7 |
+
You must not make up an answer! Use only provided context!
|
8 |
+
|
9 |
+
In each answer, you must provide a precise citation from the given context in double quotes.
|
10 |
+
|
11 |
+
Citation is mandatory in the answer!
|
12 |
+
|
13 |
+
Context:
|
14 |
+
|
15 |
+
{% for doc in documents %}
|
16 |
+
---
|
17 |
+
|
18 |
+
{{ doc }}
|
19 |
+
|
20 |
+
{% endfor %}
|
settings.py
CHANGED
@@ -5,4 +5,10 @@ LANCEDB_DIRECTORY = "data/lancedb"
|
|
5 |
LANCEDB_TABLE_NAME = "table"
|
6 |
VECTOR_COLUMN_NAME = "embedding"
|
7 |
TEXT_COLUMN_NAME = "text"
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
LANCEDB_TABLE_NAME = "table"
|
6 |
VECTOR_COLUMN_NAME = "embedding"
|
7 |
TEXT_COLUMN_NAME = "text"
|
8 |
+
HF_LLM_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
|
9 |
+
OPENAI_LLM_NAME = "gpt-3.5-turbo"
|
10 |
+
|
11 |
+
context_lengths = {
|
12 |
+
"mistralai/Mistral-7B-Instruct-v0.1": 4096,
|
13 |
+
"gpt-3.5-turbo": 4096,
|
14 |
+
}
|