Kadi-IAM commited on
Commit
0d3d29d
1 Parent(s): 2bf359b

Upload 2 files

Browse files
Files changed (2) hide show
  1. kadichat.py +532 -0
  2. requirements.txt +11 -0
kadichat.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is a demo to show how to use OAuth2 to connect an application to Kadi.
3
+
4
+ Read Section "OAuth2 Tokens" in Kadi documents.
5
+ Ref: https://kadi.readthedocs.io/en/stable/httpapi/intro.html#oauth2-tokens
6
+
7
+ Notes:
8
+ 1. register an application in Kadi (Setting->Applications)
9
+ - Name: KadiOAuthTest
10
+ - Website URL: http://127.0.0.1:8000
11
+ - Redirect URIs: http://localhost:8000/auth
12
+
13
+ And you will get Client ID and Client Secret, note them down and set in this file.
14
+
15
+ 2. Start this app, and open browser with address "http://localhost:8000/"
16
+
17
+ """
18
+
19
+ import json
20
+
21
+ import uvicorn
22
+ from fastapi import FastAPI, Depends
23
+ from starlette.responses import RedirectResponse
24
+ from starlette.middleware.sessions import SessionMiddleware
25
+ from authlib.integrations.starlette_client import OAuth, OAuthError
26
+ from fastapi import Request
27
+ import gradio as gr
28
+ import kadi_apy
29
+ from kadi_apy import KadiManager
30
+ from requests.compat import urljoin
31
+ from typing import List, Tuple
32
+ import pymupdf
33
+ from sentence_transformers import SentenceTransformer
34
+ import numpy as np
35
+ import faiss
36
+ from dotenv import load_dotenv
37
+ import os
38
+
39
+ # Kadi OAuth settings
40
+ load_dotenv()
41
+ KADI_CLIENT_ID = os.environ["KADI_CLIENT_ID"]
42
+ KADI_CLIENT_SECRET = os.environ["KADI_CLIENT_SECRET"]
43
+ SECRET_KEY = os.environ["SECRET_KEY"]
44
+ huggingfacehub_api_token = os.environ["huggingfacehub_api_token"]
45
+
46
+ from huggingface_hub import login
47
+ login(token=huggingfacehub_api_token)
48
+
49
+ # Set up OAuth
50
+ app = FastAPI()
51
+ oauth = OAuth()
52
+
53
+ # Set Kadi instance
54
+ instance = "my_kadi_demo_instance" # "demo kit instance"
55
+ host = "https://demo-kadi4mat.iam.kit.edu"
56
+
57
+ base_url = host
58
+ oauth.register(
59
+ name="kadi4mat",
60
+ client_id=KADI_CLIENT_ID,
61
+ client_secret=KADI_CLIENT_SECRET,
62
+ api_base_url=f"{base_url}/api",
63
+ access_token_url=f"{base_url}/oauth/token",
64
+ authorize_url=f"{base_url}/oauth/authorize",
65
+ access_token_params={
66
+ "client_id": KADI_CLIENT_ID,
67
+ "client_secret": KADI_CLIENT_SECRET,
68
+ },
69
+ )
70
+
71
+ # Global LLM client
72
+ from huggingface_hub import InferenceClient
73
+ client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
74
+
75
+
76
+ embeddings_client = InferenceClient(model="sentence-transformers/all-mpnet-base-v2", token=huggingfacehub_api_token)
77
+ # embeddings_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", trust_remote_code=True) # unused
78
+ embeddings_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", trust_remote_code=True)
79
+
80
+ # Dependency to get the current user
81
+ def get_user(request: Request):
82
+ if "user_access_token" in request.session:
83
+ token = request.session["user_access_token"]
84
+ else:
85
+ token = None
86
+ return None
87
+ if token:
88
+ try:
89
+ manager = KadiManager(instance=instance, host=host, token=token)
90
+ user = manager.pat_user
91
+ return user.meta["displayname"]
92
+ except kadi_apy.lib.exceptions.KadiAPYRequestError as e:
93
+ print(e)
94
+ return None
95
+ return None # "Authed but Failed at getting user info!"
96
+
97
+
98
+ @app.get("/")
99
+ def public(request: Request, user=Depends(get_user)):
100
+ root_url = gr.route_utils.get_root_url(request, "/", None)
101
+ if user:
102
+ return RedirectResponse(url=f"{root_url}/gradio/")
103
+ else:
104
+ return RedirectResponse(url=f"{root_url}/main/")
105
+
106
+
107
+ @app.route("/logout")
108
+ async def logout(request: Request):
109
+ request.session.pop("user", None)
110
+ request.session.pop("user_id", None)
111
+ request.session.pop("user_access_token", None)
112
+
113
+ return RedirectResponse(url="/")
114
+
115
+
116
+ @app.route("/login")
117
+ async def login(request: Request):
118
+ root_url = gr.route_utils.get_root_url(request, "/login", None)
119
+ redirect_uri = request.url_for("auth") # f"{root_url}/auth"
120
+ return await oauth.kadi4mat.authorize_redirect(request, redirect_uri)
121
+
122
+
123
+ @app.route("/auth")
124
+ async def auth(request: Request):
125
+ # root_url = gr.route_utils.get_root_url(request, "/auth", None)
126
+ try:
127
+ access_token = await oauth.kadi4mat.authorize_access_token(request)
128
+ request.session["user_access_token"] = access_token["access_token"]
129
+
130
+ except OAuthError as e:
131
+ print("Error getting access token", e)
132
+ return RedirectResponse(url="/")
133
+
134
+ return RedirectResponse(url="/gradio")
135
+
136
+
137
+ def greet(request: gr.Request):
138
+ return f"Welcome to Kadichat, you're logged in as: {request.username}"
139
+
140
+
141
+ def get_files_in_record(record_id, user_token, top_k=10):
142
+
143
+ manager = KadiManager(instance=instance, host=host, pat=user_token)
144
+
145
+ try:
146
+ record = manager.record(identifier=record_id)
147
+ except kadi_apy.lib.exceptions.KadiAPYInputError as e:
148
+ raise gr.Error(e)
149
+
150
+ file_num = record.get_number_files()
151
+
152
+ per_page = 100 # default in kadi
153
+ not_divisible = file_num % per_page
154
+ if not_divisible:
155
+ page_num = file_num // per_page + 1
156
+ else:
157
+ page_num = file_num // per_page
158
+
159
+ file_names = []
160
+ for p in range(1, page_num + 1): # page starts at 1 in kadi
161
+ file_names.extend(
162
+ [
163
+ info["name"]
164
+ for info in record.get_filelist(page=p, per_page=per_page).json()[
165
+ "items"
166
+ ]
167
+ ]
168
+ )
169
+
170
+ assert file_num == len(
171
+ file_names
172
+ ), "Number of files did not match, please check function get_all_file_names."
173
+
174
+ # return file_names[:top_k]
175
+ return gr.Dropdown(
176
+ choices=file_names[:top_k],
177
+ label="Select file",
178
+ info="Select (max. 3) files to chat with.",
179
+ multiselect=True,
180
+ max_choices=3,
181
+ interactive=True,
182
+ )
183
+
184
+
185
+ def get_all_records(user_token):
186
+
187
+ if not user_token:
188
+ return []
189
+
190
+ manager = KadiManager(instance=instance, host=host, pat=user_token)
191
+
192
+ host_api = manager.host if manager.host.endswith("/") else manager.host + "/"
193
+ searched_resource = "records"
194
+ endpoint = urljoin(
195
+ host_api, searched_resource
196
+ ) # e.g https://demo-kadi4mat.iam.kit.edu/api/" + "records"
197
+
198
+ response = manager.search.search_resources("record", per_page=100)
199
+ parsed = json.loads(response.content)
200
+
201
+ total_pages = parsed["_pagination"]["total_pages"]
202
+
203
+ def get_page_records(parsed_content):
204
+ item_identifiers = []
205
+ items = parsed_content["items"]
206
+ for item in items:
207
+ item_identifiers.append(item["identifier"])
208
+
209
+ return item_identifiers
210
+
211
+ all_records_identifiers = []
212
+ for page in range(1, total_pages + 1):
213
+ page_endpoint = endpoint + f"?page={page}&per_page=100"
214
+ response = manager.make_request(page_endpoint)
215
+ parsed = json.loads(response.content)
216
+ all_records_identifiers.extend(get_page_records(parsed))
217
+
218
+ return gr.Dropdown(
219
+ choices=all_records_identifiers,
220
+ interactive=True,
221
+ label="Record Identifier",
222
+ info="Select record to get file list",
223
+ )
224
+
225
+
226
+ def _init_user_token(request: gr.Request):
227
+ user_token = request.request.session["user_access_token"]
228
+ return user_token
229
+
230
+
231
+ with gr.Blocks(theme=gr.themes.Ocean()) as login_demo:
232
+ gr.Markdown(
233
+ """<br/><br/><br/><br/><br/><br/><br/><br/>
234
+ <center>
235
+ <h1>Welcome to KadiChat!</h1>
236
+ <br/><br/>
237
+ <img src="https://i.postimg.cc/qvsQCCLS/kadichat-logo.png" alt="Kadichat logo">
238
+ <br/><br/>
239
+ Chat with Record in Kadi.</center>
240
+ """
241
+ )
242
+ # Note: kadichat-logo is hosted on https://postimage.io/
243
+
244
+ with gr.Row():
245
+ with gr.Column():
246
+ _btn_placeholder = gr.Button(visible=False)
247
+ with gr.Column():
248
+ btn = gr.Button("Sign in with Kadi (demo-instance)")
249
+ with gr.Column():
250
+ _btn_placeholder2 = gr.Button(visible=False)
251
+
252
+ gr.Markdown(
253
+ """<br/><br/><br/><br/>
254
+ <center>
255
+ This demo shows how to use
256
+ <a href="https://kadi4mat.readthedocs.io/en/stable/httpapi/intro.html#oauth2-tokens">OAuth2</a>
257
+ to have access to Kadi.</center>
258
+ """
259
+ )
260
+ _js_redirect = """
261
+ () => {
262
+ url = '/login' + window.location.search;
263
+ window.open(url, '_blank');
264
+ }
265
+ """
266
+ btn.click(None, js=_js_redirect)
267
+
268
+ import tempfile
269
+ import os
270
+ import pymupdf
271
+
272
+ class SimpleRAG:
273
+ def __init__(self) -> None:
274
+ self.documents = []
275
+ self.embeddings_model = None
276
+ self.embeddings = None
277
+ self.index = None
278
+ #self.load_pdf("Brandt et al_2024_Kadi_info_page.pdf")
279
+ #self.build_vector_db()
280
+
281
+ def load_pdf(self, file_path: str) -> None:
282
+ """Extracts text from a PDF file and stores it in the property documents by page."""
283
+ doc = pymupdf.open(file_path)
284
+ self.documents = []
285
+ for page_num in range(len(doc)):
286
+ page = doc[page_num]
287
+ text = page.get_text()
288
+ self.documents.append({"page": page_num + 1, "content": text})
289
+ print("PDF processed successfully!")
290
+
291
+
292
+ def build_vector_db(self) -> None:
293
+ """Builds a vector database using the content of the PDF."""
294
+ if self.embeddings_model is None:
295
+ self.embeddings_model = SentenceTransformer("jinaai/jina-embeddings-v2-small-en", trust_remote_code=True) # jinaai/jina-embeddings-v2-base-de?
296
+ # Use embeddings_client
297
+ print("now doing embedding")
298
+ print("len of documents", len(self.documents))
299
+ import time
300
+ start =time.time()
301
+ #embedding_responses = embeddings_client.post(json={"inputs":[doc["content"] for doc in self.documents]}, task="feature-extraction")
302
+ #self.embeddings = np.array(json.loads(embedding_responses.decode()))
303
+ self.embeddings = self.embeddings_model.encode([doc["content"] for doc in self.documents], show_progress_bar=True)
304
+ end = time.time()
305
+ print("cost time", end-start)
306
+ self.index = faiss.IndexFlatL2(self.embeddings.shape[1])
307
+ self.index.add(np.array(self.embeddings))
308
+ print("Vector database built successfully!")
309
+
310
+ def search_documents(self, query: str, k: int = 4) -> List[str]:
311
+ """Searches for relevant documents using vector similarity."""
312
+ # query_embedding = self.embeddings_model.encode([query], show_progress_bar=False)
313
+ embedding_responses = embeddings_client.post(json={"inputs": [query]}, task="feature-extraction")
314
+ query_embedding = json.loads(embedding_responses.decode())
315
+ D, I = self.index.search(np.array(query_embedding), k)
316
+ results = [self.documents[i]["content"] for i in I[0]]
317
+ return results if results else ["No relevant documents found."]
318
+
319
+ def chunk_text(text, chunk_size=2048, overlap_size=256, separators=["\n\n", "\n"]):
320
+ """Chunk text into pieces of specified size with overlap, considering separators."""
321
+
322
+ # Split the text by the separators
323
+ for sep in separators:
324
+ text = text.replace(sep, "\n")
325
+
326
+ chunks = []
327
+ start = 0
328
+
329
+ while start < len(text):
330
+ # Determine the end of the chunk, accounting for overlap and the chunk size
331
+ end = min(len(text), start + chunk_size)
332
+
333
+ # Find a natural break point at the newline to avoid cutting words
334
+ if end < len(text):
335
+ while end > start and text[end] != '\n':
336
+ end -= 1
337
+
338
+ chunk = text[start:end].strip() # Strip trailing whitespace
339
+ chunks.append(chunk)
340
+
341
+ # Move the start position forward by the overlap size
342
+ start += chunk_size - overlap_size
343
+
344
+ return chunks
345
+
346
+ def load_and_chunk_pdf(file_path):
347
+ """Extracts text from a PDF file and stores it in the property documents by chunks."""
348
+
349
+ with pymupdf.open(file_path) as pdf:
350
+ text = ""
351
+ for page in pdf:
352
+ text += page.get_text()
353
+
354
+ chunks = chunk_text(text)
355
+ documents = []
356
+ for chunk in chunks:
357
+ documents.append({"content": chunk, "metadata": pdf.metadata})
358
+
359
+ return documents
360
+
361
+ def load_pdf(file_path: str) -> None:
362
+ """Extracts text from a PDF file and stores it in the property documents by page."""
363
+ doc = pymupdf.open(file_path)
364
+ documents = []
365
+ for page_num in range(len(doc)):
366
+ page = doc[page_num]
367
+ text = page.get_text()
368
+ documents.append({"page": page_num + 1, "content": text})
369
+ print("PDF processed successfully!")
370
+ return documents
371
+
372
+ def prepare_file_for_chat(record_id, file_names, token, progress=gr.Progress()):
373
+ if not file_names:
374
+ raise gr.Error("No file selected")
375
+ progress(0, desc="Starting")
376
+ # Create connection to kadi
377
+ manager = KadiManager(instance=instance, host=host, pat=token)
378
+ record = manager.record(identifier=record_id)
379
+ progress(0.2, desc="Loading files...")
380
+ # Parse files
381
+ documents = []
382
+ # Download
383
+ for file_name in file_names:
384
+ file_id = record.get_file_id(file_name)
385
+ with tempfile.TemporaryDirectory(prefix="tmp-kadichat-downloads-") as temp_dir:
386
+ print(temp_dir)
387
+ temp_file_location = os.path.join(temp_dir, file_name)
388
+ record.download_file(file_id, temp_file_location)
389
+ # parse document
390
+ docs = load_and_chunk_pdf(temp_file_location)
391
+ documents.extend(docs)
392
+
393
+ progress(0.4, desc="Embedding documents...")
394
+ user_rag = SimpleRAG()
395
+ user_rag.documents = documents
396
+ user_rag.embeddings_model = embeddings_model
397
+ user_rag.build_vector_db()
398
+ # print(documents[:2])
399
+ print("user rag created")
400
+ progress(1, desc="ready to chat")
401
+ return "ready to chat", user_rag
402
+
403
+ def preprocess_response(response: str) -> str:
404
+ """Preprocesses the response to make it more polished."""
405
+ # response = response.strip()
406
+ # response = response.replace("\n\n", "\n")
407
+ # response = response.replace(" ,", ",")
408
+ # response = response.replace(" .", ".")
409
+ # response = " ".join(response.split())
410
+ # if not any(word in response.lower() for word in ["sorry", "apologize", "empathy"]):
411
+ # response = "I'm here to help. " + response
412
+ return response
413
+
414
+
415
+ def respond(message: str, history: List[Tuple[str, str]], user_session_rag):
416
+
417
+ # message is the current input query from user
418
+ # RAG
419
+ retrieved_docs = user_session_rag.search_documents(message)
420
+ context = "\n".join(retrieved_docs)
421
+ system_message = "You are an assistant to help user to answer question related to Kadi based on Relevant documents.\nRelevant documents: {}".format(context)
422
+ messages = [{"role": "assistant", "content": system_message}]
423
+
424
+ # Add history for conversational chat, TODO
425
+ # for val in history:
426
+ # #if val[0]:
427
+ # messages.append({"role": "user", "content": val[0]})
428
+ # #if val[1]:
429
+ # messages.append({"role": "assistant", "content": val[1]})
430
+
431
+ messages.append({"role": "user", "content": f"\nQuestion: {message}"})
432
+
433
+ print("-----------------")
434
+ print(messages)
435
+ print("-----------------")
436
+ # Get anwser from LLM
437
+ response = client.chat_completion(messages, max_tokens=2048, temperature=0.0) #, top_p=0.9)
438
+ response_content = "".join([choice.message['content'] for choice in response.choices if 'content' in choice.message])
439
+
440
+ # Process response
441
+ polished_response = preprocess_response(response_content)
442
+
443
+ history.append((message, polished_response))
444
+ return history, ""
445
+
446
+
447
+ app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY)
448
+ app = gr.mount_gradio_app(app, login_demo, path="/main")
449
+
450
+ # Gradio interface
451
+ with gr.Blocks(theme=gr.themes.Ocean()) as main_demo:
452
+
453
+ # State for storing user token
454
+ _state_user_token = gr.State([])
455
+
456
+ user_session_rag = gr.State(
457
+ "placeholder", time_to_live=3600
458
+ ) # clean state after 1h
459
+
460
+ with gr.Row():
461
+ with gr.Column(scale=7):
462
+ m = gr.Markdown("Welcome to Chatbot!")
463
+ main_demo.load(greet, None, m)
464
+ with gr.Column(scale=1):
465
+ gr.Button("Logout", link="/logout")
466
+
467
+ with gr.Tab("Main"):
468
+ with gr.Row():
469
+ with gr.Column(scale=7):
470
+ chatbot = gr.Chatbot()
471
+
472
+ with gr.Column(scale=3):
473
+ record_list = gr.Dropdown(label="Record Identifier")
474
+ record_file_dropdown = gr.Dropdown(
475
+ choices=[""],
476
+ label="Select file",
477
+ info="Select (max. 3) files to chat with.",
478
+ multiselect=True,
479
+ max_choices=3,
480
+ )
481
+
482
+ gr.Markdown(" " * 200)
483
+ # Use .then to ensure get token first
484
+ main_demo.load(_init_user_token, None, _state_user_token).then(
485
+ get_all_records, _state_user_token, record_list
486
+ )
487
+
488
+ parse_files = gr.Button("Parse files")
489
+ # message_box = gr.Markdown("")
490
+ message_box = gr.Textbox(label="", value="progress bar", interactive=False)
491
+ # Interactions
492
+ # Update file list after selecting record
493
+ record_list.select(
494
+ fn=get_files_in_record,
495
+ inputs=[record_list, _state_user_token],
496
+ outputs=record_file_dropdown,
497
+ )
498
+ # Prepare files for chatbot
499
+ parse_files.click(fn=prepare_file_for_chat, inputs=[record_list, record_file_dropdown, _state_user_token], outputs=[message_box, user_session_rag])
500
+
501
+ with gr.Row():
502
+ txt_input = gr.Textbox(
503
+ show_label=False,
504
+ placeholder="Type your question here...",
505
+ lines=1
506
+ )
507
+ submit_btn = gr.Button("Submit", scale=1)
508
+ refresh_btn = gr.Button("Refresh Chat", scale=1, variant="secondary")
509
+
510
+ example_questions = [
511
+ ["Summarize the paper."],
512
+ ["how to create record in kadi4mat?"],
513
+ ]
514
+
515
+ gr.Examples(examples=example_questions, inputs=[txt_input])
516
+
517
+ txt_input.submit(fn=respond, inputs=[txt_input, chatbot, user_session_rag], outputs=[chatbot, txt_input])
518
+ submit_btn.click(fn=respond, inputs=[txt_input, chatbot, user_session_rag], outputs=[chatbot, txt_input])
519
+ refresh_btn.click(lambda: [], None, chatbot)
520
+
521
+ app = gr.mount_gradio_app(app, main_demo, path="/gradio", auth_dependency=get_user)
522
+
523
+
524
+ # def launch_gradio():
525
+ # login_demo.launch(share=True)
526
+
527
+ # import threading
528
+
529
+ if __name__ == "__main__":
530
+ # Launch Gradio with share=True in a separate thread
531
+ # threading.Thread(target=launch_gradio).start()
532
+ uvicorn.run(app, port=8000, host="localhost")
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ uvicorn
2
+ fastapi
3
+ authlib
4
+ httpx
5
+ gradio
6
+ pymupdf
7
+ sentence-transformers
8
+ faiss-cpu
9
+ python-dotenv
10
+ itsdangerous
11
+ kadi-apy