Spaces:
Sleeping
Sleeping
File size: 6,831 Bytes
bac8623 b0370b6 bac8623 b0370b6 bac8623 b0370b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import gradio as gr
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
from groq import Groq
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
import os
from langchain_community.tools.tavily_search import TavilySearchResults
from typing_extensions import TypedDict
from typing import List
from langchain.schema import Document
from langgraph.graph import END, StateGraph
# Environment setup
os.environ['TAVILY_API_KEY'] = "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"
# Model and embedding setup
embed_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5")
llm = ChatGroq(temperature=0, model_name="Llama3-8b-8192", api_key="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")
# Load documents from URLs
urls = ["https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/"]
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
# Document splitting
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=512, chunk_overlap=0)
doc_splits = text_splitter.split_documents(docs_list)
# Vectorstore setup
vectorstore = Chroma.from_documents(documents=doc_splits, embedding=embed_model, collection_name="local-rag")
retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
# Prompt templates
question_router_prompt = PromptTemplate(
template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an expert at routing a
user question to a vectorstore or web search. Use the vectorstore for questions on LLM agents,
prompt engineering, and adversarial attacks. Otherwise, use web-search. Give a binary choice 'web_search'
or 'vectorstore' based on the question. Return a JSON with a single key 'datasource' and no preamble.
Question to route: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
input_variables=["question"],
)
question_router = question_router_prompt | llm | JsonOutputParser()
rag_chain_prompt = PromptTemplate(
template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question concisely. <|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {question}
Context: {context}
Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
input_variables=["question", "document"],
)
# Chain
rag_chain = rag_chain_prompt | llm | StrOutputParser()
# Web search tool
web_search_tool = TavilySearchResults(k=3)
# Workflow functions
def retrieve(state):
question = state["question"]
documents = retriever.invoke(question)
return {"documents": documents, "question": question}
def generate(state):
question = state["question"]
documents = state["documents"]
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
def route_question(state):
question = state["question"]
source = question_router.invoke({"question": question})
return "websearch" if source['datasource'] == 'web_search' else "vectorstore"
def web_search(state):
question = state["question"]
docs = web_search_tool.invoke({"query": question})
web_results = Document(page_content="\n".join([d["content"] for d in docs]))
documents = state.get("documents", [])
documents.append(web_results)
return {"documents": documents, "question": question}
workflow = StateGraph(TypedDict("GraphState", {"question": str, "generation": str, "documents": List[Document]}))
# Define the nodes
workflow.add_node("websearch", web_search)
workflow.add_node("retrieve", retrieve)
workflow.add_node("generate", generate)
workflow.set_conditional_entry_point(
route_question,
{
"websearch": "websearch",
"vectorstore": "retrieve",
},
)
workflow.add_edge("retrieve", "generate")
workflow.add_edge("websearch", "generate")
# Compile the app
app = workflow.compile()
# Gradio integration with Chatbot
# Updated ask_question_conversation function
def ask_question_conversation(history, question):
inputs = {"question": question}
generation_result = None
# Run the workflow and get the generation result
for output in app.stream(inputs):
for key, value in output.items():
generation_result = value.get("generation", "No generation found")
# Append the new question and response to the history
history.append((question, generation_result))
# Return the updated history to chatbot and clear the question textbox
return history, ""
# Gradio conversation UI
'''
with gr.Blocks() as demo:
gr.Markdown("🤖 Multi-Agent Knowledge Assistant: Powered by RAG for Smart Answers!")
chatbot = gr.Chatbot(label="Chat with AI Assistant")
question = gr.Textbox(label="Your Question", placeholder="Ask your question here...")
clear = gr.Button("Clear Conversation")
# Submit action for the question textbox
question.submit(ask_question_conversation, [chatbot, question], [chatbot, question])
clear.click(lambda: [], None, chatbot) # Clear conversation history
demo.launch()
'''
with gr.Blocks(css="""
#title {
font-size: 26px;
font-weight: bold;
text-align: center;
color: #4A90E2;
}
#subtitle {
font-size: 18px;
text-align: center;
margin-top: -15px;
color: #7D7D7D;
}
.gr-chatbot, .gr-textbox, .gr-button {
max-width: 600px;
margin: 0 auto;
}
.gr-chatbot {
height: 400px;
}
.gr-button {
display: block;
width: 100px;
margin: 20px auto;
background-color: #4A90E2;
color: white;
}
""") as demo:
gr.Markdown("<div id='title'>🤖 Multi-Agent Knowledge Assistant: Powered by RAG for Smart Answers!</div>")
chatbot = gr.Chatbot(label="Chat with AI Assistant")
question = gr.Textbox(label="Ask a Question", placeholder="Type your question here...", lines=1)
clear = gr.Button("Clear Chat")
# Submit action for the question textbox
question.submit(ask_question_conversation, [chatbot, question], [chatbot, question])
clear.click(lambda: [], None, chatbot) # Clear conversation history
demo.launch() |