Upload 2 files
Browse files- kadichat.py +532 -0
- requirements.txt +11 -0
kadichat.py
ADDED
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is a demo to show how to use OAuth2 to connect an application to Kadi.
|
3 |
+
|
4 |
+
Read Section "OAuth2 Tokens" in Kadi documents.
|
5 |
+
Ref: https://kadi.readthedocs.io/en/stable/httpapi/intro.html#oauth2-tokens
|
6 |
+
|
7 |
+
Notes:
|
8 |
+
1. register an application in Kadi (Setting->Applications)
|
9 |
+
- Name: KadiOAuthTest
|
10 |
+
- Website URL: http://127.0.0.1:8000
|
11 |
+
- Redirect URIs: http://localhost:8000/auth
|
12 |
+
|
13 |
+
And you will get Client ID and Client Secret, note them down and set in this file.
|
14 |
+
|
15 |
+
2. Start this app, and open browser with address "http://localhost:8000/"
|
16 |
+
|
17 |
+
"""
|
18 |
+
|
19 |
+
import json
|
20 |
+
|
21 |
+
import uvicorn
|
22 |
+
from fastapi import FastAPI, Depends
|
23 |
+
from starlette.responses import RedirectResponse
|
24 |
+
from starlette.middleware.sessions import SessionMiddleware
|
25 |
+
from authlib.integrations.starlette_client import OAuth, OAuthError
|
26 |
+
from fastapi import Request
|
27 |
+
import gradio as gr
|
28 |
+
import kadi_apy
|
29 |
+
from kadi_apy import KadiManager
|
30 |
+
from requests.compat import urljoin
|
31 |
+
from typing import List, Tuple
|
32 |
+
import pymupdf
|
33 |
+
from sentence_transformers import SentenceTransformer
|
34 |
+
import numpy as np
|
35 |
+
import faiss
|
36 |
+
from dotenv import load_dotenv
|
37 |
+
import os
|
38 |
+
|
39 |
+
# Kadi OAuth settings
|
40 |
+
load_dotenv()
|
41 |
+
KADI_CLIENT_ID = os.environ["KADI_CLIENT_ID"]
|
42 |
+
KADI_CLIENT_SECRET = os.environ["KADI_CLIENT_SECRET"]
|
43 |
+
SECRET_KEY = os.environ["SECRET_KEY"]
|
44 |
+
huggingfacehub_api_token = os.environ["huggingfacehub_api_token"]
|
45 |
+
|
46 |
+
from huggingface_hub import login
|
47 |
+
login(token=huggingfacehub_api_token)
|
48 |
+
|
49 |
+
# Set up OAuth
|
50 |
+
app = FastAPI()
|
51 |
+
oauth = OAuth()
|
52 |
+
|
53 |
+
# Set Kadi instance
|
54 |
+
instance = "my_kadi_demo_instance" # "demo kit instance"
|
55 |
+
host = "https://demo-kadi4mat.iam.kit.edu"
|
56 |
+
|
57 |
+
base_url = host
|
58 |
+
oauth.register(
|
59 |
+
name="kadi4mat",
|
60 |
+
client_id=KADI_CLIENT_ID,
|
61 |
+
client_secret=KADI_CLIENT_SECRET,
|
62 |
+
api_base_url=f"{base_url}/api",
|
63 |
+
access_token_url=f"{base_url}/oauth/token",
|
64 |
+
authorize_url=f"{base_url}/oauth/authorize",
|
65 |
+
access_token_params={
|
66 |
+
"client_id": KADI_CLIENT_ID,
|
67 |
+
"client_secret": KADI_CLIENT_SECRET,
|
68 |
+
},
|
69 |
+
)
|
70 |
+
|
71 |
+
# Global LLM client
|
72 |
+
from huggingface_hub import InferenceClient
|
73 |
+
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
|
74 |
+
|
75 |
+
|
76 |
+
embeddings_client = InferenceClient(model="sentence-transformers/all-mpnet-base-v2", token=huggingfacehub_api_token)
|
77 |
+
# embeddings_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", trust_remote_code=True) # unused
|
78 |
+
embeddings_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", trust_remote_code=True)
|
79 |
+
|
80 |
+
# Dependency to get the current user
|
81 |
+
def get_user(request: Request):
|
82 |
+
if "user_access_token" in request.session:
|
83 |
+
token = request.session["user_access_token"]
|
84 |
+
else:
|
85 |
+
token = None
|
86 |
+
return None
|
87 |
+
if token:
|
88 |
+
try:
|
89 |
+
manager = KadiManager(instance=instance, host=host, token=token)
|
90 |
+
user = manager.pat_user
|
91 |
+
return user.meta["displayname"]
|
92 |
+
except kadi_apy.lib.exceptions.KadiAPYRequestError as e:
|
93 |
+
print(e)
|
94 |
+
return None
|
95 |
+
return None # "Authed but Failed at getting user info!"
|
96 |
+
|
97 |
+
|
98 |
+
@app.get("/")
|
99 |
+
def public(request: Request, user=Depends(get_user)):
|
100 |
+
root_url = gr.route_utils.get_root_url(request, "/", None)
|
101 |
+
if user:
|
102 |
+
return RedirectResponse(url=f"{root_url}/gradio/")
|
103 |
+
else:
|
104 |
+
return RedirectResponse(url=f"{root_url}/main/")
|
105 |
+
|
106 |
+
|
107 |
+
@app.route("/logout")
|
108 |
+
async def logout(request: Request):
|
109 |
+
request.session.pop("user", None)
|
110 |
+
request.session.pop("user_id", None)
|
111 |
+
request.session.pop("user_access_token", None)
|
112 |
+
|
113 |
+
return RedirectResponse(url="/")
|
114 |
+
|
115 |
+
|
116 |
+
@app.route("/login")
|
117 |
+
async def login(request: Request):
|
118 |
+
root_url = gr.route_utils.get_root_url(request, "/login", None)
|
119 |
+
redirect_uri = request.url_for("auth") # f"{root_url}/auth"
|
120 |
+
return await oauth.kadi4mat.authorize_redirect(request, redirect_uri)
|
121 |
+
|
122 |
+
|
123 |
+
@app.route("/auth")
|
124 |
+
async def auth(request: Request):
|
125 |
+
# root_url = gr.route_utils.get_root_url(request, "/auth", None)
|
126 |
+
try:
|
127 |
+
access_token = await oauth.kadi4mat.authorize_access_token(request)
|
128 |
+
request.session["user_access_token"] = access_token["access_token"]
|
129 |
+
|
130 |
+
except OAuthError as e:
|
131 |
+
print("Error getting access token", e)
|
132 |
+
return RedirectResponse(url="/")
|
133 |
+
|
134 |
+
return RedirectResponse(url="/gradio")
|
135 |
+
|
136 |
+
|
137 |
+
def greet(request: gr.Request):
|
138 |
+
return f"Welcome to Kadichat, you're logged in as: {request.username}"
|
139 |
+
|
140 |
+
|
141 |
+
def get_files_in_record(record_id, user_token, top_k=10):
|
142 |
+
|
143 |
+
manager = KadiManager(instance=instance, host=host, pat=user_token)
|
144 |
+
|
145 |
+
try:
|
146 |
+
record = manager.record(identifier=record_id)
|
147 |
+
except kadi_apy.lib.exceptions.KadiAPYInputError as e:
|
148 |
+
raise gr.Error(e)
|
149 |
+
|
150 |
+
file_num = record.get_number_files()
|
151 |
+
|
152 |
+
per_page = 100 # default in kadi
|
153 |
+
not_divisible = file_num % per_page
|
154 |
+
if not_divisible:
|
155 |
+
page_num = file_num // per_page + 1
|
156 |
+
else:
|
157 |
+
page_num = file_num // per_page
|
158 |
+
|
159 |
+
file_names = []
|
160 |
+
for p in range(1, page_num + 1): # page starts at 1 in kadi
|
161 |
+
file_names.extend(
|
162 |
+
[
|
163 |
+
info["name"]
|
164 |
+
for info in record.get_filelist(page=p, per_page=per_page).json()[
|
165 |
+
"items"
|
166 |
+
]
|
167 |
+
]
|
168 |
+
)
|
169 |
+
|
170 |
+
assert file_num == len(
|
171 |
+
file_names
|
172 |
+
), "Number of files did not match, please check function get_all_file_names."
|
173 |
+
|
174 |
+
# return file_names[:top_k]
|
175 |
+
return gr.Dropdown(
|
176 |
+
choices=file_names[:top_k],
|
177 |
+
label="Select file",
|
178 |
+
info="Select (max. 3) files to chat with.",
|
179 |
+
multiselect=True,
|
180 |
+
max_choices=3,
|
181 |
+
interactive=True,
|
182 |
+
)
|
183 |
+
|
184 |
+
|
185 |
+
def get_all_records(user_token):
|
186 |
+
|
187 |
+
if not user_token:
|
188 |
+
return []
|
189 |
+
|
190 |
+
manager = KadiManager(instance=instance, host=host, pat=user_token)
|
191 |
+
|
192 |
+
host_api = manager.host if manager.host.endswith("/") else manager.host + "/"
|
193 |
+
searched_resource = "records"
|
194 |
+
endpoint = urljoin(
|
195 |
+
host_api, searched_resource
|
196 |
+
) # e.g https://demo-kadi4mat.iam.kit.edu/api/" + "records"
|
197 |
+
|
198 |
+
response = manager.search.search_resources("record", per_page=100)
|
199 |
+
parsed = json.loads(response.content)
|
200 |
+
|
201 |
+
total_pages = parsed["_pagination"]["total_pages"]
|
202 |
+
|
203 |
+
def get_page_records(parsed_content):
|
204 |
+
item_identifiers = []
|
205 |
+
items = parsed_content["items"]
|
206 |
+
for item in items:
|
207 |
+
item_identifiers.append(item["identifier"])
|
208 |
+
|
209 |
+
return item_identifiers
|
210 |
+
|
211 |
+
all_records_identifiers = []
|
212 |
+
for page in range(1, total_pages + 1):
|
213 |
+
page_endpoint = endpoint + f"?page={page}&per_page=100"
|
214 |
+
response = manager.make_request(page_endpoint)
|
215 |
+
parsed = json.loads(response.content)
|
216 |
+
all_records_identifiers.extend(get_page_records(parsed))
|
217 |
+
|
218 |
+
return gr.Dropdown(
|
219 |
+
choices=all_records_identifiers,
|
220 |
+
interactive=True,
|
221 |
+
label="Record Identifier",
|
222 |
+
info="Select record to get file list",
|
223 |
+
)
|
224 |
+
|
225 |
+
|
226 |
+
def _init_user_token(request: gr.Request):
|
227 |
+
user_token = request.request.session["user_access_token"]
|
228 |
+
return user_token
|
229 |
+
|
230 |
+
|
231 |
+
with gr.Blocks(theme=gr.themes.Ocean()) as login_demo:
|
232 |
+
gr.Markdown(
|
233 |
+
"""<br/><br/><br/><br/><br/><br/><br/><br/>
|
234 |
+
<center>
|
235 |
+
<h1>Welcome to KadiChat!</h1>
|
236 |
+
<br/><br/>
|
237 |
+
<img src="https://i.postimg.cc/qvsQCCLS/kadichat-logo.png" alt="Kadichat logo">
|
238 |
+
<br/><br/>
|
239 |
+
Chat with Record in Kadi.</center>
|
240 |
+
"""
|
241 |
+
)
|
242 |
+
# Note: kadichat-logo is hosted on https://postimage.io/
|
243 |
+
|
244 |
+
with gr.Row():
|
245 |
+
with gr.Column():
|
246 |
+
_btn_placeholder = gr.Button(visible=False)
|
247 |
+
with gr.Column():
|
248 |
+
btn = gr.Button("Sign in with Kadi (demo-instance)")
|
249 |
+
with gr.Column():
|
250 |
+
_btn_placeholder2 = gr.Button(visible=False)
|
251 |
+
|
252 |
+
gr.Markdown(
|
253 |
+
"""<br/><br/><br/><br/>
|
254 |
+
<center>
|
255 |
+
This demo shows how to use
|
256 |
+
<a href="https://kadi4mat.readthedocs.io/en/stable/httpapi/intro.html#oauth2-tokens">OAuth2</a>
|
257 |
+
to have access to Kadi.</center>
|
258 |
+
"""
|
259 |
+
)
|
260 |
+
_js_redirect = """
|
261 |
+
() => {
|
262 |
+
url = '/login' + window.location.search;
|
263 |
+
window.open(url, '_blank');
|
264 |
+
}
|
265 |
+
"""
|
266 |
+
btn.click(None, js=_js_redirect)
|
267 |
+
|
268 |
+
import tempfile
|
269 |
+
import os
|
270 |
+
import pymupdf
|
271 |
+
|
272 |
+
class SimpleRAG:
|
273 |
+
def __init__(self) -> None:
|
274 |
+
self.documents = []
|
275 |
+
self.embeddings_model = None
|
276 |
+
self.embeddings = None
|
277 |
+
self.index = None
|
278 |
+
#self.load_pdf("Brandt et al_2024_Kadi_info_page.pdf")
|
279 |
+
#self.build_vector_db()
|
280 |
+
|
281 |
+
def load_pdf(self, file_path: str) -> None:
|
282 |
+
"""Extracts text from a PDF file and stores it in the property documents by page."""
|
283 |
+
doc = pymupdf.open(file_path)
|
284 |
+
self.documents = []
|
285 |
+
for page_num in range(len(doc)):
|
286 |
+
page = doc[page_num]
|
287 |
+
text = page.get_text()
|
288 |
+
self.documents.append({"page": page_num + 1, "content": text})
|
289 |
+
print("PDF processed successfully!")
|
290 |
+
|
291 |
+
|
292 |
+
def build_vector_db(self) -> None:
|
293 |
+
"""Builds a vector database using the content of the PDF."""
|
294 |
+
if self.embeddings_model is None:
|
295 |
+
self.embeddings_model = SentenceTransformer("jinaai/jina-embeddings-v2-small-en", trust_remote_code=True) # jinaai/jina-embeddings-v2-base-de?
|
296 |
+
# Use embeddings_client
|
297 |
+
print("now doing embedding")
|
298 |
+
print("len of documents", len(self.documents))
|
299 |
+
import time
|
300 |
+
start =time.time()
|
301 |
+
#embedding_responses = embeddings_client.post(json={"inputs":[doc["content"] for doc in self.documents]}, task="feature-extraction")
|
302 |
+
#self.embeddings = np.array(json.loads(embedding_responses.decode()))
|
303 |
+
self.embeddings = self.embeddings_model.encode([doc["content"] for doc in self.documents], show_progress_bar=True)
|
304 |
+
end = time.time()
|
305 |
+
print("cost time", end-start)
|
306 |
+
self.index = faiss.IndexFlatL2(self.embeddings.shape[1])
|
307 |
+
self.index.add(np.array(self.embeddings))
|
308 |
+
print("Vector database built successfully!")
|
309 |
+
|
310 |
+
def search_documents(self, query: str, k: int = 4) -> List[str]:
|
311 |
+
"""Searches for relevant documents using vector similarity."""
|
312 |
+
# query_embedding = self.embeddings_model.encode([query], show_progress_bar=False)
|
313 |
+
embedding_responses = embeddings_client.post(json={"inputs": [query]}, task="feature-extraction")
|
314 |
+
query_embedding = json.loads(embedding_responses.decode())
|
315 |
+
D, I = self.index.search(np.array(query_embedding), k)
|
316 |
+
results = [self.documents[i]["content"] for i in I[0]]
|
317 |
+
return results if results else ["No relevant documents found."]
|
318 |
+
|
319 |
+
def chunk_text(text, chunk_size=2048, overlap_size=256, separators=["\n\n", "\n"]):
|
320 |
+
"""Chunk text into pieces of specified size with overlap, considering separators."""
|
321 |
+
|
322 |
+
# Split the text by the separators
|
323 |
+
for sep in separators:
|
324 |
+
text = text.replace(sep, "\n")
|
325 |
+
|
326 |
+
chunks = []
|
327 |
+
start = 0
|
328 |
+
|
329 |
+
while start < len(text):
|
330 |
+
# Determine the end of the chunk, accounting for overlap and the chunk size
|
331 |
+
end = min(len(text), start + chunk_size)
|
332 |
+
|
333 |
+
# Find a natural break point at the newline to avoid cutting words
|
334 |
+
if end < len(text):
|
335 |
+
while end > start and text[end] != '\n':
|
336 |
+
end -= 1
|
337 |
+
|
338 |
+
chunk = text[start:end].strip() # Strip trailing whitespace
|
339 |
+
chunks.append(chunk)
|
340 |
+
|
341 |
+
# Move the start position forward by the overlap size
|
342 |
+
start += chunk_size - overlap_size
|
343 |
+
|
344 |
+
return chunks
|
345 |
+
|
346 |
+
def load_and_chunk_pdf(file_path):
|
347 |
+
"""Extracts text from a PDF file and stores it in the property documents by chunks."""
|
348 |
+
|
349 |
+
with pymupdf.open(file_path) as pdf:
|
350 |
+
text = ""
|
351 |
+
for page in pdf:
|
352 |
+
text += page.get_text()
|
353 |
+
|
354 |
+
chunks = chunk_text(text)
|
355 |
+
documents = []
|
356 |
+
for chunk in chunks:
|
357 |
+
documents.append({"content": chunk, "metadata": pdf.metadata})
|
358 |
+
|
359 |
+
return documents
|
360 |
+
|
361 |
+
def load_pdf(file_path: str) -> None:
|
362 |
+
"""Extracts text from a PDF file and stores it in the property documents by page."""
|
363 |
+
doc = pymupdf.open(file_path)
|
364 |
+
documents = []
|
365 |
+
for page_num in range(len(doc)):
|
366 |
+
page = doc[page_num]
|
367 |
+
text = page.get_text()
|
368 |
+
documents.append({"page": page_num + 1, "content": text})
|
369 |
+
print("PDF processed successfully!")
|
370 |
+
return documents
|
371 |
+
|
372 |
+
def prepare_file_for_chat(record_id, file_names, token, progress=gr.Progress()):
|
373 |
+
if not file_names:
|
374 |
+
raise gr.Error("No file selected")
|
375 |
+
progress(0, desc="Starting")
|
376 |
+
# Create connection to kadi
|
377 |
+
manager = KadiManager(instance=instance, host=host, pat=token)
|
378 |
+
record = manager.record(identifier=record_id)
|
379 |
+
progress(0.2, desc="Loading files...")
|
380 |
+
# Parse files
|
381 |
+
documents = []
|
382 |
+
# Download
|
383 |
+
for file_name in file_names:
|
384 |
+
file_id = record.get_file_id(file_name)
|
385 |
+
with tempfile.TemporaryDirectory(prefix="tmp-kadichat-downloads-") as temp_dir:
|
386 |
+
print(temp_dir)
|
387 |
+
temp_file_location = os.path.join(temp_dir, file_name)
|
388 |
+
record.download_file(file_id, temp_file_location)
|
389 |
+
# parse document
|
390 |
+
docs = load_and_chunk_pdf(temp_file_location)
|
391 |
+
documents.extend(docs)
|
392 |
+
|
393 |
+
progress(0.4, desc="Embedding documents...")
|
394 |
+
user_rag = SimpleRAG()
|
395 |
+
user_rag.documents = documents
|
396 |
+
user_rag.embeddings_model = embeddings_model
|
397 |
+
user_rag.build_vector_db()
|
398 |
+
# print(documents[:2])
|
399 |
+
print("user rag created")
|
400 |
+
progress(1, desc="ready to chat")
|
401 |
+
return "ready to chat", user_rag
|
402 |
+
|
403 |
+
def preprocess_response(response: str) -> str:
|
404 |
+
"""Preprocesses the response to make it more polished."""
|
405 |
+
# response = response.strip()
|
406 |
+
# response = response.replace("\n\n", "\n")
|
407 |
+
# response = response.replace(" ,", ",")
|
408 |
+
# response = response.replace(" .", ".")
|
409 |
+
# response = " ".join(response.split())
|
410 |
+
# if not any(word in response.lower() for word in ["sorry", "apologize", "empathy"]):
|
411 |
+
# response = "I'm here to help. " + response
|
412 |
+
return response
|
413 |
+
|
414 |
+
|
415 |
+
def respond(message: str, history: List[Tuple[str, str]], user_session_rag):
|
416 |
+
|
417 |
+
# message is the current input query from user
|
418 |
+
# RAG
|
419 |
+
retrieved_docs = user_session_rag.search_documents(message)
|
420 |
+
context = "\n".join(retrieved_docs)
|
421 |
+
system_message = "You are an assistant to help user to answer question related to Kadi based on Relevant documents.\nRelevant documents: {}".format(context)
|
422 |
+
messages = [{"role": "assistant", "content": system_message}]
|
423 |
+
|
424 |
+
# Add history for conversational chat, TODO
|
425 |
+
# for val in history:
|
426 |
+
# #if val[0]:
|
427 |
+
# messages.append({"role": "user", "content": val[0]})
|
428 |
+
# #if val[1]:
|
429 |
+
# messages.append({"role": "assistant", "content": val[1]})
|
430 |
+
|
431 |
+
messages.append({"role": "user", "content": f"\nQuestion: {message}"})
|
432 |
+
|
433 |
+
print("-----------------")
|
434 |
+
print(messages)
|
435 |
+
print("-----------------")
|
436 |
+
# Get anwser from LLM
|
437 |
+
response = client.chat_completion(messages, max_tokens=2048, temperature=0.0) #, top_p=0.9)
|
438 |
+
response_content = "".join([choice.message['content'] for choice in response.choices if 'content' in choice.message])
|
439 |
+
|
440 |
+
# Process response
|
441 |
+
polished_response = preprocess_response(response_content)
|
442 |
+
|
443 |
+
history.append((message, polished_response))
|
444 |
+
return history, ""
|
445 |
+
|
446 |
+
|
447 |
+
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY)
|
448 |
+
app = gr.mount_gradio_app(app, login_demo, path="/main")
|
449 |
+
|
450 |
+
# Gradio interface
|
451 |
+
with gr.Blocks(theme=gr.themes.Ocean()) as main_demo:
|
452 |
+
|
453 |
+
# State for storing user token
|
454 |
+
_state_user_token = gr.State([])
|
455 |
+
|
456 |
+
user_session_rag = gr.State(
|
457 |
+
"placeholder", time_to_live=3600
|
458 |
+
) # clean state after 1h
|
459 |
+
|
460 |
+
with gr.Row():
|
461 |
+
with gr.Column(scale=7):
|
462 |
+
m = gr.Markdown("Welcome to Chatbot!")
|
463 |
+
main_demo.load(greet, None, m)
|
464 |
+
with gr.Column(scale=1):
|
465 |
+
gr.Button("Logout", link="/logout")
|
466 |
+
|
467 |
+
with gr.Tab("Main"):
|
468 |
+
with gr.Row():
|
469 |
+
with gr.Column(scale=7):
|
470 |
+
chatbot = gr.Chatbot()
|
471 |
+
|
472 |
+
with gr.Column(scale=3):
|
473 |
+
record_list = gr.Dropdown(label="Record Identifier")
|
474 |
+
record_file_dropdown = gr.Dropdown(
|
475 |
+
choices=[""],
|
476 |
+
label="Select file",
|
477 |
+
info="Select (max. 3) files to chat with.",
|
478 |
+
multiselect=True,
|
479 |
+
max_choices=3,
|
480 |
+
)
|
481 |
+
|
482 |
+
gr.Markdown(" " * 200)
|
483 |
+
# Use .then to ensure get token first
|
484 |
+
main_demo.load(_init_user_token, None, _state_user_token).then(
|
485 |
+
get_all_records, _state_user_token, record_list
|
486 |
+
)
|
487 |
+
|
488 |
+
parse_files = gr.Button("Parse files")
|
489 |
+
# message_box = gr.Markdown("")
|
490 |
+
message_box = gr.Textbox(label="", value="progress bar", interactive=False)
|
491 |
+
# Interactions
|
492 |
+
# Update file list after selecting record
|
493 |
+
record_list.select(
|
494 |
+
fn=get_files_in_record,
|
495 |
+
inputs=[record_list, _state_user_token],
|
496 |
+
outputs=record_file_dropdown,
|
497 |
+
)
|
498 |
+
# Prepare files for chatbot
|
499 |
+
parse_files.click(fn=prepare_file_for_chat, inputs=[record_list, record_file_dropdown, _state_user_token], outputs=[message_box, user_session_rag])
|
500 |
+
|
501 |
+
with gr.Row():
|
502 |
+
txt_input = gr.Textbox(
|
503 |
+
show_label=False,
|
504 |
+
placeholder="Type your question here...",
|
505 |
+
lines=1
|
506 |
+
)
|
507 |
+
submit_btn = gr.Button("Submit", scale=1)
|
508 |
+
refresh_btn = gr.Button("Refresh Chat", scale=1, variant="secondary")
|
509 |
+
|
510 |
+
example_questions = [
|
511 |
+
["Summarize the paper."],
|
512 |
+
["how to create record in kadi4mat?"],
|
513 |
+
]
|
514 |
+
|
515 |
+
gr.Examples(examples=example_questions, inputs=[txt_input])
|
516 |
+
|
517 |
+
txt_input.submit(fn=respond, inputs=[txt_input, chatbot, user_session_rag], outputs=[chatbot, txt_input])
|
518 |
+
submit_btn.click(fn=respond, inputs=[txt_input, chatbot, user_session_rag], outputs=[chatbot, txt_input])
|
519 |
+
refresh_btn.click(lambda: [], None, chatbot)
|
520 |
+
|
521 |
+
app = gr.mount_gradio_app(app, main_demo, path="/gradio", auth_dependency=get_user)
|
522 |
+
|
523 |
+
|
524 |
+
# def launch_gradio():
|
525 |
+
# login_demo.launch(share=True)
|
526 |
+
|
527 |
+
# import threading
|
528 |
+
|
529 |
+
if __name__ == "__main__":
|
530 |
+
# Launch Gradio with share=True in a separate thread
|
531 |
+
# threading.Thread(target=launch_gradio).start()
|
532 |
+
uvicorn.run(app, port=8000, host="localhost")
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
uvicorn
|
2 |
+
fastapi
|
3 |
+
authlib
|
4 |
+
httpx
|
5 |
+
gradio
|
6 |
+
pymupdf
|
7 |
+
sentence-transformers
|
8 |
+
faiss-cpu
|
9 |
+
python-dotenv
|
10 |
+
itsdangerous
|
11 |
+
kadi-apy
|