Update app.py
Browse files
app.py
CHANGED
@@ -13,11 +13,9 @@ import numpy as np
|
|
13 |
import pandas as pd
|
14 |
import streamlit as st
|
15 |
|
16 |
-
from vllm import LLM
|
17 |
from datetime import datetime
|
18 |
from streamlit import _bottom
|
19 |
from pydantic import BaseModel
|
20 |
-
from outlines.models import VLLM
|
21 |
from streamlit_pills import pills
|
22 |
from dense_embed import embed_text
|
23 |
from ppt_chunker import ppt_chunker
|
@@ -103,12 +101,14 @@ icon_to_types = {
|
|
103 |
def generate_answer(query: str,
|
104 |
client: QdrantClient,
|
105 |
collection_name: str,
|
106 |
-
llm,
|
107 |
dense_model: AsyncEmbeddingEngine,
|
108 |
sparse_model: SparseTextEmbedding,
|
109 |
past_messages: str,
|
110 |
search_strategy,
|
111 |
-
documents_only: bool
|
|
|
|
|
|
|
112 |
):
|
113 |
sparse_embeddings = list(sparse_model.query_embed(query))[0].as_object()
|
114 |
|
@@ -130,13 +130,9 @@ def generate_answer(query: str,
|
|
130 |
contents, metadatas = [list(t) for t in zip(*docs)]
|
131 |
|
132 |
context = "\n".join(contents)
|
133 |
-
|
134 |
-
regex = build_regex_from_schema(schema, r"[\n ]?")
|
135 |
-
gen_text = outlines.generate.regex(llm, regex)
|
136 |
|
137 |
-
gen_choice = outlines.generate.choice(llm, choices=['Yes', 'No'])
|
138 |
prompt = route_llm(context, query)
|
139 |
-
action =
|
140 |
print(f'Choice: {action}')
|
141 |
|
142 |
if action == 'Yes':
|
@@ -155,9 +151,8 @@ def generate_answer(query: str,
|
|
155 |
if documents_only == False:
|
156 |
answer = f'Documents Based :\n\n{answer}'
|
157 |
else:
|
158 |
-
gen_choice = outlines.generate.choice(llm, choices=['Domain-Specific Question', 'General Question'])
|
159 |
prompt = question_type_prompt(query)
|
160 |
-
action =
|
161 |
print(f'Choice 2: {action}')
|
162 |
|
163 |
if action == 'General Question':
|
@@ -210,6 +205,7 @@ def collect_files(directory, pattern):
|
|
210 |
|
211 |
return array
|
212 |
|
|
|
213 |
def load_models_and_documents():
|
214 |
container = st.empty()
|
215 |
|
@@ -240,17 +236,22 @@ def load_models_and_documents():
|
|
240 |
|
241 |
st.write('Downloading and Loading Mistral Nemo quantized with GPTQ and using Outlines + vLLM Engine as backend...')
|
242 |
|
243 |
-
llm =
|
244 |
-
|
245 |
tensor_parallel_size=1,
|
246 |
enforce_eager=True,
|
247 |
-
gpu_memory_utilization=
|
248 |
max_model_len=8192,
|
249 |
dtype=torch.float16,
|
250 |
max_num_seqs=64,
|
251 |
quantization="gptq"
|
252 |
)
|
253 |
-
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
st.write('Downloading NLTK Packages...')
|
256 |
|
@@ -441,7 +442,19 @@ def load_models_and_documents():
|
|
441 |
|
442 |
st.write('Building FSM Index for Agentic Behaviour of our AI...')
|
443 |
|
444 |
-
answer = generate_answer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
|
446 |
status.update(
|
447 |
label="Processing Complete!", state="complete", expanded=False
|
@@ -450,14 +463,13 @@ def load_models_and_documents():
|
|
450 |
time.sleep(5)
|
451 |
container.empty()
|
452 |
|
453 |
-
return client, collection_name,
|
454 |
|
455 |
|
456 |
if __name__ == '__main__':
|
457 |
st.set_page_config(page_title="Multipurpose AI Agent",layout="wide", initial_sidebar_state='auto')
|
458 |
|
459 |
-
|
460 |
-
st.session_state.client, st.session_state.collection_name, st.session_state.llm, st.session_state.model, st.session_state.dense_model, st.session_state.sparse_model = load_models_and_documents()
|
461 |
|
462 |
styles = {
|
463 |
"nav": {
|
@@ -530,7 +542,7 @@ if __name__ == '__main__':
|
|
530 |
os.remove(os.path.join(embeddings_path, name + '_dense.npz'))
|
531 |
os.remove(os.path.join(embeddings_path, name + '_sparse.npz'))
|
532 |
|
533 |
-
|
534 |
collection_name=collection_name,
|
535 |
points_selector=Filter(
|
536 |
must=[
|
@@ -610,10 +622,10 @@ if __name__ == '__main__':
|
|
610 |
engine='pyarrow'
|
611 |
)
|
612 |
|
613 |
-
documents, ids = ppt_chunker(uploaded_file,
|
614 |
|
615 |
-
dense_embeddings, tokens_count = asyncio.run(embed_text(
|
616 |
-
sparse_embeddings = [s for s in
|
617 |
|
618 |
embeddings_path = os.path.join(os.getenv('HF_HOME'), 'embeddings')
|
619 |
|
@@ -638,8 +650,8 @@ if __name__ == '__main__':
|
|
638 |
|
639 |
payload_docs = [{ 'text': documents[i], 'metadata': metadata } for i, metadata in enumerate(metadatas_list)]
|
640 |
|
641 |
-
|
642 |
-
collection_name=
|
643 |
points=Batch(
|
644 |
ids=ids,
|
645 |
payloads=payload_docs,
|
@@ -687,7 +699,7 @@ if __name__ == '__main__':
|
|
687 |
use_container_width=True,
|
688 |
hide_index=True,
|
689 |
on_change=on_change_data_editor,
|
690 |
-
args=(
|
691 |
key='key_data_editor',
|
692 |
column_config={
|
693 |
'icon': st.column_config.ImageColumn(
|
@@ -812,17 +824,20 @@ if __name__ == '__main__':
|
|
812 |
with st.chat_message(message["role"]):
|
813 |
st.markdown(message["content"])
|
814 |
|
815 |
-
def generate_conv_title(
|
816 |
st.session_state.local_user_input = st.session_state.user_input
|
817 |
print(f'USER INPUT : {st.session_state.user_input}')
|
818 |
st.session_state.user_input = " "
|
819 |
if st.session_state.chat_id == 'New Conversation':
|
820 |
-
output =
|
821 |
-
|
822 |
-
|
|
|
|
|
|
|
823 |
)
|
824 |
-
print(f'OUTPUT : {output
|
825 |
-
st.session_state.chat_id = output
|
826 |
st.session_state.messages = []
|
827 |
|
828 |
torch.cuda.empty_cache()
|
@@ -842,7 +857,7 @@ if __name__ == '__main__':
|
|
842 |
key='user_input',
|
843 |
placeholder='Message Video Game Assistant',
|
844 |
label_visibility='collapsed',
|
845 |
-
args=(
|
846 |
):
|
847 |
if prompt != ('Exact Search : ' or 'Explain Further : '):
|
848 |
st.chat_message("user").markdown(st.session_state.local_user_input)
|
@@ -850,14 +865,16 @@ if __name__ == '__main__':
|
|
850 |
|
851 |
ai_response = generate_answer(
|
852 |
st.session_state.local_user_input,
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
st.session_state.sparse_model,
|
858 |
"\n".join([f'{msg["role"]}: {msg["content"]}' for msg in st.session_state.messages]),
|
859 |
st.session_state.search_strategy,
|
860 |
-
st.session_state.documents_only
|
|
|
|
|
|
|
861 |
)
|
862 |
|
863 |
with st.chat_message("assistant"):
|
|
|
13 |
import pandas as pd
|
14 |
import streamlit as st
|
15 |
|
|
|
16 |
from datetime import datetime
|
17 |
from streamlit import _bottom
|
18 |
from pydantic import BaseModel
|
|
|
19 |
from streamlit_pills import pills
|
20 |
from dense_embed import embed_text
|
21 |
from ppt_chunker import ppt_chunker
|
|
|
101 |
def generate_answer(query: str,
|
102 |
client: QdrantClient,
|
103 |
collection_name: str,
|
|
|
104 |
dense_model: AsyncEmbeddingEngine,
|
105 |
sparse_model: SparseTextEmbedding,
|
106 |
past_messages: str,
|
107 |
search_strategy,
|
108 |
+
documents_only: bool,
|
109 |
+
gen_text,
|
110 |
+
gen_context_choice,
|
111 |
+
gen_question_choice
|
112 |
):
|
113 |
sparse_embeddings = list(sparse_model.query_embed(query))[0].as_object()
|
114 |
|
|
|
130 |
contents, metadatas = [list(t) for t in zip(*docs)]
|
131 |
|
132 |
context = "\n".join(contents)
|
|
|
|
|
|
|
133 |
|
|
|
134 |
prompt = route_llm(context, query)
|
135 |
+
action = gen_context_choice(prompt, max_tokens=2, sampling_params=SamplingParams(temperature=0))
|
136 |
print(f'Choice: {action}')
|
137 |
|
138 |
if action == 'Yes':
|
|
|
151 |
if documents_only == False:
|
152 |
answer = f'Documents Based :\n\n{answer}'
|
153 |
else:
|
|
|
154 |
prompt = question_type_prompt(query)
|
155 |
+
action = gen_question_choice(prompt, max_tokens=3, sampling_params=SamplingParams(temperature=0))
|
156 |
print(f'Choice 2: {action}')
|
157 |
|
158 |
if action == 'General Question':
|
|
|
205 |
|
206 |
return array
|
207 |
|
208 |
+
@st.cache_resource(show_spinner=False)
|
209 |
def load_models_and_documents():
|
210 |
container = st.empty()
|
211 |
|
|
|
236 |
|
237 |
st.write('Downloading and Loading Mistral Nemo quantized with GPTQ and using Outlines + vLLM Engine as backend...')
|
238 |
|
239 |
+
llm = outlines.models.vllm(
|
240 |
+
model_name="shuyuej/Mistral-Nemo-Instruct-2407-GPTQ",
|
241 |
tensor_parallel_size=1,
|
242 |
enforce_eager=True,
|
243 |
+
gpu_memory_utilization=1,
|
244 |
max_model_len=8192,
|
245 |
dtype=torch.float16,
|
246 |
max_num_seqs=64,
|
247 |
quantization="gptq"
|
248 |
)
|
249 |
+
|
250 |
+
regex = build_regex_from_schema(schema, r"[\n ]?")
|
251 |
+
gen_text = outlines.generate.regex(llm, regex)
|
252 |
+
|
253 |
+
gen_context_choice = outlines.generate.choice(llm, choices=['Yes', 'No'])
|
254 |
+
gen_question_choice = outlines.generate.choice(llm, choices=['Domain-Specific Question', 'General Question'])
|
255 |
|
256 |
st.write('Downloading NLTK Packages...')
|
257 |
|
|
|
442 |
|
443 |
st.write('Building FSM Index for Agentic Behaviour of our AI...')
|
444 |
|
445 |
+
answer = generate_answer(
|
446 |
+
'aggro',
|
447 |
+
client,
|
448 |
+
collection_name,
|
449 |
+
dense_model,
|
450 |
+
sparse_model,
|
451 |
+
'',
|
452 |
+
'Exact Search',
|
453 |
+
False,
|
454 |
+
gen_text,
|
455 |
+
gen_context_choice,
|
456 |
+
gen_question_choice
|
457 |
+
)
|
458 |
|
459 |
status.update(
|
460 |
label="Processing Complete!", state="complete", expanded=False
|
|
|
463 |
time.sleep(5)
|
464 |
container.empty()
|
465 |
|
466 |
+
return client, collection_name, dense_model, sparse_model, gen_text, gen_context_choice, gen_question_choice
|
467 |
|
468 |
|
469 |
if __name__ == '__main__':
|
470 |
st.set_page_config(page_title="Multipurpose AI Agent",layout="wide", initial_sidebar_state='auto')
|
471 |
|
472 |
+
client, collection_name, dense_model, sparse_model, gen_text, gen_context_choice, gen_question_choice = load_models_and_documents()
|
|
|
473 |
|
474 |
styles = {
|
475 |
"nav": {
|
|
|
542 |
os.remove(os.path.join(embeddings_path, name + '_dense.npz'))
|
543 |
os.remove(os.path.join(embeddings_path, name + '_sparse.npz'))
|
544 |
|
545 |
+
client.delete(
|
546 |
collection_name=collection_name,
|
547 |
points_selector=Filter(
|
548 |
must=[
|
|
|
622 |
engine='pyarrow'
|
623 |
)
|
624 |
|
625 |
+
documents, ids = ppt_chunker(uploaded_file, llm)
|
626 |
|
627 |
+
dense_embeddings, tokens_count = asyncio.run(embed_text(dense_model[0], documents))
|
628 |
+
sparse_embeddings = [s for s in sparse_model.embed(documents, 32)]
|
629 |
|
630 |
embeddings_path = os.path.join(os.getenv('HF_HOME'), 'embeddings')
|
631 |
|
|
|
650 |
|
651 |
payload_docs = [{ 'text': documents[i], 'metadata': metadata } for i, metadata in enumerate(metadatas_list)]
|
652 |
|
653 |
+
client.upsert(
|
654 |
+
collection_name=collection_name,
|
655 |
points=Batch(
|
656 |
ids=ids,
|
657 |
payloads=payload_docs,
|
|
|
699 |
use_container_width=True,
|
700 |
hide_index=True,
|
701 |
on_change=on_change_data_editor,
|
702 |
+
args=(client, collection_name),
|
703 |
key='key_data_editor',
|
704 |
column_config={
|
705 |
'icon': st.column_config.ImageColumn(
|
|
|
824 |
with st.chat_message(message["role"]):
|
825 |
st.markdown(message["content"])
|
826 |
|
827 |
+
def generate_conv_title(generator):
|
828 |
st.session_state.local_user_input = st.session_state.user_input
|
829 |
print(f'USER INPUT : {st.session_state.user_input}')
|
830 |
st.session_state.user_input = " "
|
831 |
if st.session_state.chat_id == 'New Conversation':
|
832 |
+
output = json.loads(
|
833 |
+
generator(
|
834 |
+
build_prompt_conv(st.session_state.local_user_input),
|
835 |
+
max_tokens=10,
|
836 |
+
sampling_params=SamplingParams(temperature=0, top_p=0.9, top_k=10)
|
837 |
+
)
|
838 |
)
|
839 |
+
print(f'OUTPUT : {output}')
|
840 |
+
st.session_state.chat_id = output
|
841 |
st.session_state.messages = []
|
842 |
|
843 |
torch.cuda.empty_cache()
|
|
|
857 |
key='user_input',
|
858 |
placeholder='Message Video Game Assistant',
|
859 |
label_visibility='collapsed',
|
860 |
+
args=(gen_text, )
|
861 |
):
|
862 |
if prompt != ('Exact Search : ' or 'Explain Further : '):
|
863 |
st.chat_message("user").markdown(st.session_state.local_user_input)
|
|
|
865 |
|
866 |
ai_response = generate_answer(
|
867 |
st.session_state.local_user_input,
|
868 |
+
client,
|
869 |
+
collection_name,
|
870 |
+
dense_model,
|
871 |
+
sparse_model,
|
|
|
872 |
"\n".join([f'{msg["role"]}: {msg["content"]}' for msg in st.session_state.messages]),
|
873 |
st.session_state.search_strategy,
|
874 |
+
st.session_state.documents_only,
|
875 |
+
gen_text,
|
876 |
+
gen_context_choice,
|
877 |
+
gen_question_choice
|
878 |
)
|
879 |
|
880 |
with st.chat_message("assistant"):
|