devve1 commited on
Commit
cbc8d85
1 Parent(s): c371e5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -39
app.py CHANGED
@@ -13,11 +13,9 @@ import numpy as np
13
  import pandas as pd
14
  import streamlit as st
15
 
16
- from vllm import LLM
17
  from datetime import datetime
18
  from streamlit import _bottom
19
  from pydantic import BaseModel
20
- from outlines.models import VLLM
21
  from streamlit_pills import pills
22
  from dense_embed import embed_text
23
  from ppt_chunker import ppt_chunker
@@ -103,12 +101,14 @@ icon_to_types = {
103
  def generate_answer(query: str,
104
  client: QdrantClient,
105
  collection_name: str,
106
- llm,
107
  dense_model: AsyncEmbeddingEngine,
108
  sparse_model: SparseTextEmbedding,
109
  past_messages: str,
110
  search_strategy,
111
- documents_only: bool
 
 
 
112
  ):
113
  sparse_embeddings = list(sparse_model.query_embed(query))[0].as_object()
114
 
@@ -130,13 +130,9 @@ def generate_answer(query: str,
130
  contents, metadatas = [list(t) for t in zip(*docs)]
131
 
132
  context = "\n".join(contents)
133
-
134
- regex = build_regex_from_schema(schema, r"[\n ]?")
135
- gen_text = outlines.generate.regex(llm, regex)
136
 
137
- gen_choice = outlines.generate.choice(llm, choices=['Yes', 'No'])
138
  prompt = route_llm(context, query)
139
- action = gen_choice(prompt, max_tokens=2, sampling_params=SamplingParams(temperature=0))
140
  print(f'Choice: {action}')
141
 
142
  if action == 'Yes':
@@ -155,9 +151,8 @@ def generate_answer(query: str,
155
  if documents_only == False:
156
  answer = f'Documents Based :\n\n{answer}'
157
  else:
158
- gen_choice = outlines.generate.choice(llm, choices=['Domain-Specific Question', 'General Question'])
159
  prompt = question_type_prompt(query)
160
- action = gen_choice(prompt, max_tokens=3, sampling_params=SamplingParams(temperature=0))
161
  print(f'Choice 2: {action}')
162
 
163
  if action == 'General Question':
@@ -210,6 +205,7 @@ def collect_files(directory, pattern):
210
 
211
  return array
212
 
 
213
  def load_models_and_documents():
214
  container = st.empty()
215
 
@@ -240,17 +236,22 @@ def load_models_and_documents():
240
 
241
  st.write('Downloading and Loading Mistral Nemo quantized with GPTQ and using Outlines + vLLM Engine as backend...')
242
 
243
- llm = LLM(
244
- model="shuyuej/Mistral-Nemo-Instruct-2407-GPTQ",
245
  tensor_parallel_size=1,
246
  enforce_eager=True,
247
- gpu_memory_utilization=0.9,
248
  max_model_len=8192,
249
  dtype=torch.float16,
250
  max_num_seqs=64,
251
  quantization="gptq"
252
  )
253
- model = VLLM(llm)
 
 
 
 
 
254
 
255
  st.write('Downloading NLTK Packages...')
256
 
@@ -441,7 +442,19 @@ def load_models_and_documents():
441
 
442
  st.write('Building FSM Index for Agentic Behaviour of our AI...')
443
 
444
- answer = generate_answer('aggro', client, collection_name, model, dense_model, sparse_model, '', 'Exact Search', False)
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  status.update(
447
  label="Processing Complete!", state="complete", expanded=False
@@ -450,14 +463,13 @@ def load_models_and_documents():
450
  time.sleep(5)
451
  container.empty()
452
 
453
- return client, collection_name, llm, model, dense_model, sparse_model
454
 
455
 
456
  if __name__ == '__main__':
457
  st.set_page_config(page_title="Multipurpose AI Agent",layout="wide", initial_sidebar_state='auto')
458
 
459
- if 'client' not in st.session_state:
460
- st.session_state.client, st.session_state.collection_name, st.session_state.llm, st.session_state.model, st.session_state.dense_model, st.session_state.sparse_model = load_models_and_documents()
461
 
462
  styles = {
463
  "nav": {
@@ -530,7 +542,7 @@ if __name__ == '__main__':
530
  os.remove(os.path.join(embeddings_path, name + '_dense.npz'))
531
  os.remove(os.path.join(embeddings_path, name + '_sparse.npz'))
532
 
533
- st.session_state.client.delete(
534
  collection_name=collection_name,
535
  points_selector=Filter(
536
  must=[
@@ -610,10 +622,10 @@ if __name__ == '__main__':
610
  engine='pyarrow'
611
  )
612
 
613
- documents, ids = ppt_chunker(uploaded_file, st.session_state.llm)
614
 
615
- dense_embeddings, tokens_count = asyncio.run(embed_text(st.session_state.dense_model[0], documents))
616
- sparse_embeddings = [s for s in st.session_state.sparse_model.embed(documents, 32)]
617
 
618
  embeddings_path = os.path.join(os.getenv('HF_HOME'), 'embeddings')
619
 
@@ -638,8 +650,8 @@ if __name__ == '__main__':
638
 
639
  payload_docs = [{ 'text': documents[i], 'metadata': metadata } for i, metadata in enumerate(metadatas_list)]
640
 
641
- st.session_state.client.upsert(
642
- collection_name=st.session_state.collection_name,
643
  points=Batch(
644
  ids=ids,
645
  payloads=payload_docs,
@@ -687,7 +699,7 @@ if __name__ == '__main__':
687
  use_container_width=True,
688
  hide_index=True,
689
  on_change=on_change_data_editor,
690
- args=(st.session_state.client, st.session_state.collection_name),
691
  key='key_data_editor',
692
  column_config={
693
  'icon': st.column_config.ImageColumn(
@@ -812,17 +824,20 @@ if __name__ == '__main__':
812
  with st.chat_message(message["role"]):
813
  st.markdown(message["content"])
814
 
815
- def generate_conv_title(llm):
816
  st.session_state.local_user_input = st.session_state.user_input
817
  print(f'USER INPUT : {st.session_state.user_input}')
818
  st.session_state.user_input = " "
819
  if st.session_state.chat_id == 'New Conversation':
820
- output = llm.chat(
821
- build_prompt_conv(st.session_state.local_user_input),
822
- SamplingParams(temperature=0, top_p=0.9, max_tokens=10, top_k=10)
 
 
 
823
  )
824
- print(f'OUTPUT : {output[0].outputs[0].text}')
825
- st.session_state.chat_id = output[0].outputs[0].text.replace('"', '')
826
  st.session_state.messages = []
827
 
828
  torch.cuda.empty_cache()
@@ -842,7 +857,7 @@ if __name__ == '__main__':
842
  key='user_input',
843
  placeholder='Message Video Game Assistant',
844
  label_visibility='collapsed',
845
- args=(st.session_state.llm, )
846
  ):
847
  if prompt != ('Exact Search : ' or 'Explain Further : '):
848
  st.chat_message("user").markdown(st.session_state.local_user_input)
@@ -850,14 +865,16 @@ if __name__ == '__main__':
850
 
851
  ai_response = generate_answer(
852
  st.session_state.local_user_input,
853
- st.session_state.client,
854
- st.session_state.collection_name,
855
- st.session_state.model,
856
- st.session_state.dense_model,
857
- st.session_state.sparse_model,
858
  "\n".join([f'{msg["role"]}: {msg["content"]}' for msg in st.session_state.messages]),
859
  st.session_state.search_strategy,
860
- st.session_state.documents_only
 
 
 
861
  )
862
 
863
  with st.chat_message("assistant"):
 
13
  import pandas as pd
14
  import streamlit as st
15
 
 
16
  from datetime import datetime
17
  from streamlit import _bottom
18
  from pydantic import BaseModel
 
19
  from streamlit_pills import pills
20
  from dense_embed import embed_text
21
  from ppt_chunker import ppt_chunker
 
101
  def generate_answer(query: str,
102
  client: QdrantClient,
103
  collection_name: str,
 
104
  dense_model: AsyncEmbeddingEngine,
105
  sparse_model: SparseTextEmbedding,
106
  past_messages: str,
107
  search_strategy,
108
+ documents_only: bool,
109
+ gen_text,
110
+ gen_context_choice,
111
+ gen_question_choice
112
  ):
113
  sparse_embeddings = list(sparse_model.query_embed(query))[0].as_object()
114
 
 
130
  contents, metadatas = [list(t) for t in zip(*docs)]
131
 
132
  context = "\n".join(contents)
 
 
 
133
 
 
134
  prompt = route_llm(context, query)
135
+ action = gen_context_choice(prompt, max_tokens=2, sampling_params=SamplingParams(temperature=0))
136
  print(f'Choice: {action}')
137
 
138
  if action == 'Yes':
 
151
  if documents_only == False:
152
  answer = f'Documents Based :\n\n{answer}'
153
  else:
 
154
  prompt = question_type_prompt(query)
155
+ action = gen_question_choice(prompt, max_tokens=3, sampling_params=SamplingParams(temperature=0))
156
  print(f'Choice 2: {action}')
157
 
158
  if action == 'General Question':
 
205
 
206
  return array
207
 
208
+ @st.cache_resource(show_spinner=False)
209
  def load_models_and_documents():
210
  container = st.empty()
211
 
 
236
 
237
  st.write('Downloading and Loading Mistral Nemo quantized with GPTQ and using Outlines + vLLM Engine as backend...')
238
 
239
+ llm = outlines.models.vllm(
240
+ model_name="shuyuej/Mistral-Nemo-Instruct-2407-GPTQ",
241
  tensor_parallel_size=1,
242
  enforce_eager=True,
243
+ gpu_memory_utilization=1,
244
  max_model_len=8192,
245
  dtype=torch.float16,
246
  max_num_seqs=64,
247
  quantization="gptq"
248
  )
249
+
250
+ regex = build_regex_from_schema(schema, r"[\n ]?")
251
+ gen_text = outlines.generate.regex(llm, regex)
252
+
253
+ gen_context_choice = outlines.generate.choice(llm, choices=['Yes', 'No'])
254
+ gen_question_choice = outlines.generate.choice(llm, choices=['Domain-Specific Question', 'General Question'])
255
 
256
  st.write('Downloading NLTK Packages...')
257
 
 
442
 
443
  st.write('Building FSM Index for Agentic Behaviour of our AI...')
444
 
445
+ answer = generate_answer(
446
+ 'aggro',
447
+ client,
448
+ collection_name,
449
+ dense_model,
450
+ sparse_model,
451
+ '',
452
+ 'Exact Search',
453
+ False,
454
+ gen_text,
455
+ gen_context_choice,
456
+ gen_question_choice
457
+ )
458
 
459
  status.update(
460
  label="Processing Complete!", state="complete", expanded=False
 
463
  time.sleep(5)
464
  container.empty()
465
 
466
+ return client, collection_name, dense_model, sparse_model, gen_text, gen_context_choice, gen_question_choice
467
 
468
 
469
  if __name__ == '__main__':
470
  st.set_page_config(page_title="Multipurpose AI Agent",layout="wide", initial_sidebar_state='auto')
471
 
472
+ client, collection_name, dense_model, sparse_model, gen_text, gen_context_choice, gen_question_choice = load_models_and_documents()
 
473
 
474
  styles = {
475
  "nav": {
 
542
  os.remove(os.path.join(embeddings_path, name + '_dense.npz'))
543
  os.remove(os.path.join(embeddings_path, name + '_sparse.npz'))
544
 
545
+ client.delete(
546
  collection_name=collection_name,
547
  points_selector=Filter(
548
  must=[
 
622
  engine='pyarrow'
623
  )
624
 
625
+ documents, ids = ppt_chunker(uploaded_file, llm)
626
 
627
+ dense_embeddings, tokens_count = asyncio.run(embed_text(dense_model[0], documents))
628
+ sparse_embeddings = [s for s in sparse_model.embed(documents, 32)]
629
 
630
  embeddings_path = os.path.join(os.getenv('HF_HOME'), 'embeddings')
631
 
 
650
 
651
  payload_docs = [{ 'text': documents[i], 'metadata': metadata } for i, metadata in enumerate(metadatas_list)]
652
 
653
+ client.upsert(
654
+ collection_name=collection_name,
655
  points=Batch(
656
  ids=ids,
657
  payloads=payload_docs,
 
699
  use_container_width=True,
700
  hide_index=True,
701
  on_change=on_change_data_editor,
702
+ args=(client, collection_name),
703
  key='key_data_editor',
704
  column_config={
705
  'icon': st.column_config.ImageColumn(
 
824
  with st.chat_message(message["role"]):
825
  st.markdown(message["content"])
826
 
827
+ def generate_conv_title(generator):
828
  st.session_state.local_user_input = st.session_state.user_input
829
  print(f'USER INPUT : {st.session_state.user_input}')
830
  st.session_state.user_input = " "
831
  if st.session_state.chat_id == 'New Conversation':
832
+ output = json.loads(
833
+ generator(
834
+ build_prompt_conv(st.session_state.local_user_input),
835
+ max_tokens=10,
836
+ sampling_params=SamplingParams(temperature=0, top_p=0.9, top_k=10)
837
+ )
838
  )
839
+ print(f'OUTPUT : {output}')
840
+ st.session_state.chat_id = output
841
  st.session_state.messages = []
842
 
843
  torch.cuda.empty_cache()
 
857
  key='user_input',
858
  placeholder='Message Video Game Assistant',
859
  label_visibility='collapsed',
860
+ args=(gen_text, )
861
  ):
862
  if prompt != ('Exact Search : ' or 'Explain Further : '):
863
  st.chat_message("user").markdown(st.session_state.local_user_input)
 
865
 
866
  ai_response = generate_answer(
867
  st.session_state.local_user_input,
868
+ client,
869
+ collection_name,
870
+ dense_model,
871
+ sparse_model,
 
872
  "\n".join([f'{msg["role"]}: {msg["content"]}' for msg in st.session_state.messages]),
873
  st.session_state.search_strategy,
874
+ st.session_state.documents_only,
875
+ gen_text,
876
+ gen_context_choice,
877
+ gen_question_choice
878
  )
879
 
880
  with st.chat_message("assistant"):