Spaces:
Running
Running
Original code
Browse files- README.md +1 -1
- dotenv.sample +10 -0
- requirements.txt +11 -0
- search_agent.py +333 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
# Simple Search Agent
|
2 |
|
3 |
-
This is a simple search agent that
|
4 |
|
5 |
## How It Works
|
6 |
|
|
|
1 |
# Simple Search Agent
|
2 |
|
3 |
+
This is a simple search agent that (kind of) does what [Perplexity AI](https://www.perplexity.ai/) does.
|
4 |
|
5 |
## How It Works
|
6 |
|
dotenv.sample
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
OPENAI_API_KEY=sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
3 |
+
ANTHROPIC_API_KEY=sk-ant-api03-XXXXXXXXXXXXXXXXXXX
|
4 |
+
GROQ_API_KEY=gsk_XXXXXXXXXXXXXXXXXXXXXXXX
|
5 |
+
CREDENTIALS_PROFILE_NAME=XXXXXXXXXXXXXXXX
|
6 |
+
|
7 |
+
LANGCHAIN_API_KEY=ls__XXXXXXXXXXXXXXXXXXXXXXXXXX
|
8 |
+
LANGCHAIN_TRACING_V2=true
|
9 |
+
|
10 |
+
BRAVE_SEARCH_API_KEY=BSXXXXXXXXXXXXXXXX
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
boto3
|
2 |
+
bs4
|
3 |
+
docopt
|
4 |
+
faiss-cpu
|
5 |
+
python-dotenv
|
6 |
+
langchain
|
7 |
+
langchain_community
|
8 |
+
langchain_openai
|
9 |
+
langchain_groq
|
10 |
+
langsmith
|
11 |
+
rich
|
search_agent.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""search_agent.py
|
2 |
+
|
3 |
+
Usage:
|
4 |
+
search_agent.py
|
5 |
+
[--domain=domain]
|
6 |
+
[--provider=provider]
|
7 |
+
[--temperature=temp]
|
8 |
+
[--max_pages=num]
|
9 |
+
SEARCH_QUERY
|
10 |
+
search_agent.py --version
|
11 |
+
|
12 |
+
Options:
|
13 |
+
-h --help Show this screen.
|
14 |
+
--version Show version.
|
15 |
+
-d domain --domain=domain Limit search to a specific domain
|
16 |
+
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
17 |
+
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq) [default: openai]
|
18 |
+
-m num --max_pages=num Max number of pages to retrieve [default: 10]
|
19 |
+
|
20 |
+
"""
|
21 |
+
|
22 |
+
import json
|
23 |
+
import os
|
24 |
+
from concurrent.futures import ThreadPoolExecutor
|
25 |
+
from urllib.parse import quote
|
26 |
+
|
27 |
+
from bs4 import BeautifulSoup
|
28 |
+
from docopt import docopt
|
29 |
+
import dotenv
|
30 |
+
|
31 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
32 |
+
from langchain.schema import SystemMessage, HumanMessage
|
33 |
+
from langchain.callbacks import LangChainTracer
|
34 |
+
from langchain_groq import ChatGroq
|
35 |
+
from langchain_openai import ChatOpenAI
|
36 |
+
from langchain_openai import OpenAIEmbeddings
|
37 |
+
from langchain_community.vectorstores.faiss import FAISS
|
38 |
+
from langchain_community.chat_models.bedrock import BedrockChat
|
39 |
+
from langsmith import Client
|
40 |
+
|
41 |
+
import requests
|
42 |
+
|
43 |
+
from rich.console import Console
|
44 |
+
from rich.rule import Rule
|
45 |
+
from rich.markdown import Markdown
|
46 |
+
|
47 |
+
|
48 |
+
def get_chat_llm(provider, temperature=0.0):
|
49 |
+
console.log(f"Using provider {provider} with temperature {temperature}")
|
50 |
+
match provider:
|
51 |
+
case 'bedrock':
|
52 |
+
chat_llm = BedrockChat(
|
53 |
+
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME'),
|
54 |
+
model_id="anthropic.claude-3-sonnet-20240229-v1:0",
|
55 |
+
model_kwargs={"temperature": temperature },
|
56 |
+
)
|
57 |
+
case 'openai':
|
58 |
+
chat_llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=temperature)
|
59 |
+
case 'groq':
|
60 |
+
chat_llm = ChatGroq(model_name = 'mixtral-8x7b-32768', temperature=temperature)
|
61 |
+
case _:
|
62 |
+
raise ValueError(f"Unknown LLM provider {provider}")
|
63 |
+
return chat_llm
|
64 |
+
|
65 |
+
def optimize_search_query(query):
|
66 |
+
messages = [
|
67 |
+
SystemMessage(
|
68 |
+
content="""
|
69 |
+
You are a serach query optimizer specialist.
|
70 |
+
Rewrite the user's question using only the most important keywords. Remove extra words.
|
71 |
+
Tips:
|
72 |
+
Identify the key concepts in the question
|
73 |
+
Remove filler words like "how to", "what is", "I want to"
|
74 |
+
Removed style such as "in the style of", "engaging", "short", "long"
|
75 |
+
Remove lenght instruction (example: essay, article, letter, blog, post, blogpost, etc)
|
76 |
+
Keep it short, around 3-7 words total
|
77 |
+
Put the most important keywords first
|
78 |
+
Remove formatting instructions
|
79 |
+
Remove style instructions (exmaple: in the style of, engaging, short, long)
|
80 |
+
Remove lenght instruction (example: essay, article, letter, etc)
|
81 |
+
Example:
|
82 |
+
Question: How do I bake chocolate chip cookies from scratch?
|
83 |
+
Search query: chocolate chip cookies recipe from scratch
|
84 |
+
Example:
|
85 |
+
Question: I would like you to show me a time line of Marie Curie life. Show results as a markdown table
|
86 |
+
Search query: Marie Curie timeline
|
87 |
+
Example:
|
88 |
+
Question: I would like you to write a long article on nato vs russia. Use know geopolical frameworks.
|
89 |
+
Search query: geopolitics nato russia
|
90 |
+
Example:
|
91 |
+
Question: Write a engaging linkedin post about Andrew Ng
|
92 |
+
Search query: Andrew Ng
|
93 |
+
Example:
|
94 |
+
Question: Write a short artible about the solar system in the style of Carl Sagan
|
95 |
+
Search query: solar system
|
96 |
+
Example:
|
97 |
+
Question: Should I use Kubernetes? Answer in the style of Gilfoyde from the TV show Silicon Valley
|
98 |
+
Search query: Kubernetes decision
|
99 |
+
Example:
|
100 |
+
Question: biography of napoleon. include a table with the major events.
|
101 |
+
Search query: napoleon biography events
|
102 |
+
"""
|
103 |
+
),
|
104 |
+
HumanMessage(
|
105 |
+
content=f"""
|
106 |
+
Questions: {query}
|
107 |
+
Search query:
|
108 |
+
"""
|
109 |
+
),
|
110 |
+
]
|
111 |
+
|
112 |
+
response = chat.invoke(messages, config={"callbacks": callbacks})
|
113 |
+
return response.content
|
114 |
+
|
115 |
+
|
116 |
+
def get_sources(query, max_pages=10, domain=None):
|
117 |
+
search_query = query
|
118 |
+
if domain:
|
119 |
+
search_query += f" site:{domain}"
|
120 |
+
|
121 |
+
url = f"https://api.search.brave.com/res/v1/web/search?q={quote(search_query)}&count={max_pages}"
|
122 |
+
headers = {
|
123 |
+
'Accept': 'application/json',
|
124 |
+
'Accept-Encoding': 'gzip',
|
125 |
+
'X-Subscription-Token': os.getenv("BRAVE_SEARCH_API_KEY")
|
126 |
+
}
|
127 |
+
|
128 |
+
try:
|
129 |
+
response = requests.get(url, headers=headers)
|
130 |
+
|
131 |
+
if response.status_code != 200:
|
132 |
+
raise Exception(f"HTTP error! status: {response.status_code}")
|
133 |
+
|
134 |
+
json_response = response.json()
|
135 |
+
|
136 |
+
if 'web' not in json_response or 'results' not in json_response['web']:
|
137 |
+
raise Exception('Invalid API response format')
|
138 |
+
|
139 |
+
final_results = [{
|
140 |
+
'title': result['title'],
|
141 |
+
'link': result['url'],
|
142 |
+
'snippet': result['description'],
|
143 |
+
'favicon': result.get('profile', {}).get('img', '')
|
144 |
+
} for result in json_response['web']['results']]
|
145 |
+
|
146 |
+
return final_results
|
147 |
+
|
148 |
+
except Exception as error:
|
149 |
+
#console.log('Error fetching search results:', error)
|
150 |
+
raise
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
def fetch_with_timeout(url, timeout=8):
|
155 |
+
try:
|
156 |
+
response = requests.get(url, timeout=timeout)
|
157 |
+
response.raise_for_status()
|
158 |
+
return response
|
159 |
+
except requests.RequestException as error:
|
160 |
+
#console.log(f"Skipping {url}! Error: {error}")
|
161 |
+
return None
|
162 |
+
|
163 |
+
def extract_main_content(html):
|
164 |
+
try:
|
165 |
+
soup = BeautifulSoup(html, 'html.parser')
|
166 |
+
for element in soup(["script", "style", "head", "nav", "footer", "iframe", "img"]):
|
167 |
+
element.extract()
|
168 |
+
main_content = ' '.join(soup.body.get_text().split())
|
169 |
+
return main_content
|
170 |
+
except Exception as error:
|
171 |
+
#console.log(f"Error extracting main content: {error}")
|
172 |
+
return None
|
173 |
+
|
174 |
+
def process_source(source):
|
175 |
+
response = fetch_with_timeout(source['link'], 8)
|
176 |
+
if response:
|
177 |
+
html = response.text
|
178 |
+
main_content = extract_main_content(html)
|
179 |
+
return {**source, 'html': main_content}
|
180 |
+
return None
|
181 |
+
|
182 |
+
def get_links_contents(sources):
|
183 |
+
with ThreadPoolExecutor() as executor:
|
184 |
+
results = list(executor.map(process_source, sources))
|
185 |
+
|
186 |
+
# Filter out None results
|
187 |
+
return [result for result in results if result is not None]
|
188 |
+
|
189 |
+
def process_and_vectorize_content(
|
190 |
+
contents,
|
191 |
+
query,
|
192 |
+
text_chunk_size=1000,
|
193 |
+
text_chunk_overlap=200,
|
194 |
+
number_of_similarity_results=5
|
195 |
+
):
|
196 |
+
"""
|
197 |
+
Process and vectorize content using Langchain.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
contents (list): List of dictionaries containing 'title', 'link', and 'html' keys.
|
201 |
+
query (str): Query string for similarity search.
|
202 |
+
text_chunk_size (int): Size of each text chunk.
|
203 |
+
text_chunk_overlap (int): Overlap between text chunks.
|
204 |
+
number_of_similarity_results (int): Number of most similar results to return.
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
list: List of most similar documents.
|
208 |
+
"""
|
209 |
+
documents = []
|
210 |
+
|
211 |
+
for content in contents:
|
212 |
+
if content['html']:
|
213 |
+
try:
|
214 |
+
# Split text into chunks
|
215 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
216 |
+
chunk_size=text_chunk_size,
|
217 |
+
chunk_overlap=text_chunk_overlap
|
218 |
+
)
|
219 |
+
texts = text_splitter.split_text(content['html'])
|
220 |
+
|
221 |
+
# Create metadata for each text chunk
|
222 |
+
metadatas = [{'title': content['title'], 'link': content['link']} for _ in range(len(texts))]
|
223 |
+
|
224 |
+
# Create vector store
|
225 |
+
embeddings = OpenAIEmbeddings()
|
226 |
+
docsearch = FAISS.from_texts(texts, embedding=embeddings, metadatas=metadatas)
|
227 |
+
|
228 |
+
# Perform similarity search
|
229 |
+
docs = docsearch.similarity_search(query, k=number_of_similarity_results)
|
230 |
+
doc_dicts = [{'page_content': doc.page_content, 'metadata': doc.metadata} for doc in docs]
|
231 |
+
documents.extend(doc_dicts)
|
232 |
+
|
233 |
+
except Exception as e:
|
234 |
+
console.log(f"[gray]Error processing content for {content['link']}: {e}")
|
235 |
+
|
236 |
+
|
237 |
+
return documents
|
238 |
+
|
239 |
+
|
240 |
+
def answer_query_with_sources(query, relevant_docs):
|
241 |
+
messages = [
|
242 |
+
SystemMessage(
|
243 |
+
content="""
|
244 |
+
You are an expert research assistant.
|
245 |
+
You are provided with a Context in JSON format and a Question.
|
246 |
+
|
247 |
+
Use RAG to answer the Question, providing references and links to the Context material you retrieve and use in your answer:
|
248 |
+
When generating your answer, follow these steps:
|
249 |
+
- Retrieve the most relevant context material from your knowledge base to help answer the question
|
250 |
+
- Cite the references you use by including the title, author, publication, and a link to each source
|
251 |
+
- Synthesize the retrieved information into a clear, informative answer to the question
|
252 |
+
- Format your answer in Markdown, using heading levels 2-3 as needed
|
253 |
+
- Include a "References" section at the end with the full citations and link for each source you used
|
254 |
+
|
255 |
+
|
256 |
+
Example of Context JSON entry:
|
257 |
+
{
|
258 |
+
"page_content": "This provides access to material related to ...",
|
259 |
+
"metadata": {
|
260 |
+
"title": "Introduction - Marie Curie: Topics in Chronicling America",
|
261 |
+
"link": "https://guides.loc.gov/chronicling-america-marie-curie"
|
262 |
+
}
|
263 |
+
}
|
264 |
+
|
265 |
+
"""
|
266 |
+
),
|
267 |
+
HumanMessage(
|
268 |
+
content= f"""
|
269 |
+
Context information is below.
|
270 |
+
Context:
|
271 |
+
---------------------
|
272 |
+
{json.dumps(relevant_docs, indent=2, ensure_ascii=False)}
|
273 |
+
---------------------
|
274 |
+
Question: {query}
|
275 |
+
Answer:
|
276 |
+
"""
|
277 |
+
),
|
278 |
+
]
|
279 |
+
|
280 |
+
response = chat.invoke(messages, config={"callbacks": callbacks})
|
281 |
+
return response
|
282 |
+
|
283 |
+
console = Console()
|
284 |
+
dotenv.load_dotenv()
|
285 |
+
|
286 |
+
callbacks = []
|
287 |
+
if(os.getenv("LANGCHAIN_API_KEY")):
|
288 |
+
callbacks.append(
|
289 |
+
LangChainTracer(
|
290 |
+
project_name="search agent",
|
291 |
+
client=Client(
|
292 |
+
api_url="https://api.smith.langchain.com",
|
293 |
+
)
|
294 |
+
)
|
295 |
+
)
|
296 |
+
|
297 |
+
if __name__ == '__main__':
|
298 |
+
arguments = docopt(__doc__, version='Search Agent 0.1')
|
299 |
+
#print(arguments)
|
300 |
+
|
301 |
+
|
302 |
+
provider = arguments["--provider"]
|
303 |
+
temperature = float(arguments["--temperature"])
|
304 |
+
chat = get_chat_llm(provider, temperature)
|
305 |
+
query = arguments["SEARCH_QUERY"]
|
306 |
+
|
307 |
+
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
308 |
+
optimize_search_query = optimize_search_query(query)
|
309 |
+
console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
|
310 |
+
|
311 |
+
domain=arguments["--domain"]
|
312 |
+
max_pages=arguments["--max_pages"]
|
313 |
+
with console.status(f"[bold green]Searching sources using the optimized query: {optimize_search_query}"):
|
314 |
+
sources = get_sources(optimize_search_query, max_pages=max_pages, domain=domain)
|
315 |
+
console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
|
316 |
+
|
317 |
+
with console.status(f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"):
|
318 |
+
contents = get_links_contents(sources)
|
319 |
+
console.log(f"Managed to extract content from {len(contents)} sources")
|
320 |
+
|
321 |
+
with console.status(
|
322 |
+
f"[bold green]Processing {len(contents)} contents and finding relevant extracts",
|
323 |
+
spinner="dots8Bit"
|
324 |
+
):
|
325 |
+
relevant_docs = process_and_vectorize_content(contents, query)
|
326 |
+
console.log(f"Filtered {len(relevant_docs)} relevant content extracts")
|
327 |
+
|
328 |
+
with console.status(f"[bold green]Querying LLM with {len(relevant_docs)} relevant extracts", spinner='dots8Bit'):
|
329 |
+
respomse = answer_query_with_sources(query, relevant_docs)
|
330 |
+
|
331 |
+
console.rule(f"[bold green]Response from {provider}")
|
332 |
+
console.print(Markdown(respomse.content))
|
333 |
+
console.rule("[bold green]")
|