Spaces:
Sleeping
Sleeping
import gradio as gr | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import WebBaseLoader | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings | |
from groq import Groq | |
from langchain_groq import ChatGroq | |
from langchain.prompts import PromptTemplate | |
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser | |
import os | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from typing_extensions import TypedDict | |
from typing import List | |
from langchain.schema import Document | |
from langgraph.graph import END, StateGraph | |
# Environment setup | |
os.environ['TAVILY_API_KEY'] = "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" | |
# Model and embedding setup | |
embed_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5") | |
llm = ChatGroq(temperature=0, model_name="Llama3-8b-8192", api_key="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") | |
# Load documents from URLs | |
urls = ["https://lilianweng.github.io/posts/2023-06-23-agent/", | |
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/", | |
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/"] | |
docs = [WebBaseLoader(url).load() for url in urls] | |
docs_list = [item for sublist in docs for item in sublist] | |
# Document splitting | |
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=512, chunk_overlap=0) | |
doc_splits = text_splitter.split_documents(docs_list) | |
# Vectorstore setup | |
vectorstore = Chroma.from_documents(documents=doc_splits, embedding=embed_model, collection_name="local-rag") | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 2}) | |
# Prompt templates | |
question_router_prompt = PromptTemplate( | |
template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an expert at routing a | |
user question to a vectorstore or web search. Use the vectorstore for questions on LLM agents, | |
prompt engineering, and adversarial attacks. Otherwise, use web-search. Give a binary choice 'web_search' | |
or 'vectorstore' based on the question. Return a JSON with a single key 'datasource' and no preamble. | |
Question to route: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""", | |
input_variables=["question"], | |
) | |
question_router = question_router_prompt | llm | JsonOutputParser() | |
rag_chain_prompt = PromptTemplate( | |
template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks. | |
Use the following pieces of retrieved context to answer the question concisely. <|eot_id|><|start_header_id|>user<|end_header_id|> | |
Question: {question} | |
Context: {context} | |
Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""", | |
input_variables=["question", "document"], | |
) | |
# Chain | |
rag_chain = rag_chain_prompt | llm | StrOutputParser() | |
# Web search tool | |
web_search_tool = TavilySearchResults(k=3) | |
# Workflow functions | |
def retrieve(state): | |
question = state["question"] | |
documents = retriever.invoke(question) | |
return {"documents": documents, "question": question} | |
def generate(state): | |
question = state["question"] | |
documents = state["documents"] | |
generation = rag_chain.invoke({"context": documents, "question": question}) | |
return {"documents": documents, "question": question, "generation": generation} | |
def route_question(state): | |
question = state["question"] | |
source = question_router.invoke({"question": question}) | |
return "websearch" if source['datasource'] == 'web_search' else "vectorstore" | |
def web_search(state): | |
question = state["question"] | |
docs = web_search_tool.invoke({"query": question}) | |
web_results = Document(page_content="\n".join([d["content"] for d in docs])) | |
documents = state.get("documents", []) | |
documents.append(web_results) | |
return {"documents": documents, "question": question} | |
workflow = StateGraph(TypedDict("GraphState", {"question": str, "generation": str, "documents": List[Document]})) | |
# Define the nodes | |
workflow.add_node("websearch", web_search) | |
workflow.add_node("retrieve", retrieve) | |
workflow.add_node("generate", generate) | |
workflow.set_conditional_entry_point( | |
route_question, | |
{ | |
"websearch": "websearch", | |
"vectorstore": "retrieve", | |
}, | |
) | |
workflow.add_edge("retrieve", "generate") | |
workflow.add_edge("websearch", "generate") | |
# Compile the app | |
app = workflow.compile() | |
# Gradio integration with Chatbot | |
# Updated ask_question_conversation function | |
def ask_question_conversation(history, question): | |
inputs = {"question": question} | |
generation_result = None | |
# Run the workflow and get the generation result | |
for output in app.stream(inputs): | |
for key, value in output.items(): | |
generation_result = value.get("generation", "No generation found") | |
# Append the new question and response to the history | |
history.append((question, generation_result)) | |
# Return the updated history to chatbot and clear the question textbox | |
return history, "" | |
# Gradio conversation UI | |
''' | |
with gr.Blocks() as demo: | |
gr.Markdown("🤖 Multi-Agent Knowledge Assistant: Powered by RAG for Smart Answers!") | |
chatbot = gr.Chatbot(label="Chat with AI Assistant") | |
question = gr.Textbox(label="Your Question", placeholder="Ask your question here...") | |
clear = gr.Button("Clear Conversation") | |
# Submit action for the question textbox | |
question.submit(ask_question_conversation, [chatbot, question], [chatbot, question]) | |
clear.click(lambda: [], None, chatbot) # Clear conversation history | |
demo.launch() | |
''' | |
with gr.Blocks(css=""" | |
#title { | |
font-size: 26px; | |
font-weight: bold; | |
text-align: center; | |
color: #4A90E2; | |
} | |
#subtitle { | |
font-size: 18px; | |
text-align: center; | |
margin-top: -15px; | |
color: #7D7D7D; | |
} | |
.gr-chatbot, .gr-textbox, .gr-button { | |
max-width: 600px; | |
margin: 0 auto; | |
} | |
.gr-chatbot { | |
height: 400px; | |
} | |
.gr-button { | |
display: block; | |
width: 100px; | |
margin: 20px auto; | |
background-color: #4A90E2; | |
color: white; | |
} | |
""") as demo: | |
gr.Markdown("<div id='title'>🤖 Multi-Agent Knowledge Assistant: Powered by RAG for Smart Answers!</div>") | |
chatbot = gr.Chatbot(label="Chat with AI Assistant") | |
question = gr.Textbox(label="Ask a Question", placeholder="Type your question here...", lines=1) | |
clear = gr.Button("Clear Chat") | |
# Submit action for the question textbox | |
question.submit(ask_question_conversation, [chatbot, question], [chatbot, question]) | |
clear.click(lambda: [], None, chatbot) # Clear conversation history | |
demo.launch() |