File size: 6,831 Bytes
bac8623
b0370b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bac8623
 
b0370b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bac8623
b0370b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import gradio as gr
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
from groq import Groq
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
import os
from langchain_community.tools.tavily_search import TavilySearchResults
from typing_extensions import TypedDict
from typing import List
from langchain.schema import Document
from langgraph.graph import END, StateGraph

# Environment setup
os.environ['TAVILY_API_KEY'] = "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"

# Model and embedding setup
embed_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5")
llm = ChatGroq(temperature=0, model_name="Llama3-8b-8192", api_key="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")

# Load documents from URLs
urls = ["https://lilianweng.github.io/posts/2023-06-23-agent/",
        "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
        "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/"]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

# Document splitting
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=512, chunk_overlap=0)
doc_splits = text_splitter.split_documents(docs_list)

# Vectorstore setup
vectorstore = Chroma.from_documents(documents=doc_splits, embedding=embed_model, collection_name="local-rag")
retriever = vectorstore.as_retriever(search_kwargs={"k": 2})

# Prompt templates
question_router_prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an expert at routing a 
    user question to a vectorstore or web search. Use the vectorstore for questions on LLM agents, 
    prompt engineering, and adversarial attacks. Otherwise, use web-search. Give a binary choice 'web_search' 
    or 'vectorstore' based on the question. Return a JSON with a single key 'datasource' and no preamble. 
    Question to route: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["question"],
)

question_router = question_router_prompt | llm | JsonOutputParser()

rag_chain_prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks. 
    Use the following pieces of retrieved context to answer the question concisely. <|eot_id|><|start_header_id|>user<|end_header_id|>
    Question: {question} 
    Context: {context} 
    Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["question", "document"],
)

# Chain
rag_chain = rag_chain_prompt | llm | StrOutputParser()

# Web search tool
web_search_tool = TavilySearchResults(k=3)

# Workflow functions
def retrieve(state):
    question = state["question"]
    documents = retriever.invoke(question)
    return {"documents": documents, "question": question}

def generate(state):
    question = state["question"]
    documents = state["documents"]
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}

def route_question(state):
    question = state["question"]
    source = question_router.invoke({"question": question})
    return "websearch" if source['datasource'] == 'web_search' else "vectorstore"

def web_search(state):
    question = state["question"]
    docs = web_search_tool.invoke({"query": question})
    web_results = Document(page_content="\n".join([d["content"] for d in docs]))
    documents = state.get("documents", [])
    documents.append(web_results)
    return {"documents": documents, "question": question}

workflow = StateGraph(TypedDict("GraphState", {"question": str, "generation": str, "documents": List[Document]}))

# Define the nodes
workflow.add_node("websearch", web_search)
workflow.add_node("retrieve", retrieve)
workflow.add_node("generate", generate)

workflow.set_conditional_entry_point(
    route_question,
    {
        "websearch": "websearch",
        "vectorstore": "retrieve",
    },
)

workflow.add_edge("retrieve", "generate")
workflow.add_edge("websearch", "generate")

# Compile the app
app = workflow.compile()

# Gradio integration with Chatbot

# Updated ask_question_conversation function
def ask_question_conversation(history, question):
    inputs = {"question": question}
    generation_result = None
    
    # Run the workflow and get the generation result
    for output in app.stream(inputs):
        for key, value in output.items():
            generation_result = value.get("generation", "No generation found")
    
    # Append the new question and response to the history
    history.append((question, generation_result))
    
    # Return the updated history to chatbot and clear the question textbox
    return history, ""

# Gradio conversation UI
'''
with gr.Blocks() as demo:
    gr.Markdown("🤖 Multi-Agent Knowledge Assistant: Powered by RAG for Smart Answers!")
    
    chatbot = gr.Chatbot(label="Chat with AI Assistant")
    question = gr.Textbox(label="Your Question", placeholder="Ask your question here...")
    clear = gr.Button("Clear Conversation")
    
    # Submit action for the question textbox
    question.submit(ask_question_conversation, [chatbot, question], [chatbot, question])
    clear.click(lambda: [], None, chatbot)  # Clear conversation history

demo.launch()
'''

with gr.Blocks(css="""
    #title { 
        font-size: 26px; 
        font-weight: bold; 
        text-align: center; 
        color: #4A90E2;
    }
    #subtitle {
        font-size: 18px; 
        text-align: center; 
        margin-top: -15px; 
        color: #7D7D7D;
    }
    .gr-chatbot, .gr-textbox, .gr-button {
        max-width: 600px;
        margin: 0 auto;
    }
    .gr-chatbot {
        height: 400px;
    }
    .gr-button {
        display: block;
        width: 100px;
        margin: 20px auto;
        background-color: #4A90E2;
        color: white;
    }
""") as demo:
    gr.Markdown("<div id='title'>🤖 Multi-Agent Knowledge Assistant: Powered by RAG for Smart Answers!</div>")
    
    chatbot = gr.Chatbot(label="Chat with AI Assistant")
    question = gr.Textbox(label="Ask a Question", placeholder="Type your question here...", lines=1)
    clear = gr.Button("Clear Chat")
    
    # Submit action for the question textbox
    question.submit(ask_question_conversation, [chatbot, question], [chatbot, question])
    clear.click(lambda: [], None, chatbot)  # Clear conversation history

demo.launch()