Update app.py
Browse files
app.py
CHANGED
@@ -16,6 +16,7 @@ import numpy as np
|
|
16 |
import pandas as pd
|
17 |
import streamlit as st
|
18 |
from vllm import LLM
|
|
|
19 |
from numpy import ndarray
|
20 |
from datetime import datetime
|
21 |
from typing import List, Dict
|
@@ -72,7 +73,8 @@ from qdrant_client.models import (
|
|
72 |
global_state_documents_only = False
|
73 |
|
74 |
class Question(BaseModel):
|
75 |
-
answer: str
|
|
|
76 |
|
77 |
icon_to_types = {
|
78 |
'ppt':('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADAAAAAwCAYAAABXAvmHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAC4ElEQVR4nO2YS2gTQRzGF/VgPVuPpsXSQ3eshdbXqVDMEtlN1UN8HqTZYouCVJTsiIfgTcWbD0o0mp3GKq1oC55E2AbbKn3dakISCVjwUNtLlZ1a24xs0oqG2uxmJ2si+8F32NnL95v/fx67DGPLli1b/5UUkauJtDsvDHm5XkXkphSvc17xcksqgksYSXMYwSksS71Yhudx6MouphSk+Ju3RLzcGUV0jg6JHFnPGMH1LcMRVZZOkz7P5n8TXuRcisjF/xY8LwDKWpVhDIcgZ1nwiXPCNkXkHuYLrhcAr4EgKUD6LlUUNfxIp3OHIjon9YY3AoCzHl8IXq0sWvgh0RkzEr4AAKK1FHWIbNsYm/lCAfBqJchj/1ZqAEZ6nhIAURHsNhQy0OToCjQ5vj1ochAzDu6vJgOte00DYATJYsh32AiA6fC/Q9AAUJEU1X1O0Aq/ZhoAOAPhO1nWABhJw2UNoMowjdG16oIBcrWy/IPMx6Pk9eV2iyoACQ75OkwDxF4+JdEXT8jMaCTznF5ZJoNtR60BkKWwaYDgwZpfY/FXzzNj0/3IGgAEJ6gCjN29ma3KwDOLAKRZ0wBhrpGglnoy2HaMLHyeyYy98XVaAqAiuGgaIFcf+ns2XMRJAdD0ommA8Xu3yNidG+Td7etk4Gxr3m2ULgA7S3UN6DFlgIlyB+gpcwBQ+EFWqGmFTwggnTqyp6p8AXjwNm/4kgZwg+N6Ab7SCv9oXxUdAB5ME4/eD5rGnRdpQGjhIy21dADcdS0MDSUFEKC8qxAdi/c+Q0ufPAcqEjw7biHA+1Szg95vFU1xV0NlgmdjRZ95no3GhNrtVMP/CQHGijnzcVdDcX4t5rRTdzF6PkW7bTbSR373Ia3cpsPzYJrabmNU2h6ddNedSvBgWDvyDcx2WjthEwJ7gviZTUwpKOaur9YuXQmBDSd5djLJg7kkD76v+ot2Jc68E0CHrruNLVu2bDGlrJ8c/urSuEn7XgAAAABJRU5ErkJggg==',
|
@@ -224,8 +226,9 @@ def main(query: str, client: QdrantClient, collection_name: str, llm, dense_mode
|
|
224 |
|
225 |
context = "\n".join(contents)
|
226 |
print(f'Context : \n + {context}')
|
227 |
-
|
228 |
-
|
|
|
229 |
|
230 |
gen_choice = outlines.generate.choice(llm, choices=['Yes', 'No'])
|
231 |
prompt = route_llm(context, query)
|
@@ -242,9 +245,8 @@ def main(query: str, client: QdrantClient, collection_name: str, llm, dense_mode
|
|
242 |
result_metadatas = "\n\n".join(f'{value}' for value in filtered_metadatas)
|
243 |
|
244 |
prompt = answer_with_context(context, query)
|
245 |
-
answer = gen_text(prompt, max_tokens=300, sampling_params=SamplingParams(temperature=0))
|
246 |
-
|
247 |
-
answer = f"{answer.answer}\n\n\nSource(s) :\n\n{result_metadatas}"
|
248 |
|
249 |
if not st.session_state.documents_only:
|
250 |
answer = f'Documents Based :\n\n{answer}'
|
@@ -255,15 +257,15 @@ def main(query: str, client: QdrantClient, collection_name: str, llm, dense_mode
|
|
255 |
print(f'Choice 2: {action}')
|
256 |
if action == 'General Question':
|
257 |
prompt = open_query_prompt(past_messages, query)
|
258 |
-
answer = gen_text(prompt, max_tokens=300, sampling_params=SamplingParams(temperature=0.6, top_p=0.9, top_k=10))
|
259 |
else:
|
260 |
print(f'GLOBAL STATE : {global_state_documents_only}')
|
261 |
if global_state_documents_only:
|
262 |
prompt = idk(query)
|
263 |
-
answer = gen_text(prompt, max_tokens=128, sampling_params=SamplingParams(temperature=0.6, top_p=0.9, top_k=10))
|
264 |
else:
|
265 |
prompt = self_knowledge(query)
|
266 |
-
answer = gen_text(prompt, max_tokens=300, sampling_params=SamplingParams(temperature=0.6, top_p=0.9, top_k=10))
|
267 |
answer = f'Internal Knowledge :\n\n{answer}'
|
268 |
|
269 |
torch.cuda.empty_cache()
|
|
|
16 |
import pandas as pd
|
17 |
import streamlit as st
|
18 |
from vllm import LLM
|
19 |
+
from json import pyjson
|
20 |
from numpy import ndarray
|
21 |
from datetime import datetime
|
22 |
from typing import List, Dict
|
|
|
73 |
global_state_documents_only = False
|
74 |
|
75 |
class Question(BaseModel):
|
76 |
+
answer: str
|
77 |
+
schema = pyjson.dumps(Question.model_json_schema())
|
78 |
|
79 |
icon_to_types = {
|
80 |
'ppt':('data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADAAAAAwCAYAAABXAvmHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAC4ElEQVR4nO2YS2gTQRzGF/VgPVuPpsXSQ3eshdbXqVDMEtlN1UN8HqTZYouCVJTsiIfgTcWbD0o0mp3GKq1oC55E2AbbKn3dakISCVjwUNtLlZ1a24xs0oqG2uxmJ2si+8F32NnL95v/fx67DGPLli1b/5UUkauJtDsvDHm5XkXkphSvc17xcksqgksYSXMYwSksS71Yhudx6MouphSk+Ju3RLzcGUV0jg6JHFnPGMH1LcMRVZZOkz7P5n8TXuRcisjF/xY8LwDKWpVhDIcgZ1nwiXPCNkXkHuYLrhcAr4EgKUD6LlUUNfxIp3OHIjon9YY3AoCzHl8IXq0sWvgh0RkzEr4AAKK1FHWIbNsYm/lCAfBqJchj/1ZqAEZ6nhIAURHsNhQy0OToCjQ5vj1ochAzDu6vJgOte00DYATJYsh32AiA6fC/Q9AAUJEU1X1O0Aq/ZhoAOAPhO1nWABhJw2UNoMowjdG16oIBcrWy/IPMx6Pk9eV2iyoACQ75OkwDxF4+JdEXT8jMaCTznF5ZJoNtR60BkKWwaYDgwZpfY/FXzzNj0/3IGgAEJ6gCjN29ma3KwDOLAKRZ0wBhrpGglnoy2HaMLHyeyYy98XVaAqAiuGgaIFcf+ns2XMRJAdD0ommA8Xu3yNidG+Td7etk4Gxr3m2ULgA7S3UN6DFlgIlyB+gpcwBQ+EFWqGmFTwggnTqyp6p8AXjwNm/4kgZwg+N6Ab7SCv9oXxUdAB5ME4/eD5rGnRdpQGjhIy21dADcdS0MDSUFEKC8qxAdi/c+Q0ufPAcqEjw7biHA+1Szg95vFU1xV0NlgmdjRZ95no3GhNrtVMP/CQHGijnzcVdDcX4t5rRTdzF6PkW7bTbSR373Ia3cpsPzYJrabmNU2h6ddNedSvBgWDvyDcx2WjthEwJ7gviZTUwpKOaur9YuXQmBDSd5djLJg7kkD76v+ot2Jc68E0CHrruNLVu2bDGlrJ8c/urSuEn7XgAAAABJRU5ErkJggg==',
|
|
|
226 |
|
227 |
context = "\n".join(contents)
|
228 |
print(f'Context : \n + {context}')
|
229 |
+
|
230 |
+
regex = build_regex_from_schema(schema, r"[\n ]?")
|
231 |
+
gen_text = outlines.generate.regex(llm, regex)
|
232 |
|
233 |
gen_choice = outlines.generate.choice(llm, choices=['Yes', 'No'])
|
234 |
prompt = route_llm(context, query)
|
|
|
245 |
result_metadatas = "\n\n".join(f'{value}' for value in filtered_metadatas)
|
246 |
|
247 |
prompt = answer_with_context(context, query)
|
248 |
+
answer = json.loads(gen_text(prompt, max_tokens=300, sampling_params=SamplingParams(temperature=0)))['answer']
|
249 |
+
answer = f"{answer}\n\n\nSource(s) :\n\n{result_metadatas}"
|
|
|
250 |
|
251 |
if not st.session_state.documents_only:
|
252 |
answer = f'Documents Based :\n\n{answer}'
|
|
|
257 |
print(f'Choice 2: {action}')
|
258 |
if action == 'General Question':
|
259 |
prompt = open_query_prompt(past_messages, query)
|
260 |
+
answer = json.loads(gen_text(prompt, max_tokens=300, sampling_params=SamplingParams(temperature=0.6, top_p=0.9, top_k=10)))['answer']
|
261 |
else:
|
262 |
print(f'GLOBAL STATE : {global_state_documents_only}')
|
263 |
if global_state_documents_only:
|
264 |
prompt = idk(query)
|
265 |
+
answer = json.loads(gen_text(prompt, max_tokens=128, sampling_params=SamplingParams(temperature=0.6, top_p=0.9, top_k=10)))['answer']
|
266 |
else:
|
267 |
prompt = self_knowledge(query)
|
268 |
+
answer = json.loads(gen_text(prompt, max_tokens=300, sampling_params=SamplingParams(temperature=0.6, top_p=0.9, top_k=10)))['answer']
|
269 |
answer = f'Internal Knowledge :\n\n{answer}'
|
270 |
|
271 |
torch.cuda.empty_cache()
|