RaghulDevaraj commited on
Commit
b0370b6
·
verified ·
1 Parent(s): 0716c9d

initial commit

Browse files
Files changed (1) hide show
  1. app.py +181 -60
app.py CHANGED
@@ -1,64 +1,185 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- if __name__ == "__main__":
64
- demo.launch()
 
1
  import gradio as gr
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain_community.document_loaders import WebBaseLoader
4
+ from langchain_community.vectorstores import Chroma
5
+ from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
6
+ from groq import Groq
7
+ from langchain_groq import ChatGroq
8
+ from langchain.prompts import PromptTemplate
9
+ from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
10
+ import os
11
+ from langchain_community.tools.tavily_search import TavilySearchResults
12
+ from typing_extensions import TypedDict
13
+ from typing import List
14
+ from langchain.schema import Document
15
+ from langgraph.graph import END, StateGraph
16
+
17
+ # Environment setup
18
+ os.environ['TAVILY_API_KEY'] = "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"
19
+
20
+ # Model and embedding setup
21
+ embed_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5")
22
+ llm = ChatGroq(temperature=0, model_name="Llama3-8b-8192", api_key="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")
23
+
24
+ # Load documents from URLs
25
+ urls = ["https://lilianweng.github.io/posts/2023-06-23-agent/",
26
+ "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
27
+ "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/"]
28
+
29
+ docs = [WebBaseLoader(url).load() for url in urls]
30
+ docs_list = [item for sublist in docs for item in sublist]
31
+
32
+ # Document splitting
33
+ text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=512, chunk_overlap=0)
34
+ doc_splits = text_splitter.split_documents(docs_list)
35
+
36
+ # Vectorstore setup
37
+ vectorstore = Chroma.from_documents(documents=doc_splits, embedding=embed_model, collection_name="local-rag")
38
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
39
+
40
+ # Prompt templates
41
+ question_router_prompt = PromptTemplate(
42
+ template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an expert at routing a
43
+ user question to a vectorstore or web search. Use the vectorstore for questions on LLM agents,
44
+ prompt engineering, and adversarial attacks. Otherwise, use web-search. Give a binary choice 'web_search'
45
+ or 'vectorstore' based on the question. Return a JSON with a single key 'datasource' and no preamble.
46
+ Question to route: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
47
+ input_variables=["question"],
48
+ )
49
+
50
+ question_router = question_router_prompt | llm | JsonOutputParser()
51
+
52
+ rag_chain_prompt = PromptTemplate(
53
+ template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks.
54
+ Use the following pieces of retrieved context to answer the question concisely. <|eot_id|><|start_header_id|>user<|end_header_id|>
55
+ Question: {question}
56
+ Context: {context}
57
+ Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
58
+ input_variables=["question", "document"],
59
+ )
60
+
61
+ # Chain
62
+ rag_chain = rag_chain_prompt | llm | StrOutputParser()
63
+
64
+ # Web search tool
65
+ web_search_tool = TavilySearchResults(k=3)
66
+
67
+ # Workflow functions
68
+ def retrieve(state):
69
+ question = state["question"]
70
+ documents = retriever.invoke(question)
71
+ return {"documents": documents, "question": question}
72
+
73
+ def generate(state):
74
+ question = state["question"]
75
+ documents = state["documents"]
76
+ generation = rag_chain.invoke({"context": documents, "question": question})
77
+ return {"documents": documents, "question": question, "generation": generation}
78
+
79
+ def route_question(state):
80
+ question = state["question"]
81
+ source = question_router.invoke({"question": question})
82
+ return "websearch" if source['datasource'] == 'web_search' else "vectorstore"
83
+
84
+ def web_search(state):
85
+ question = state["question"]
86
+ docs = web_search_tool.invoke({"query": question})
87
+ web_results = Document(page_content="\n".join([d["content"] for d in docs]))
88
+ documents = state.get("documents", [])
89
+ documents.append(web_results)
90
+ return {"documents": documents, "question": question}
91
+
92
+ workflow = StateGraph(TypedDict("GraphState", {"question": str, "generation": str, "documents": List[Document]}))
93
+
94
+ # Define the nodes
95
+ workflow.add_node("websearch", web_search)
96
+ workflow.add_node("retrieve", retrieve)
97
+ workflow.add_node("generate", generate)
98
+
99
+ workflow.set_conditional_entry_point(
100
+ route_question,
101
+ {
102
+ "websearch": "websearch",
103
+ "vectorstore": "retrieve",
104
+ },
105
  )
106
 
107
+ workflow.add_edge("retrieve", "generate")
108
+ workflow.add_edge("websearch", "generate")
109
+
110
+ # Compile the app
111
+ app = workflow.compile()
112
+
113
+ # Gradio integration with Chatbot
114
+
115
+ # Updated ask_question_conversation function
116
+ def ask_question_conversation(history, question):
117
+ inputs = {"question": question}
118
+ generation_result = None
119
+
120
+ # Run the workflow and get the generation result
121
+ for output in app.stream(inputs):
122
+ for key, value in output.items():
123
+ generation_result = value.get("generation", "No generation found")
124
+
125
+ # Append the new question and response to the history
126
+ history.append((question, generation_result))
127
+
128
+ # Return the updated history to chatbot and clear the question textbox
129
+ return history, ""
130
+
131
+ # Gradio conversation UI
132
+ '''
133
+ with gr.Blocks() as demo:
134
+ gr.Markdown("🤖 Multi-Agent Knowledge Assistant: Powered by RAG for Smart Answers!")
135
+
136
+ chatbot = gr.Chatbot(label="Chat with AI Assistant")
137
+ question = gr.Textbox(label="Your Question", placeholder="Ask your question here...")
138
+ clear = gr.Button("Clear Conversation")
139
+
140
+ # Submit action for the question textbox
141
+ question.submit(ask_question_conversation, [chatbot, question], [chatbot, question])
142
+ clear.click(lambda: [], None, chatbot) # Clear conversation history
143
+
144
+ demo.launch()
145
+ '''
146
+
147
+ with gr.Blocks(css="""
148
+ #title {
149
+ font-size: 26px;
150
+ font-weight: bold;
151
+ text-align: center;
152
+ color: #4A90E2;
153
+ }
154
+ #subtitle {
155
+ font-size: 18px;
156
+ text-align: center;
157
+ margin-top: -15px;
158
+ color: #7D7D7D;
159
+ }
160
+ .gr-chatbot, .gr-textbox, .gr-button {
161
+ max-width: 600px;
162
+ margin: 0 auto;
163
+ }
164
+ .gr-chatbot {
165
+ height: 400px;
166
+ }
167
+ .gr-button {
168
+ display: block;
169
+ width: 100px;
170
+ margin: 20px auto;
171
+ background-color: #4A90E2;
172
+ color: white;
173
+ }
174
+ """) as demo:
175
+ gr.Markdown("<div id='title'>🤖 Multi-Agent Knowledge Assistant: Powered by RAG for Smart Answers!</div>")
176
+
177
+ chatbot = gr.Chatbot(label="Chat with AI Assistant")
178
+ question = gr.Textbox(label="Ask a Question", placeholder="Type your question here...", lines=1)
179
+ clear = gr.Button("Clear Chat")
180
+
181
+ # Submit action for the question textbox
182
+ question.submit(ask_question_conversation, [chatbot, question], [chatbot, question])
183
+ clear.click(lambda: [], None, chatbot) # Clear conversation history
184
 
185
+ demo.launch()