AlexanderKazakov
commited on
Commit
·
34b78ab
1
Parent(s):
eeafaaa
add cross-encoder and HF API LLM
Browse files- gradio_app/app.py +47 -18
- gradio_app/backend/ChatGptInteractor.py +34 -32
- gradio_app/backend/HuggingfaceGenerator.py +44 -0
- gradio_app/backend/cross_encoder.py +32 -0
- gradio_app/backend/query_llm.py +41 -145
- settings.py +8 -2
gradio_app/app.py
CHANGED
@@ -13,7 +13,8 @@ import markdown
|
|
13 |
from jinja2 import Environment, FileSystemLoader
|
14 |
|
15 |
from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
|
16 |
-
from gradio_app.backend.
|
|
|
17 |
from gradio_app.backend.semantic_search import table, embedder
|
18 |
|
19 |
from settings import *
|
@@ -45,42 +46,52 @@ def add_text(history, text):
|
|
45 |
return history, gr.Textbox(value="", interactive=False)
|
46 |
|
47 |
|
48 |
-
def bot(history,
|
49 |
-
top_k_rank = 5
|
50 |
-
thresh_dist = 1.2
|
51 |
history[-1][1] = ""
|
52 |
query = history[-1][0]
|
53 |
|
54 |
if not query:
|
55 |
-
gr.
|
56 |
-
raise ValueError("Empty string was submitted")
|
57 |
|
58 |
logger.info('Retrieving documents...')
|
59 |
-
|
60 |
-
|
61 |
|
62 |
query_vec = embedder.embed(query)[0]
|
63 |
-
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME)
|
|
|
|
|
64 |
thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
|
65 |
documents = [d for d in documents if d['_distance'] <= thresh_dist]
|
66 |
documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
|
67 |
|
68 |
-
|
69 |
-
logger.info(f'Finished Retrieving documents in {round(
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
while len(documents) != 0:
|
72 |
context = context_template.render(documents=documents)
|
73 |
documents_html = [markdown.markdown(d) for d in documents]
|
74 |
context_html = context_html_template.render(documents=documents_html)
|
75 |
-
messages =
|
76 |
-
num_tokens = num_tokens_from_messages(messages,
|
77 |
-
if num_tokens + 512 < context_lengths[
|
78 |
break
|
79 |
documents.pop()
|
80 |
else:
|
81 |
raise gr.Error('Model context length exceeded, reload the page')
|
82 |
|
83 |
-
|
|
|
84 |
history[-1][1] += part
|
85 |
yield history, context_html
|
86 |
else:
|
@@ -110,7 +121,25 @@ with gr.Blocks() as demo:
|
|
110 |
)
|
111 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
112 |
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
# Examples
|
116 |
gr.Examples(examples, input_textbox)
|
@@ -122,7 +151,7 @@ with gr.Blocks() as demo:
|
|
122 |
txt_msg = txt_btn.click(
|
123 |
add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False
|
124 |
).then(
|
125 |
-
bot, [chatbot,
|
126 |
)
|
127 |
|
128 |
# Turn it back on
|
@@ -130,7 +159,7 @@ with gr.Blocks() as demo:
|
|
130 |
|
131 |
# Turn off interactivity while generating if you hit enter
|
132 |
txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
|
133 |
-
bot, [chatbot,
|
134 |
|
135 |
# Turn it back on
|
136 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
|
|
|
13 |
from jinja2 import Environment, FileSystemLoader
|
14 |
|
15 |
from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
|
16 |
+
from gradio_app.backend.cross_encoder import rerank_with_cross_encoder
|
17 |
+
from gradio_app.backend.query_llm import *
|
18 |
from gradio_app.backend.semantic_search import table, embedder
|
19 |
|
20 |
from settings import *
|
|
|
46 |
return history, gr.Textbox(value="", interactive=False)
|
47 |
|
48 |
|
49 |
+
def bot(history, llm, cross_enc):
|
|
|
|
|
50 |
history[-1][1] = ""
|
51 |
query = history[-1][0]
|
52 |
|
53 |
if not query:
|
54 |
+
raise gr.Error("Empty string was submitted")
|
|
|
55 |
|
56 |
logger.info('Retrieving documents...')
|
57 |
+
gr.Info('Start documents retrieval ...')
|
58 |
+
time = perf_counter()
|
59 |
|
60 |
query_vec = embedder.embed(query)[0]
|
61 |
+
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME)
|
62 |
+
documents = documents.limit(TOP_K_RANK).to_list()
|
63 |
+
thresh_dist = thresh_distances[EMBED_NAME]
|
64 |
thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
|
65 |
documents = [d for d in documents if d['_distance'] <= thresh_dist]
|
66 |
documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
|
67 |
|
68 |
+
time = perf_counter() - time
|
69 |
+
logger.info(f'Finished Retrieving documents in {round(time, 2)} seconds...')
|
70 |
|
71 |
+
logger.info('Reranking documents...')
|
72 |
+
gr.Info('Start documents reranking ...')
|
73 |
+
time = perf_counter()
|
74 |
+
|
75 |
+
documents = rerank_with_cross_encoder(cross_enc, documents, query)
|
76 |
+
|
77 |
+
time = perf_counter() - time
|
78 |
+
logger.info(f'Finished Reranking documents in {round(time, 2)} seconds...')
|
79 |
+
|
80 |
+
msg_constructor = get_message_constructor(llm)
|
81 |
while len(documents) != 0:
|
82 |
context = context_template.render(documents=documents)
|
83 |
documents_html = [markdown.markdown(d) for d in documents]
|
84 |
context_html = context_html_template.render(documents=documents_html)
|
85 |
+
messages = msg_constructor(context, history)
|
86 |
+
num_tokens = num_tokens_from_messages(messages, 'gpt-3.5-turbo') # todo for HF, it is approximation
|
87 |
+
if num_tokens + 512 < context_lengths[llm]:
|
88 |
break
|
89 |
documents.pop()
|
90 |
else:
|
91 |
raise gr.Error('Model context length exceeded, reload the page')
|
92 |
|
93 |
+
llm_gen = get_llm_generator(llm)
|
94 |
+
for part in llm_gen(messages):
|
95 |
history[-1][1] += part
|
96 |
yield history, context_html
|
97 |
else:
|
|
|
121 |
)
|
122 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
123 |
|
124 |
+
llm_name = gr.Radio(
|
125 |
+
choices=[
|
126 |
+
"gpt-3.5-turbo",
|
127 |
+
"mistralai/Mistral-7B-Instruct-v0.1",
|
128 |
+
"GeneZC/MiniChat-3B",
|
129 |
+
],
|
130 |
+
value="gpt-3.5-turbo",
|
131 |
+
label='LLM'
|
132 |
+
)
|
133 |
+
|
134 |
+
cross_enc_name = gr.Radio(
|
135 |
+
choices=[
|
136 |
+
None,
|
137 |
+
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
138 |
+
"cross-encoder/ms-marco-MiniLM-L-12-v2",
|
139 |
+
],
|
140 |
+
value=None,
|
141 |
+
label='Cross-Encoder'
|
142 |
+
)
|
143 |
|
144 |
# Examples
|
145 |
gr.Examples(examples, input_textbox)
|
|
|
151 |
txt_msg = txt_btn.click(
|
152 |
add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False
|
153 |
).then(
|
154 |
+
bot, [chatbot, llm_name, cross_enc_name], [chatbot, context_html]
|
155 |
)
|
156 |
|
157 |
# Turn it back on
|
|
|
159 |
|
160 |
# Turn off interactivity while generating if you hit enter
|
161 |
txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
|
162 |
+
bot, [chatbot, llm_name, cross_enc_name], [chatbot, context_html])
|
163 |
|
164 |
# Turn it back on
|
165 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
|
gradio_app/backend/ChatGptInteractor.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import time
|
2 |
|
3 |
import tiktoken
|
@@ -9,6 +10,10 @@ with open('data/openaikey.txt') as f:
|
|
9 |
openai.api_key = OPENAI_KEY
|
10 |
|
11 |
|
|
|
|
|
|
|
|
|
12 |
def num_tokens_from_messages(messages, model):
|
13 |
"""
|
14 |
Return the number of tokens used by a list of messages.
|
@@ -17,7 +22,7 @@ def num_tokens_from_messages(messages, model):
|
|
17 |
try:
|
18 |
encoding = tiktoken.encoding_for_model(model)
|
19 |
except KeyError:
|
20 |
-
|
21 |
encoding = tiktoken.get_encoding("cl100k_base")
|
22 |
if model in {
|
23 |
"gpt-3.5-turbo-0613",
|
@@ -33,10 +38,10 @@ def num_tokens_from_messages(messages, model):
|
|
33 |
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
34 |
tokens_per_name = -1 # if there's a name, the role is omitted
|
35 |
elif "gpt-3.5-turbo" in model:
|
36 |
-
#
|
37 |
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
|
38 |
elif "gpt-4" in model:
|
39 |
-
#
|
40 |
return num_tokens_from_messages(messages, model="gpt-4-0613")
|
41 |
else:
|
42 |
raise NotImplementedError(
|
@@ -54,8 +59,11 @@ def num_tokens_from_messages(messages, model):
|
|
54 |
|
55 |
|
56 |
class ChatGptInteractor:
|
57 |
-
def __init__(self, model_name='gpt-3.5-turbo'):
|
58 |
self.model_name = model_name
|
|
|
|
|
|
|
59 |
self.tokenizer = tiktoken.encoding_for_model(self.model_name)
|
60 |
|
61 |
def chat_completion_simple(
|
@@ -63,15 +71,9 @@ class ChatGptInteractor:
|
|
63 |
*,
|
64 |
user_text,
|
65 |
system_text=None,
|
66 |
-
max_tokens=None,
|
67 |
-
temperature=None,
|
68 |
-
stream=False,
|
69 |
):
|
70 |
return self.chat_completion(
|
71 |
self._construct_messages_simple(user_text, system_text),
|
72 |
-
max_tokens=max_tokens,
|
73 |
-
temperature=temperature,
|
74 |
-
stream=stream,
|
75 |
)
|
76 |
|
77 |
def count_tokens_simple(self, *, user_text, system_text=None):
|
@@ -91,27 +93,17 @@ class ChatGptInteractor:
|
|
91 |
})
|
92 |
return messages
|
93 |
|
94 |
-
def chat_completion(
|
95 |
-
|
96 |
-
messages,
|
97 |
-
max_tokens=None,
|
98 |
-
temperature=None,
|
99 |
-
stream=False,
|
100 |
-
):
|
101 |
-
print(f'Sending request to {self.model_name} stream={stream} ...')
|
102 |
t1 = time.time()
|
103 |
-
completion = self._request(
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
stream=stream,
|
109 |
-
)
|
110 |
-
if stream:
|
111 |
-
return completion
|
112 |
t2 = time.time()
|
113 |
usage = completion['usage']
|
114 |
-
|
115 |
f'Received response: {usage["prompt_tokens"]} in + {usage["completion_tokens"]} out'
|
116 |
f' = {usage["total_tokens"]} total tokens. Time: {t2 - t1:3.1f} seconds'
|
117 |
)
|
@@ -121,14 +113,23 @@ class ChatGptInteractor:
|
|
121 |
def get_stream_text(stream_part):
|
122 |
return stream_part['choices'][0]['delta'].get('content', '')
|
123 |
|
|
|
|
|
|
|
|
|
|
|
124 |
def count_tokens(self, messages):
|
125 |
return num_tokens_from_messages(messages, self.model_name)
|
126 |
|
127 |
-
def _request(self,
|
128 |
for _ in range(5):
|
129 |
try:
|
130 |
completion = openai.ChatCompletion.create(
|
131 |
-
|
|
|
|
|
|
|
|
|
132 |
request_timeout=100.0,
|
133 |
)
|
134 |
return completion
|
@@ -164,7 +165,8 @@ if __name__ == '__main__':
|
|
164 |
print(cgi.chat_completion_simple(user_text=ut, system_text=st))
|
165 |
print('---')
|
166 |
|
167 |
-
|
168 |
-
|
|
|
169 |
print('\n---')
|
170 |
|
|
|
1 |
+
import logging
|
2 |
import time
|
3 |
|
4 |
import tiktoken
|
|
|
10 |
openai.api_key = OPENAI_KEY
|
11 |
|
12 |
|
13 |
+
logging.basicConfig(level=logging.INFO)
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
def num_tokens_from_messages(messages, model):
|
18 |
"""
|
19 |
Return the number of tokens used by a list of messages.
|
|
|
22 |
try:
|
23 |
encoding = tiktoken.encoding_for_model(model)
|
24 |
except KeyError:
|
25 |
+
logger.info("Warning: model not found. Using cl100k_base encoding.")
|
26 |
encoding = tiktoken.get_encoding("cl100k_base")
|
27 |
if model in {
|
28 |
"gpt-3.5-turbo-0613",
|
|
|
38 |
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
39 |
tokens_per_name = -1 # if there's a name, the role is omitted
|
40 |
elif "gpt-3.5-turbo" in model:
|
41 |
+
# logger.info()("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
42 |
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
|
43 |
elif "gpt-4" in model:
|
44 |
+
# logger.info()("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
45 |
return num_tokens_from_messages(messages, model="gpt-4-0613")
|
46 |
else:
|
47 |
raise NotImplementedError(
|
|
|
59 |
|
60 |
|
61 |
class ChatGptInteractor:
|
62 |
+
def __init__(self, model_name='gpt-3.5-turbo', max_tokens=None, temperature=None, stream=False):
|
63 |
self.model_name = model_name
|
64 |
+
self.max_tokens = max_tokens
|
65 |
+
self.temperature = temperature
|
66 |
+
self.stream = stream
|
67 |
self.tokenizer = tiktoken.encoding_for_model(self.model_name)
|
68 |
|
69 |
def chat_completion_simple(
|
|
|
71 |
*,
|
72 |
user_text,
|
73 |
system_text=None,
|
|
|
|
|
|
|
74 |
):
|
75 |
return self.chat_completion(
|
76 |
self._construct_messages_simple(user_text, system_text),
|
|
|
|
|
|
|
77 |
)
|
78 |
|
79 |
def count_tokens_simple(self, *, user_text, system_text=None):
|
|
|
93 |
})
|
94 |
return messages
|
95 |
|
96 |
+
def chat_completion(self, messages):
|
97 |
+
logger.info(f'Sending request to {self.model_name} stream={self.stream} ...')
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
t1 = time.time()
|
99 |
+
completion = self._request(messages)
|
100 |
+
|
101 |
+
if self.stream:
|
102 |
+
return self._generator(completion)
|
103 |
+
|
|
|
|
|
|
|
|
|
104 |
t2 = time.time()
|
105 |
usage = completion['usage']
|
106 |
+
logger.info(
|
107 |
f'Received response: {usage["prompt_tokens"]} in + {usage["completion_tokens"]} out'
|
108 |
f' = {usage["total_tokens"]} total tokens. Time: {t2 - t1:3.1f} seconds'
|
109 |
)
|
|
|
113 |
def get_stream_text(stream_part):
|
114 |
return stream_part['choices'][0]['delta'].get('content', '')
|
115 |
|
116 |
+
@staticmethod
|
117 |
+
def _generator(completion):
|
118 |
+
for part in completion:
|
119 |
+
yield ChatGptInteractor.get_stream_text(part)
|
120 |
+
|
121 |
def count_tokens(self, messages):
|
122 |
return num_tokens_from_messages(messages, self.model_name)
|
123 |
|
124 |
+
def _request(self, messages):
|
125 |
for _ in range(5):
|
126 |
try:
|
127 |
completion = openai.ChatCompletion.create(
|
128 |
+
messages=messages,
|
129 |
+
model=self.model_name,
|
130 |
+
max_tokens=self.max_tokens,
|
131 |
+
temperature=self.temperature,
|
132 |
+
stream=self.stream,
|
133 |
request_timeout=100.0,
|
134 |
)
|
135 |
return completion
|
|
|
165 |
print(cgi.chat_completion_simple(user_text=ut, system_text=st))
|
166 |
print('---')
|
167 |
|
168 |
+
cgi = ChatGptInteractor(stream=True)
|
169 |
+
for part in cgi.chat_completion_simple(user_text=ut, system_text=st):
|
170 |
+
print(part, end='')
|
171 |
print('\n---')
|
172 |
|
gradio_app/backend/HuggingfaceGenerator.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from huggingface_hub import InferenceClient
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
|
6 |
+
with open('data/hftoken.txt') as f:
|
7 |
+
HF_TOKEN = f.read().strip()
|
8 |
+
|
9 |
+
logging.basicConfig(level=logging.INFO)
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
# noinspection PyTypeChecker
|
14 |
+
class HuggingfaceGenerator:
|
15 |
+
def __init__(
|
16 |
+
self, model_name,
|
17 |
+
temperature: float = 0.9, max_new_tokens: int = 512,
|
18 |
+
top_p: float = None, repetition_penalty: float = None,
|
19 |
+
stream: bool = True,
|
20 |
+
):
|
21 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
22 |
+
self.hf_client = InferenceClient(model_name, token=HF_TOKEN)
|
23 |
+
self.stream = stream
|
24 |
+
|
25 |
+
self.generate_kwargs = {
|
26 |
+
'temperature': max(temperature, 0.1),
|
27 |
+
'max_new_tokens': max_new_tokens,
|
28 |
+
'top_p': top_p,
|
29 |
+
'repetition_penalty': repetition_penalty,
|
30 |
+
'do_sample': True,
|
31 |
+
'seed': 42,
|
32 |
+
}
|
33 |
+
|
34 |
+
def generate(self, messages):
|
35 |
+
formatted_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
36 |
+
|
37 |
+
logger.info(f'Start HuggingFace generation, model {self.hf_client.model} ...')
|
38 |
+
stream = self.hf_client.text_generation(
|
39 |
+
formatted_prompt, **self.generate_kwargs,
|
40 |
+
stream=self.stream, details=True, return_full_text=not self.stream
|
41 |
+
)
|
42 |
+
|
43 |
+
for response in stream:
|
44 |
+
yield response.token.text
|
gradio_app/backend/cross_encoder.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
+
|
4 |
+
from settings import *
|
5 |
+
|
6 |
+
|
7 |
+
cross_encoder = None
|
8 |
+
cross_enc_tokenizer = None
|
9 |
+
|
10 |
+
|
11 |
+
@torch.no_grad()
|
12 |
+
def rerank_with_cross_encoder(cross_enc_name, documents, query):
|
13 |
+
if cross_enc_name is None or len(documents) <= 1:
|
14 |
+
return documents
|
15 |
+
|
16 |
+
global cross_encoder, cross_enc_tokenizer
|
17 |
+
if cross_encoder is None or cross_encoder.name_or_path != cross_enc_name:
|
18 |
+
cross_encoder = AutoModelForSequenceClassification.from_pretrained(cross_enc_name)
|
19 |
+
cross_encoder.eval()
|
20 |
+
cross_enc_tokenizer = AutoTokenizer.from_pretrained(cross_enc_name)
|
21 |
+
|
22 |
+
features = cross_enc_tokenizer(
|
23 |
+
[query] * len(documents), documents, padding=True, truncation=True, return_tensors="pt"
|
24 |
+
)
|
25 |
+
scores = cross_encoder(**features).logits.squeeze()
|
26 |
+
ranks = torch.argsort(scores, descending=True)
|
27 |
+
documents = [documents[i] for i in ranks[:TOP_K_RERANK]]
|
28 |
+
return documents
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
gradio_app/backend/query_llm.py
CHANGED
@@ -1,102 +1,30 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
|
3 |
-
from typing import Any, Dict, Generator, List
|
4 |
-
|
5 |
-
# from huggingface_hub import InferenceClient
|
6 |
-
# from transformers import AutoTokenizer
|
7 |
from jinja2 import Environment, FileSystemLoader
|
8 |
|
9 |
-
from settings import *
|
10 |
from gradio_app.backend.ChatGptInteractor import *
|
11 |
-
|
12 |
-
|
13 |
-
# tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
|
14 |
-
# HF_TOKEN = None
|
15 |
-
# hf_client = InferenceClient(LLM_NAME, token=HF_TOKEN)
|
16 |
-
|
17 |
-
|
18 |
-
def format_prompt(message: str, api_kind: str):
|
19 |
-
"""
|
20 |
-
Formats the given message using a chat template.
|
21 |
-
|
22 |
-
Args:
|
23 |
-
message (str): The user message to be formatted.
|
24 |
-
|
25 |
-
Returns:
|
26 |
-
str: Formatted message after applying the chat template.
|
27 |
-
"""
|
28 |
-
|
29 |
-
# Create a list of message dictionaries with role and content
|
30 |
-
messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
|
31 |
-
|
32 |
-
if api_kind == "openai":
|
33 |
-
return messages
|
34 |
-
elif api_kind == "hf":
|
35 |
-
return tokenizer.apply_chat_template(messages, tokenize=False)
|
36 |
-
else:
|
37 |
-
raise ValueError("API is not supported")
|
38 |
-
|
39 |
-
|
40 |
-
def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
|
41 |
-
top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
|
42 |
-
"""
|
43 |
-
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
44 |
-
|
45 |
-
Args:
|
46 |
-
prompt (str): The initial prompt for the text generation.
|
47 |
-
history (str): Context or history for the text generation.
|
48 |
-
temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
|
49 |
-
max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
|
50 |
-
top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
|
51 |
-
repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
|
52 |
-
|
53 |
-
Returns:
|
54 |
-
Generator[str, None, str]: A generator yielding chunks of generated text.
|
55 |
-
Returns a final string if an error occurs.
|
56 |
-
"""
|
57 |
-
|
58 |
-
temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
|
59 |
-
top_p = float(top_p)
|
60 |
-
|
61 |
-
generate_kwargs = {
|
62 |
-
'temperature': temperature,
|
63 |
-
'max_new_tokens': max_new_tokens,
|
64 |
-
'top_p': top_p,
|
65 |
-
'repetition_penalty': repetition_penalty,
|
66 |
-
'do_sample': True,
|
67 |
-
'seed': 42,
|
68 |
-
}
|
69 |
-
|
70 |
-
formatted_prompt = format_prompt(prompt, "hf")
|
71 |
-
|
72 |
-
try:
|
73 |
-
stream = hf_client.text_generation(formatted_prompt, **generate_kwargs,
|
74 |
-
stream=True, details=True, return_full_text=False)
|
75 |
-
output = ""
|
76 |
-
for response in stream:
|
77 |
-
output += response.token.text
|
78 |
-
yield output
|
79 |
-
|
80 |
-
except Exception as e:
|
81 |
-
if "Too Many Requests" in str(e):
|
82 |
-
print("ERROR: Too many requests on Mistral client")
|
83 |
-
gr.Warning("Unfortunately Mistral is unable to process")
|
84 |
-
return "Unfortunately, I am not able to process your request now."
|
85 |
-
elif "Authorization header is invalid" in str(e):
|
86 |
-
print("Authetification error:", str(e))
|
87 |
-
gr.Warning("Authentication error: HF token was either not provided or incorrect")
|
88 |
-
return "Authentication error"
|
89 |
-
else:
|
90 |
-
print("Unhandled Exception:", str(e))
|
91 |
-
gr.Warning("Unfortunately Mistral is unable to process")
|
92 |
-
return "I do not know what happened, but I couldn't understand you."
|
93 |
-
|
94 |
|
95 |
env = Environment(loader=FileSystemLoader('gradio_app/templates'))
|
96 |
context_template = env.get_template('context_template.j2')
|
97 |
start_system_message = context_template.render(documents=[])
|
98 |
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
def construct_openai_messages(context, history):
|
101 |
messages = [
|
102 |
{
|
@@ -122,64 +50,32 @@ def construct_openai_messages(context, history):
|
|
122 |
return messages
|
123 |
|
124 |
|
125 |
-
def
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
def _generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
|
132 |
-
top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
|
133 |
-
"""
|
134 |
-
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
135 |
|
136 |
-
Args:
|
137 |
-
prompt (str): The initial prompt for the text generation.
|
138 |
-
history (str): Context or history for the text generation.
|
139 |
-
temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
|
140 |
-
max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
|
141 |
-
top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
|
142 |
-
repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
151 |
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
|
159 |
-
formatted_prompt = format_prompt(prompt, "openai")
|
160 |
|
161 |
-
try:
|
162 |
-
stream = openai.ChatCompletion.create(
|
163 |
-
model=LLM_NAME,
|
164 |
-
messages=formatted_prompt,
|
165 |
-
**generate_kwargs,
|
166 |
-
stream=True
|
167 |
-
)
|
168 |
-
output = ""
|
169 |
-
for chunk in stream:
|
170 |
-
output += chunk.choices[0].delta.get("content", "")
|
171 |
-
yield output
|
172 |
|
173 |
-
except Exception as e:
|
174 |
-
if "Too Many Requests" in str(e):
|
175 |
-
print("ERROR: Too many requests on OpenAI client")
|
176 |
-
gr.Warning("Unfortunately OpenAI is unable to process")
|
177 |
-
return "Unfortunately, I am not able to process your request now."
|
178 |
-
elif "You didn't provide an API key" in str(e):
|
179 |
-
print("Authetification error:", str(e))
|
180 |
-
gr.Warning("Authentication error: OpenAI key was either not provided or incorrect")
|
181 |
-
return "Authentication error"
|
182 |
-
else:
|
183 |
-
print("Unhandled Exception:", str(e))
|
184 |
-
gr.Warning("Unfortunately OpenAI is unable to process")
|
185 |
-
return "I do not know what happened, but I couldn't understand you."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from jinja2 import Environment, FileSystemLoader
|
2 |
|
|
|
3 |
from gradio_app.backend.ChatGptInteractor import *
|
4 |
+
from gradio_app.backend.HuggingfaceGenerator import HuggingfaceGenerator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
env = Environment(loader=FileSystemLoader('gradio_app/templates'))
|
7 |
context_template = env.get_template('context_template.j2')
|
8 |
start_system_message = context_template.render(documents=[])
|
9 |
|
10 |
|
11 |
+
def construct_mistral_messages(context, history):
|
12 |
+
messages = []
|
13 |
+
for q, a in history:
|
14 |
+
if len(a) == 0: # the last message
|
15 |
+
q = context + f'\n\nQuery:\n\n{q}'
|
16 |
+
messages.append({
|
17 |
+
"role": "user",
|
18 |
+
"content": q,
|
19 |
+
})
|
20 |
+
if len(a) != 0: # some of the previous LLM answers
|
21 |
+
messages.append({
|
22 |
+
"role": "assistant",
|
23 |
+
"content": a,
|
24 |
+
})
|
25 |
+
return messages
|
26 |
+
|
27 |
+
|
28 |
def construct_openai_messages(context, history):
|
29 |
messages = [
|
30 |
{
|
|
|
50 |
return messages
|
51 |
|
52 |
|
53 |
+
def get_message_constructor(llm_name):
|
54 |
+
if llm_name == 'gpt-3.5-turbo':
|
55 |
+
return construct_openai_messages
|
56 |
+
if llm_name in ['mistralai/Mistral-7B-Instruct-v0.1', "GeneZC/MiniChat-3B"]:
|
57 |
+
return construct_mistral_messages
|
58 |
+
raise ValueError('Unknown LLM name')
|
|
|
|
|
|
|
|
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
+
def get_llm_generator(llm_name):
|
62 |
+
if llm_name == 'gpt-3.5-turbo':
|
63 |
+
cgi = ChatGptInteractor(
|
64 |
+
model_name=llm_name, max_tokens=512, temperature=0, stream=True
|
65 |
+
)
|
66 |
+
return cgi.chat_completion
|
67 |
+
if llm_name == 'mistralai/Mistral-7B-Instruct-v0.1':
|
68 |
+
hfg = HuggingfaceGenerator(
|
69 |
+
model_name=llm_name, temperature=0, max_new_tokens=512,
|
70 |
+
)
|
71 |
+
return hfg.generate
|
72 |
|
73 |
+
if llm_name == "GeneZC/MiniChat-3B":
|
74 |
+
hfg = HuggingfaceGenerator(
|
75 |
+
model_name=llm_name, temperature=0, max_new_tokens=250, stream=False,
|
76 |
+
)
|
77 |
+
return hfg.generate
|
78 |
+
raise ValueError('Unknown LLM name')
|
79 |
|
|
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
settings.py
CHANGED
@@ -5,11 +5,11 @@ VECTOR_COLUMN_NAME = "embedding"
|
|
5 |
TEXT_COLUMN_NAME = "text"
|
6 |
DOCUMENT_PATH_COLUMN_NAME = "document_path"
|
7 |
|
8 |
-
# LLM_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
|
9 |
-
LLM_NAME = "gpt-3.5-turbo"
|
10 |
# EMBED_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
11 |
EMBED_NAME = "text-embedding-ada-002"
|
12 |
|
|
|
|
|
13 |
|
14 |
emb_sizes = {
|
15 |
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
@@ -17,8 +17,14 @@ emb_sizes = {
|
|
17 |
"text-embedding-ada-002": 1536,
|
18 |
}
|
19 |
|
|
|
|
|
|
|
|
|
|
|
20 |
context_lengths = {
|
21 |
"mistralai/Mistral-7B-Instruct-v0.1": 4096,
|
|
|
22 |
"gpt-3.5-turbo": 4096,
|
23 |
"sentence-transformers/all-MiniLM-L6-v2": 128,
|
24 |
"thenlper/gte-large": 512,
|
|
|
5 |
TEXT_COLUMN_NAME = "text"
|
6 |
DOCUMENT_PATH_COLUMN_NAME = "document_path"
|
7 |
|
|
|
|
|
8 |
# EMBED_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
9 |
EMBED_NAME = "text-embedding-ada-002"
|
10 |
|
11 |
+
TOP_K_RANK = 50
|
12 |
+
TOP_K_RERANK = 5
|
13 |
|
14 |
emb_sizes = {
|
15 |
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
|
|
17 |
"text-embedding-ada-002": 1536,
|
18 |
}
|
19 |
|
20 |
+
thresh_distances = {
|
21 |
+
"sentence-transformers/all-MiniLM-L6-v2": 1.2,
|
22 |
+
"text-embedding-ada-002": 0.5,
|
23 |
+
}
|
24 |
+
|
25 |
context_lengths = {
|
26 |
"mistralai/Mistral-7B-Instruct-v0.1": 4096,
|
27 |
+
"GeneZC/MiniChat-3B": 4096,
|
28 |
"gpt-3.5-turbo": 4096,
|
29 |
"sentence-transformers/all-MiniLM-L6-v2": 128,
|
30 |
"thenlper/gte-large": 512,
|