import gradio as gr from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import WebBaseLoader from langchain_community.vectorstores import Chroma from langchain_community.embeddings.fastembed import FastEmbedEmbeddings from groq import Groq from langchain_groq import ChatGroq from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser, StrOutputParser import os from langchain_community.tools.tavily_search import TavilySearchResults from typing_extensions import TypedDict from typing import List from langchain.schema import Document from langgraph.graph import END, StateGraph # Environment setup os.environ['TAVILY_API_KEY'] = "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" # Model and embedding setup embed_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5") llm = ChatGroq(temperature=0, model_name="Llama3-8b-8192", api_key="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") # Load documents from URLs urls = ["https://lilianweng.github.io/posts/2023-06-23-agent/", "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/", "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/"] docs = [WebBaseLoader(url).load() for url in urls] docs_list = [item for sublist in docs for item in sublist] # Document splitting text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=512, chunk_overlap=0) doc_splits = text_splitter.split_documents(docs_list) # Vectorstore setup vectorstore = Chroma.from_documents(documents=doc_splits, embedding=embed_model, collection_name="local-rag") retriever = vectorstore.as_retriever(search_kwargs={"k": 2}) # Prompt templates question_router_prompt = PromptTemplate( template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an expert at routing a user question to a vectorstore or web search. Use the vectorstore for questions on LLM agents, prompt engineering, and adversarial attacks. Otherwise, use web-search. Give a binary choice 'web_search' or 'vectorstore' based on the question. Return a JSON with a single key 'datasource' and no preamble. Question to route: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""", input_variables=["question"], ) question_router = question_router_prompt | llm | JsonOutputParser() rag_chain_prompt = PromptTemplate( template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question concisely. <|eot_id|><|start_header_id|>user<|end_header_id|> Question: {question} Context: {context} Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""", input_variables=["question", "document"], ) # Chain rag_chain = rag_chain_prompt | llm | StrOutputParser() # Web search tool web_search_tool = TavilySearchResults(k=3) # Workflow functions def retrieve(state): question = state["question"] documents = retriever.invoke(question) return {"documents": documents, "question": question} def generate(state): question = state["question"] documents = state["documents"] generation = rag_chain.invoke({"context": documents, "question": question}) return {"documents": documents, "question": question, "generation": generation} def route_question(state): question = state["question"] source = question_router.invoke({"question": question}) return "websearch" if source['datasource'] == 'web_search' else "vectorstore" def web_search(state): question = state["question"] docs = web_search_tool.invoke({"query": question}) web_results = Document(page_content="\n".join([d["content"] for d in docs])) documents = state.get("documents", []) documents.append(web_results) return {"documents": documents, "question": question} workflow = StateGraph(TypedDict("GraphState", {"question": str, "generation": str, "documents": List[Document]})) # Define the nodes workflow.add_node("websearch", web_search) workflow.add_node("retrieve", retrieve) workflow.add_node("generate", generate) workflow.set_conditional_entry_point( route_question, { "websearch": "websearch", "vectorstore": "retrieve", }, ) workflow.add_edge("retrieve", "generate") workflow.add_edge("websearch", "generate") # Compile the app app = workflow.compile() # Gradio integration with Chatbot # Updated ask_question_conversation function def ask_question_conversation(history, question): inputs = {"question": question} generation_result = None # Run the workflow and get the generation result for output in app.stream(inputs): for key, value in output.items(): generation_result = value.get("generation", "No generation found") # Append the new question and response to the history history.append((question, generation_result)) # Return the updated history to chatbot and clear the question textbox return history, "" # Gradio conversation UI ''' with gr.Blocks() as demo: gr.Markdown("🤖 Multi-Agent Knowledge Assistant: Powered by RAG for Smart Answers!") chatbot = gr.Chatbot(label="Chat with AI Assistant") question = gr.Textbox(label="Your Question", placeholder="Ask your question here...") clear = gr.Button("Clear Conversation") # Submit action for the question textbox question.submit(ask_question_conversation, [chatbot, question], [chatbot, question]) clear.click(lambda: [], None, chatbot) # Clear conversation history demo.launch() ''' with gr.Blocks(css=""" #title { font-size: 26px; font-weight: bold; text-align: center; color: #4A90E2; } #subtitle { font-size: 18px; text-align: center; margin-top: -15px; color: #7D7D7D; } .gr-chatbot, .gr-textbox, .gr-button { max-width: 600px; margin: 0 auto; } .gr-chatbot { height: 400px; } .gr-button { display: block; width: 100px; margin: 20px auto; background-color: #4A90E2; color: white; } """) as demo: gr.Markdown("