""" This is a demo to show how to use OAuth2 to connect an application to Kadi. Read Section "OAuth2 Tokens" in Kadi documents. Ref: https://kadi.readthedocs.io/en/stable/httpapi/intro.html#oauth2-tokens Notes: 1. register an application in Kadi (Setting->Applications) - Name: KadiOAuthTest - Website URL: http://127.0.0.1:7860 - Redirect URIs: http://localhost:7860/auth And you will get Client ID and Client Secret, note them down and set in this file. 2. Start this app, and open browser with address "http://localhost:7860/" - if you are starting this app on Huggingface, use "start.py" instead. """ import json import uvicorn import gradio as gr import kadi_apy import pymupdf import numpy as np import faiss import os import tempfile import pymupdf from fastapi import FastAPI, Depends from starlette.responses import RedirectResponse from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client import OAuth, OAuthError from fastapi import Request from kadi_apy import KadiManager from requests.compat import urljoin from typing import List, Tuple from sentence_transformers import SentenceTransformer from dotenv import load_dotenv # Kadi OAuth settings load_dotenv() KADI_CLIENT_ID = os.environ["KADI_CLIENT_ID"] KADI_CLIENT_SECRET = os.environ["KADI_CLIENT_SECRET"] SECRET_KEY = os.environ["SECRET_KEY"] huggingfacehub_api_token = os.environ["huggingfacehub_api_token"] from huggingface_hub import login login(token=huggingfacehub_api_token) # Set up OAuth app = FastAPI() oauth = OAuth() # Set Kadi instance instance = "my_instance" # "demo kit instance" host = "https://demo-kadi4mat.iam.kit.edu" # Register oauth base_url = host oauth.register( name="kadi4mat", client_id=KADI_CLIENT_ID, client_secret=KADI_CLIENT_SECRET, api_base_url=f"{base_url}/api", access_token_url=f"{base_url}/oauth/token", authorize_url=f"{base_url}/oauth/authorize", access_token_params={ "client_id": KADI_CLIENT_ID, "client_secret": KADI_CLIENT_SECRET, }, ) # Global LLM client from huggingface_hub import InferenceClient client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") # Mixed-usage of huggingface client and local model for showing 2 possibilities embeddings_client = InferenceClient( model="sentence-transformers/all-mpnet-base-v2", token=huggingfacehub_api_token ) embeddings_model = SentenceTransformer( "sentence-transformers/all-mpnet-base-v2", trust_remote_code=True ) # Dependency to get the current user def get_user(request: Request): """Validate and get user information.""" if "user_access_token" in request.session: token = request.session["user_access_token"] else: token = None return None if token: try: manager = KadiManager(instance=instance, host=host, token=token) user = manager.pat_user return user.meta["displayname"] except kadi_apy.lib.exceptions.KadiAPYRequestError as e: print(e) return None return None # "Authed but Failed at getting user info!" @app.get("/") def public(request: Request, user=Depends(get_user)): """Main extrance of app.""" root_url = gr.route_utils.get_root_url(request, "/", None) # print("root url", root_url) if user: return RedirectResponse(url=f"{root_url}/gradio/") else: return RedirectResponse(url=f"{root_url}/main/") # Logout @app.route("/logout") async def logout(request: Request): request.session.pop("user", None) request.session.pop("user_id", None) request.session.pop("user_access_token", None) return RedirectResponse(url="/") # Login @app.route("/login") async def login(request: Request): root_url = gr.route_utils.get_root_url(request, "/login", None) redirect_uri = request.url_for("auth") # f"{root_url}/auth" redirect_uri = redirect_uri.replace(scheme="https") # required by Kadi # print("-----------in login") # print("root_urlt", root_url) # print("redirect_uri", redirect_uri) # print("request", request) return await oauth.kadi4mat.authorize_redirect(request, redirect_uri) # Get auth @app.route("/auth") async def auth(request: Request): root_url = gr.route_utils.get_root_url(request, "/auth", None) # print("*****+ in auth") # print("root_urlt", root_url) # print("request", request) try: access_token = await oauth.kadi4mat.authorize_access_token(request) request.session["user_access_token"] = access_token["access_token"] except OAuthError as e: print("Error getting access token", e) return RedirectResponse(url="/") return RedirectResponse(url="/gradio") def greet(request: gr.Request): """Show greeting message.""" return f"Welcome to Kadichat, you're logged in as: {request.username}" def get_files_in_record(record_id, user_token, top_k=10): """Get all file list within one record.""" manager = KadiManager(instance=instance, host=host, pat=user_token) try: record = manager.record(identifier=record_id) except kadi_apy.lib.exceptions.KadiAPYInputError as e: raise gr.Error(e) file_num = record.get_number_files() per_page = 100 # default in kadi not_divisible = file_num % per_page if not_divisible: page_num = file_num // per_page + 1 else: page_num = file_num // per_page file_names = [] for p in range(1, page_num + 1): # page starts at 1 in kadi file_names.extend( [ info["name"] for info in record.get_filelist(page=p, per_page=per_page).json()[ "items" ] ] ) assert file_num == len( file_names ), "Number of files did not match, please check function get_all_file_names." # return file_names[:top_k] return gr.Dropdown( choices=file_names[:top_k], label="Select file", info="Select (max. 3) files to chat with.", multiselect=True, max_choices=3, interactive=True, ) def get_all_records(user_token): """Get all record list in Kadi.""" if not user_token: return [] manager = KadiManager(instance=instance, host=host, pat=user_token) host_api = manager.host if manager.host.endswith("/") else manager.host + "/" searched_resource = "records" endpoint = urljoin( host_api, searched_resource ) # e.g https://demo-kadi4mat.iam.kit.edu/api/" + "records" response = manager.search.search_resources("record", per_page=100) parsed = json.loads(response.content) total_pages = parsed["_pagination"]["total_pages"] def get_page_records(parsed_content): item_identifiers = [] items = parsed_content["items"] for item in items: item_identifiers.append(item["identifier"]) return item_identifiers all_records_identifiers = [] for page in range(1, total_pages + 1): page_endpoint = endpoint + f"?page={page}&per_page=100" response = manager.make_request(page_endpoint) parsed = json.loads(response.content) all_records_identifiers.extend(get_page_records(parsed)) return gr.Dropdown( choices=all_records_identifiers, interactive=True, label="Record Identifier", info="Select record to get file list", ) def _init_user_token(request: gr.Request): """Init user token.""" user_token = request.request.session["user_access_token"] return user_token # Landing page for login with gr.Blocks() as login_demo: gr.Markdown( """







Welcome to KadiChat!



Kadichat logo

Chat with Record in Kadi.
""" ) # Note: kadichat-logo is hosted on https://postimage.io/ with gr.Row(): with gr.Column(): _btn_placeholder = gr.Button(visible=False) with gr.Column(): btn = gr.Button("Sign in with Kadi (demo-instance)") with gr.Column(): _btn_placeholder2 = gr.Button(visible=False) gr.Markdown( """



This demo shows how to use OAuth2 to have access to Kadi.
""" ) _js_redirect = """ () => { url = '/login' + window.location.search; window.open(url, '_blank'); } """ btn.click(None, js=_js_redirect) # A simple RAG implementation class SimpleRAG: def __init__(self) -> None: self.documents = [] self.embeddings_model = None self.embeddings = None self.index = None # self.load_pdf("Brandt et al_2024_Kadi_info_page.pdf") # self.build_vector_db() def load_pdf(self, file_path: str) -> None: """Extracts text from a PDF file and stores it in the property documents by page.""" doc = pymupdf.open(file_path) self.documents = [] for page_num in range(len(doc)): page = doc[page_num] text = page.get_text() self.documents.append({"page": page_num + 1, "content": text}) # print("PDF processed successfully!") def build_vector_db(self) -> None: """Builds a vector database using the content of the PDF.""" if self.embeddings_model is None: self.embeddings_model = SentenceTransformer( "sentence-transformers/all-mpnet-base-v2", trust_remote_code=True ) # jinaai/jina-embeddings-v2-base-de? # Use local model # print("now doing embedding") # print("len of documents", len(self.documents)) # embedding_responses = embeddings_client.post(json={"inputs":[doc["content"] for doc in self.documents]}, task="feature-extraction") # self.embeddings = np.array(json.loads(embedding_responses.decode())) self.embeddings = self.embeddings_model.encode( [doc["content"] for doc in self.documents], show_progress_bar=True ) self.index = faiss.IndexFlatL2(self.embeddings.shape[1]) self.index.add(np.array(self.embeddings)) print("Vector database built successfully!") def search_documents(self, query: str, k: int = 4) -> List[str]: """Searches for relevant documents using vector similarity.""" # Use embeddings_client # query_embedding = self.embeddings_model.encode([query], show_progress_bar=False) embedding_responses = embeddings_client.post( json={"inputs": [query]}, task="feature-extraction" ) query_embedding = json.loads(embedding_responses.decode()) D, I = self.index.search(np.array(query_embedding), k) results = [self.documents[i]["content"] for i in I[0]] return results if results else ["No relevant documents found."] def chunk_text(text, chunk_size=2048, overlap_size=256, separators=["\n\n", "\n"]): """Chunk text into pieces of specified size with overlap, considering separators.""" # Split the text by the separators for sep in separators: text = text.replace(sep, "\n") chunks = [] start = 0 while start < len(text): # Determine the end of the chunk, accounting for overlap and the chunk size end = min(len(text), start + chunk_size) # Find a natural break point at the newline to avoid cutting words if end < len(text): while end > start and text[end] != "\n": end -= 1 chunk = text[start:end].strip() # Strip trailing whitespace chunks.append(chunk) # Move the start position forward by the overlap size start += chunk_size - overlap_size return chunks def load_and_chunk_pdf(file_path): """Extracts text from a PDF file and stores it in the property documents by chunks.""" with pymupdf.open(file_path) as pdf: text = "" for page in pdf: text += page.get_text() chunks = chunk_text(text) documents = [] for chunk in chunks: documents.append({"content": chunk, "metadata": pdf.metadata}) return documents def load_pdf(file_path): """Extracts text from a PDF file and stores it in the property documents by page.""" doc = pymupdf.open(file_path) documents = [] for page_num in range(len(doc)): page = doc[page_num] text = page.get_text() documents.append({"page": page_num + 1, "content": text}) print("PDF processed successfully!") return documents def prepare_file_for_chat(record_id, file_names, token, progress=gr.Progress()): """Parse file and prepare RAG.""" if not file_names: raise gr.Error("No file selected") progress(0, desc="Starting") # Create connection to kadi manager = KadiManager(instance=instance, host=host, pat=token) record = manager.record(identifier=record_id) progress(0.2, desc="Loading files...") # Parse files documents = [] # Download for file_name in file_names: file_id = record.get_file_id(file_name) with tempfile.TemporaryDirectory(prefix="tmp-kadichat-downloads-") as temp_dir: print(temp_dir) temp_file_location = os.path.join(temp_dir, file_name) record.download_file(file_id, temp_file_location) # parse document docs = load_and_chunk_pdf(temp_file_location) documents.extend(docs) progress(0.4, desc="Embedding documents...") user_rag = SimpleRAG() user_rag.documents = documents user_rag.embeddings_model = embeddings_model user_rag.build_vector_db() # print(documents[:2]) print("user rag created") progress(1, desc="ready to chat") return "ready to chat", user_rag def preprocess_response(response: str) -> str: """Preprocesses the response to make it more polished.""" # Placeholder for preprocessing # response = response.strip() # response = response.replace("\n\n", "\n") # response = response.replace(" ,", ",") # response = response.replace(" .", ".") # response = " ".join(response.split()) # if not any(word in response.lower() for word in ["sorry", "apologize", "empathy"]): # response = "I'm here to help. " + response return response def respond(message: str, history: List[Tuple[str, str]], user_session_rag): """Get respond from LLMs.""" # message is the current input query from user # RAG retrieved_docs = user_session_rag.search_documents(message) context = "\n".join(retrieved_docs) system_message = "You are an assistant to help user to answer question related to Kadi based on Relevant documents.\nRelevant documents: {}".format( context ) messages = [{"role": "assistant", "content": system_message}] # Add history for conversational chat, TODO # for val in history: # #if val[0]: # messages.append({"role": "user", "content": val[0]}) # #if val[1]: # messages.append({"role": "assistant", "content": val[1]}) messages.append({"role": "user", "content": f"\nQuestion: {message}"}) # print("-----------------") # print(messages) # print("-----------------") # Get anwser from LLM response = client.chat_completion( messages, max_tokens=2048, temperature=0.0 ) # , top_p=0.9) response_content = "".join( [ choice.message["content"] for choice in response.choices if "content" in choice.message ] ) # Process response polished_response = preprocess_response(response_content) history.append((message, polished_response)) return history, "" app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) app = gr.mount_gradio_app(app, login_demo, path="/main") # Gradio interface with gr.Blocks() as main_demo: # State for storing user token _state_user_token = gr.State([]) # State for user rag user_session_rag = gr.State("placeholder") with gr.Row(): with gr.Column(scale=7): m = gr.Markdown("Welcome to Chatbot!") main_demo.load(greet, None, m) with gr.Column(scale=1): gr.Button("Logout", link="/logout") with gr.Tab("Main"): with gr.Row(): with gr.Column(scale=7): chatbot = gr.Chatbot() with gr.Column(scale=3): record_list = gr.Dropdown(label="Record Identifier") record_file_dropdown = gr.Dropdown( choices=[""], label="Select file", info="Select (max. 3) files to chat with.", multiselect=True, max_choices=3, ) gr.Markdown(" " * 200) # Use .then to ensure get token first main_demo.load(_init_user_token, None, _state_user_token).then( get_all_records, _state_user_token, record_list ) parse_files = gr.Button("Parse files") # message_box = gr.Markdown("") message_box = gr.Textbox( label="", value="progress bar", interactive=False ) # Interactions # Update file list after selecting record record_list.select( fn=get_files_in_record, inputs=[record_list, _state_user_token], outputs=record_file_dropdown, ) # Prepare files for chatbot parse_files.click( fn=prepare_file_for_chat, inputs=[record_list, record_file_dropdown, _state_user_token], outputs=[message_box, user_session_rag], ) with gr.Row(): txt_input = gr.Textbox( show_label=False, placeholder="Type your question here...", lines=1 ) submit_btn = gr.Button("Submit", scale=1) refresh_btn = gr.Button("Refresh Chat", scale=1, variant="secondary") example_questions = [ ["Summarize the paper."], ["how to create record in kadi4mat?"], ] gr.Examples(examples=example_questions, inputs=[txt_input]) # Actions txt_input.submit( fn=respond, inputs=[txt_input, chatbot, user_session_rag], outputs=[chatbot, txt_input], ) submit_btn.click( fn=respond, inputs=[txt_input, chatbot, user_session_rag], outputs=[chatbot, txt_input], ) refresh_btn.click(lambda: [], None, chatbot) app = gr.mount_gradio_app(app, main_demo, path="/gradio", auth_dependency=get_user) if __name__ == "__main__": uvicorn.run(app, port=7860, host="0.0.0.0")