#DOCS # https://langchain-ai.github.io/langgraph/reference/prebuilt/#langgraph.prebuilt.chat_agent_executor.create_react_agent import uuid from fastapi import FastAPI from fastapi.responses import StreamingResponse from langchain_core.messages import ( BaseMessage, HumanMessage, SystemMessage, trim_messages, ) from langchain_core.tools import tool from langchain_openai import ChatOpenAI from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import create_react_agent from pydantic import BaseModel import json from typing import Optional, Annotated from langchain_core.runnables import RunnableConfig from langgraph.prebuilt import InjectedState from document_rag_router import router as document_rag_router from document_rag_router import QueryInput, query_collection, SearchResult,db from fastapi import HTTPException import requests from sse_starlette.sse import EventSourceResponse from fastapi.middleware.cors import CORSMiddleware import re import os from langchain_core.prompts import ChatPromptTemplate import logging.config # Configure logging at application startup logging.config.dictConfig({ "version": 1, "disable_existing_loggers": False, "formatters": { "default": { "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "datefmt": "%Y-%m-%d %H:%M:%S", } }, "handlers": { "console": { "class": "logging.StreamHandler", "stream": "ext://sys.stdout", "formatter": "default", "level": "DEBUG", } }, "root": { "level": "DEBUG", "handlers": ["console"] }, "loggers": { "uvicorn": {"handlers": ["console"], "level": "DEBUG"}, "fastapi": {"handlers": ["console"], "level": "DEBUG"} } }) # Create logger instance logger = logging.getLogger(__name__) app = FastAPI() app.include_router(document_rag_router) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def get_current_files(): """Get list of files in current directory""" try: files = os.listdir('.') return ", ".join(files) except Exception as e: return f"Error getting files: {str(e)}" @tool def get_user_age(name: str) -> str: """Use this tool to find the user's age.""" if "bob" in name.lower(): return "42 years old" return "41 years old" @tool async def query_documents( query: str, config: RunnableConfig, ) -> str: """Use this tool to retrieve relevant data from the collection. Args: query: The search query to find relevant document passages """ # Get collection_id and user_id from config thread_config = config.get("configurable", {}) collection_id = thread_config.get("collection_id") user_id = thread_config.get("user_id") if not collection_id or not user_id: return "Error: collection_id and user_id are required in the config" try: # Create query input input_data = QueryInput( collection_id=collection_id, query=query, user_id=user_id, top_k=6 ) response = await query_collection(input_data) results = [] # Access response directly since it's a Pydantic model for r in response.results: result_dict = { "text": r.text, "distance": r.distance, "metadata": { "document_id": r.metadata.get("document_id"), "chunk_index": r.metadata.get("location", {}).get("chunk_index") } } results.append(result_dict) return str(results) except Exception as e: print(e) return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP" async def query_documents_raw( query: str, config: RunnableConfig, ) -> SearchResult: """Use this tool to retrieve relevant data from the collection. Args: query: The search query to find relevant document passages """ # Get collection_id and user_id from config thread_config = config.get("configurable", {}) collection_id = thread_config.get("collection_id") user_id = thread_config.get("user_id") if not collection_id or not user_id: return "Error: collection_id and user_id are required in the config" try: # Create query input input_data = QueryInput( collection_id=collection_id, query=query, user_id=user_id, top_k=6 ) response = await query_collection(input_data) return response.results except Exception as e: print(e) return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP" memory = MemorySaver() model = ChatOpenAI(model="gpt-4o-mini", streaming=True) # Create a prompt template for formatting prompt = ChatPromptTemplate.from_messages([ ("system", "You are a helpful AI assistant. The current collection contains the following files: {collection_files}, use query_documents tool to answer user queries from the document. In case a summary is requested, create multiple queries for different plausible sections of the document"), ("placeholder", "{messages}"), ]) import requests from requests.exceptions import RequestException, Timeout import logging from typing import Optional # def get_collection_files(collection_id: str, user_id: str) -> str: # """ # Synchronously get list of files in the specified collection using the external API # with proper timeout and error handling. # """ # try: # url = "https://pvanand-documind-api-v2.hf.space/rag/get_collection_files" # params = { # "collection_id": collection_id, # "user_id": user_id # } # headers = { # 'accept': 'application/json' # } # logger.debug(f"Requesting collection files for user {user_id}, collection {collection_id}") # # Set timeout to 5 seconds # response = requests.post(url, params=params, headers=headers, data='', timeout=5) # if response.status_code == 200: # logger.info(f"Successfully retrieved collection files: {response.text[:100]}...") # return response.text # else: # logger.error(f"API error (status {response.status_code}): {response.text}") # return f"Error fetching files (status {response.status_code})" # except Timeout: # logger.error("Timeout while fetching collection files") # return "Error: Request timed out" # except RequestException as e: # logger.error(f"Network error fetching collection files: {str(e)}") # return f"Error: Network issue - {str(e)}" # except Exception as e: # logger.error(f"Error fetching collection files: {str(e)}", exc_info=True) # return f"Error fetching files: {str(e)}" def get_collection_files(collection_id: str, user_id: str) -> str: """Get list of files in the specified collection""" try: # Get the full collection name collection_name = f"{user_id}_{collection_id}" # Open the table and convert to pandas table = db.open_table(collection_name) df = table.to_pandas() print(df.head()) # Get unique file names unique_files = df['file_name'].unique() # Join the file names into a string return ", ".join(unique_files) except Exception as e: logging.error(f"Error getting collection files: {str(e)}") return f"Error getting files: {str(e)}" def format_for_model(state: dict, config: Optional[RunnableConfig] = None) -> list[BaseMessage]: """ Format the input state and config for the model. Args: state: The current state dictionary containing messages config: Optional RunnableConfig containing thread configuration Returns: Formatted messages for the model """ # Get collection_id and user_id from config instead of state thread_config = config.get("configurable", {}) if config else {} collection_id = thread_config.get("collection_id") user_id = thread_config.get("user_id") try: # Get files in the collection with timeout protection if collection_id and user_id: collection_files = get_collection_files(collection_id, user_id) else: collection_files = "No files available" logger.info(f"Fetching collection for userid {user_id} and collection_id {collection_id} || Results: {collection_files[:100]}...") # Format using the prompt template return prompt.invoke({ "collection_files": collection_files, "messages": state.get("messages", []) }) except Exception as e: logger.error(f"Error in format_for_model: {str(e)}", exc_info=True) # Return a basic format if there's an error return prompt.invoke({ "collection_files": "Error fetching files", "messages": state.get("messages", []) }) async def clean_tool_input(tool_input: str): # Use regex to parse the first key and value pattern = r"{\s*'([^']+)':\s*'([^']+)'" match = re.search(pattern, tool_input) if match: key, value = match.groups() return {key: value} return [tool_input] async def clean_tool_response(tool_output: str): """Clean and extract relevant information from tool response if it contains query_documents.""" if "query_documents" in tool_output: try: # First safely evaluate the string as a Python literal import ast print(tool_output) # Extract the list string from the content start = tool_output.find("[{") end = tool_output.rfind("}]") + 2 if start >= 0 and end > 0: list_str = tool_output[start:end] # Convert string to Python object using ast.literal_eval results = ast.literal_eval(list_str) # Return only relevant fields return [{"text": r["text"], "document_id": r["metadata"]["document_id"]} for r in results] except SyntaxError as e: print(f"Syntax error in parsing: {e}") return f"Error parsing document results: {str(e)}" except Exception as e: print(f"General error: {e}") return f"Error processing results: {str(e)}" return tool_output agent = create_react_agent( model, tools=[query_documents], checkpointer=memory, state_modifier=format_for_model, ) class ChatInput(BaseModel): message: str thread_id: Optional[str] = None collection_id: Optional[str] = None user_id: Optional[str] = None @app.post("/chat") async def chat(input_data: ChatInput): thread_id = input_data.thread_id or str(uuid.uuid4()) config = { "configurable": { "thread_id": thread_id, "collection_id": input_data.collection_id, "user_id": input_data.user_id } } input_message = HumanMessage(content=input_data.message) async def generate(): async for event in agent.astream_events( {"messages": [input_message]}, config, version="v2" ): kind = event["event"] if kind == "on_chat_model_stream": content = event["data"]["chunk"].content if content: yield f"{json.dumps({'type': 'token', 'content': content})}" elif kind == "on_tool_start": tool_input = str(event['data'].get('input', '')) yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}" elif kind == "on_tool_end": tool_output = str(event['data'].get('output', '')) yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}" return EventSourceResponse( generate(), media_type="text/event-stream" ) @app.post("/chat2") async def chat2(input_data: ChatInput): thread_id = input_data.thread_id or str(uuid.uuid4()) config = { "configurable": { "thread_id": thread_id, "collection_id": input_data.collection_id, "user_id": input_data.user_id } } input_message = HumanMessage(content=input_data.message) async def generate(): async for event in agent.astream_events( {"messages": [input_message]}, config, version="v2" ): kind = event["event"] if kind == "on_chat_model_stream": content = event["data"]["chunk"].content if content: yield f"{json.dumps({'type': 'token', 'content': content})}" elif kind == "on_tool_start": tool_name = event['name'] tool_input = event['data'].get('input', '') clean_input = await clean_tool_input(str(tool_input)) yield f"{json.dumps({'type': 'tool_start', 'tool': tool_name, 'inputs': clean_input})}" elif kind == "on_tool_end": if "query_documents" in event['name']: print(event) raw_output = await query_documents_raw(str(event['data'].get('input', '')), config) try: serializable_output = [ { "text": result.text, "distance": result.distance, "metadata": result.metadata } for result in raw_output ] yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': json.dumps(serializable_output)})}" except Exception as e: print(e) yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': str(raw_output)})}" else: tool_name = event['name'] raw_output = str(event['data'].get('output', '')) clean_output = await clean_tool_response(raw_output) yield f"{json.dumps({'type': 'tool_end', 'tool': tool_name, 'output': clean_output})}" return EventSourceResponse( generate(), media_type="text/event-stream" ) @app.get("/health") async def health_check(): return {"status": "healthy"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)