CyranoB commited on
Commit
9c3709d
1 Parent(s): 48dbc73

Original code

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. dotenv.sample +10 -0
  3. requirements.txt +11 -0
  4. search_agent.py +333 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  # Simple Search Agent
2
 
3
- This is a simple search agent that accepts a question as input, searches the web for relevant information, and then uses the search results to generate an answer using a large language model (LLM).
4
 
5
  ## How It Works
6
 
 
1
  # Simple Search Agent
2
 
3
+ This is a simple search agent that (kind of) does what [Perplexity AI](https://www.perplexity.ai/) does.
4
 
5
  ## How It Works
6
 
dotenv.sample ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ OPENAI_API_KEY=sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
3
+ ANTHROPIC_API_KEY=sk-ant-api03-XXXXXXXXXXXXXXXXXXX
4
+ GROQ_API_KEY=gsk_XXXXXXXXXXXXXXXXXXXXXXXX
5
+ CREDENTIALS_PROFILE_NAME=XXXXXXXXXXXXXXXX
6
+
7
+ LANGCHAIN_API_KEY=ls__XXXXXXXXXXXXXXXXXXXXXXXXXX
8
+ LANGCHAIN_TRACING_V2=true
9
+
10
+ BRAVE_SEARCH_API_KEY=BSXXXXXXXXXXXXXXXX
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ boto3
2
+ bs4
3
+ docopt
4
+ faiss-cpu
5
+ python-dotenv
6
+ langchain
7
+ langchain_community
8
+ langchain_openai
9
+ langchain_groq
10
+ langsmith
11
+ rich
search_agent.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """search_agent.py
2
+
3
+ Usage:
4
+ search_agent.py
5
+ [--domain=domain]
6
+ [--provider=provider]
7
+ [--temperature=temp]
8
+ [--max_pages=num]
9
+ SEARCH_QUERY
10
+ search_agent.py --version
11
+
12
+ Options:
13
+ -h --help Show this screen.
14
+ --version Show version.
15
+ -d domain --domain=domain Limit search to a specific domain
16
+ -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
17
+ -p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq) [default: openai]
18
+ -m num --max_pages=num Max number of pages to retrieve [default: 10]
19
+
20
+ """
21
+
22
+ import json
23
+ import os
24
+ from concurrent.futures import ThreadPoolExecutor
25
+ from urllib.parse import quote
26
+
27
+ from bs4 import BeautifulSoup
28
+ from docopt import docopt
29
+ import dotenv
30
+
31
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
32
+ from langchain.schema import SystemMessage, HumanMessage
33
+ from langchain.callbacks import LangChainTracer
34
+ from langchain_groq import ChatGroq
35
+ from langchain_openai import ChatOpenAI
36
+ from langchain_openai import OpenAIEmbeddings
37
+ from langchain_community.vectorstores.faiss import FAISS
38
+ from langchain_community.chat_models.bedrock import BedrockChat
39
+ from langsmith import Client
40
+
41
+ import requests
42
+
43
+ from rich.console import Console
44
+ from rich.rule import Rule
45
+ from rich.markdown import Markdown
46
+
47
+
48
+ def get_chat_llm(provider, temperature=0.0):
49
+ console.log(f"Using provider {provider} with temperature {temperature}")
50
+ match provider:
51
+ case 'bedrock':
52
+ chat_llm = BedrockChat(
53
+ credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME'),
54
+ model_id="anthropic.claude-3-sonnet-20240229-v1:0",
55
+ model_kwargs={"temperature": temperature },
56
+ )
57
+ case 'openai':
58
+ chat_llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=temperature)
59
+ case 'groq':
60
+ chat_llm = ChatGroq(model_name = 'mixtral-8x7b-32768', temperature=temperature)
61
+ case _:
62
+ raise ValueError(f"Unknown LLM provider {provider}")
63
+ return chat_llm
64
+
65
+ def optimize_search_query(query):
66
+ messages = [
67
+ SystemMessage(
68
+ content="""
69
+ You are a serach query optimizer specialist.
70
+ Rewrite the user's question using only the most important keywords. Remove extra words.
71
+ Tips:
72
+ Identify the key concepts in the question
73
+ Remove filler words like "how to", "what is", "I want to"
74
+ Removed style such as "in the style of", "engaging", "short", "long"
75
+ Remove lenght instruction (example: essay, article, letter, blog, post, blogpost, etc)
76
+ Keep it short, around 3-7 words total
77
+ Put the most important keywords first
78
+ Remove formatting instructions
79
+ Remove style instructions (exmaple: in the style of, engaging, short, long)
80
+ Remove lenght instruction (example: essay, article, letter, etc)
81
+ Example:
82
+ Question: How do I bake chocolate chip cookies from scratch?
83
+ Search query: chocolate chip cookies recipe from scratch
84
+ Example:
85
+ Question: I would like you to show me a time line of Marie Curie life. Show results as a markdown table
86
+ Search query: Marie Curie timeline
87
+ Example:
88
+ Question: I would like you to write a long article on nato vs russia. Use know geopolical frameworks.
89
+ Search query: geopolitics nato russia
90
+ Example:
91
+ Question: Write a engaging linkedin post about Andrew Ng
92
+ Search query: Andrew Ng
93
+ Example:
94
+ Question: Write a short artible about the solar system in the style of Carl Sagan
95
+ Search query: solar system
96
+ Example:
97
+ Question: Should I use Kubernetes? Answer in the style of Gilfoyde from the TV show Silicon Valley
98
+ Search query: Kubernetes decision
99
+ Example:
100
+ Question: biography of napoleon. include a table with the major events.
101
+ Search query: napoleon biography events
102
+ """
103
+ ),
104
+ HumanMessage(
105
+ content=f"""
106
+ Questions: {query}
107
+ Search query:
108
+ """
109
+ ),
110
+ ]
111
+
112
+ response = chat.invoke(messages, config={"callbacks": callbacks})
113
+ return response.content
114
+
115
+
116
+ def get_sources(query, max_pages=10, domain=None):
117
+ search_query = query
118
+ if domain:
119
+ search_query += f" site:{domain}"
120
+
121
+ url = f"https://api.search.brave.com/res/v1/web/search?q={quote(search_query)}&count={max_pages}"
122
+ headers = {
123
+ 'Accept': 'application/json',
124
+ 'Accept-Encoding': 'gzip',
125
+ 'X-Subscription-Token': os.getenv("BRAVE_SEARCH_API_KEY")
126
+ }
127
+
128
+ try:
129
+ response = requests.get(url, headers=headers)
130
+
131
+ if response.status_code != 200:
132
+ raise Exception(f"HTTP error! status: {response.status_code}")
133
+
134
+ json_response = response.json()
135
+
136
+ if 'web' not in json_response or 'results' not in json_response['web']:
137
+ raise Exception('Invalid API response format')
138
+
139
+ final_results = [{
140
+ 'title': result['title'],
141
+ 'link': result['url'],
142
+ 'snippet': result['description'],
143
+ 'favicon': result.get('profile', {}).get('img', '')
144
+ } for result in json_response['web']['results']]
145
+
146
+ return final_results
147
+
148
+ except Exception as error:
149
+ #console.log('Error fetching search results:', error)
150
+ raise
151
+
152
+
153
+
154
+ def fetch_with_timeout(url, timeout=8):
155
+ try:
156
+ response = requests.get(url, timeout=timeout)
157
+ response.raise_for_status()
158
+ return response
159
+ except requests.RequestException as error:
160
+ #console.log(f"Skipping {url}! Error: {error}")
161
+ return None
162
+
163
+ def extract_main_content(html):
164
+ try:
165
+ soup = BeautifulSoup(html, 'html.parser')
166
+ for element in soup(["script", "style", "head", "nav", "footer", "iframe", "img"]):
167
+ element.extract()
168
+ main_content = ' '.join(soup.body.get_text().split())
169
+ return main_content
170
+ except Exception as error:
171
+ #console.log(f"Error extracting main content: {error}")
172
+ return None
173
+
174
+ def process_source(source):
175
+ response = fetch_with_timeout(source['link'], 8)
176
+ if response:
177
+ html = response.text
178
+ main_content = extract_main_content(html)
179
+ return {**source, 'html': main_content}
180
+ return None
181
+
182
+ def get_links_contents(sources):
183
+ with ThreadPoolExecutor() as executor:
184
+ results = list(executor.map(process_source, sources))
185
+
186
+ # Filter out None results
187
+ return [result for result in results if result is not None]
188
+
189
+ def process_and_vectorize_content(
190
+ contents,
191
+ query,
192
+ text_chunk_size=1000,
193
+ text_chunk_overlap=200,
194
+ number_of_similarity_results=5
195
+ ):
196
+ """
197
+ Process and vectorize content using Langchain.
198
+
199
+ Args:
200
+ contents (list): List of dictionaries containing 'title', 'link', and 'html' keys.
201
+ query (str): Query string for similarity search.
202
+ text_chunk_size (int): Size of each text chunk.
203
+ text_chunk_overlap (int): Overlap between text chunks.
204
+ number_of_similarity_results (int): Number of most similar results to return.
205
+
206
+ Returns:
207
+ list: List of most similar documents.
208
+ """
209
+ documents = []
210
+
211
+ for content in contents:
212
+ if content['html']:
213
+ try:
214
+ # Split text into chunks
215
+ text_splitter = RecursiveCharacterTextSplitter(
216
+ chunk_size=text_chunk_size,
217
+ chunk_overlap=text_chunk_overlap
218
+ )
219
+ texts = text_splitter.split_text(content['html'])
220
+
221
+ # Create metadata for each text chunk
222
+ metadatas = [{'title': content['title'], 'link': content['link']} for _ in range(len(texts))]
223
+
224
+ # Create vector store
225
+ embeddings = OpenAIEmbeddings()
226
+ docsearch = FAISS.from_texts(texts, embedding=embeddings, metadatas=metadatas)
227
+
228
+ # Perform similarity search
229
+ docs = docsearch.similarity_search(query, k=number_of_similarity_results)
230
+ doc_dicts = [{'page_content': doc.page_content, 'metadata': doc.metadata} for doc in docs]
231
+ documents.extend(doc_dicts)
232
+
233
+ except Exception as e:
234
+ console.log(f"[gray]Error processing content for {content['link']}: {e}")
235
+
236
+
237
+ return documents
238
+
239
+
240
+ def answer_query_with_sources(query, relevant_docs):
241
+ messages = [
242
+ SystemMessage(
243
+ content="""
244
+ You are an expert research assistant.
245
+ You are provided with a Context in JSON format and a Question.
246
+
247
+ Use RAG to answer the Question, providing references and links to the Context material you retrieve and use in your answer:
248
+ When generating your answer, follow these steps:
249
+ - Retrieve the most relevant context material from your knowledge base to help answer the question
250
+ - Cite the references you use by including the title, author, publication, and a link to each source
251
+ - Synthesize the retrieved information into a clear, informative answer to the question
252
+ - Format your answer in Markdown, using heading levels 2-3 as needed
253
+ - Include a "References" section at the end with the full citations and link for each source you used
254
+
255
+
256
+ Example of Context JSON entry:
257
+ {
258
+ "page_content": "This provides access to material related to ...",
259
+ "metadata": {
260
+ "title": "Introduction - Marie Curie: Topics in Chronicling America",
261
+ "link": "https://guides.loc.gov/chronicling-america-marie-curie"
262
+ }
263
+ }
264
+
265
+ """
266
+ ),
267
+ HumanMessage(
268
+ content= f"""
269
+ Context information is below.
270
+ Context:
271
+ ---------------------
272
+ {json.dumps(relevant_docs, indent=2, ensure_ascii=False)}
273
+ ---------------------
274
+ Question: {query}
275
+ Answer:
276
+ """
277
+ ),
278
+ ]
279
+
280
+ response = chat.invoke(messages, config={"callbacks": callbacks})
281
+ return response
282
+
283
+ console = Console()
284
+ dotenv.load_dotenv()
285
+
286
+ callbacks = []
287
+ if(os.getenv("LANGCHAIN_API_KEY")):
288
+ callbacks.append(
289
+ LangChainTracer(
290
+ project_name="search agent",
291
+ client=Client(
292
+ api_url="https://api.smith.langchain.com",
293
+ )
294
+ )
295
+ )
296
+
297
+ if __name__ == '__main__':
298
+ arguments = docopt(__doc__, version='Search Agent 0.1')
299
+ #print(arguments)
300
+
301
+
302
+ provider = arguments["--provider"]
303
+ temperature = float(arguments["--temperature"])
304
+ chat = get_chat_llm(provider, temperature)
305
+ query = arguments["SEARCH_QUERY"]
306
+
307
+ with console.status(f"[bold green]Optimizing query for search: {query}"):
308
+ optimize_search_query = optimize_search_query(query)
309
+ console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
310
+
311
+ domain=arguments["--domain"]
312
+ max_pages=arguments["--max_pages"]
313
+ with console.status(f"[bold green]Searching sources using the optimized query: {optimize_search_query}"):
314
+ sources = get_sources(optimize_search_query, max_pages=max_pages, domain=domain)
315
+ console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
316
+
317
+ with console.status(f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"):
318
+ contents = get_links_contents(sources)
319
+ console.log(f"Managed to extract content from {len(contents)} sources")
320
+
321
+ with console.status(
322
+ f"[bold green]Processing {len(contents)} contents and finding relevant extracts",
323
+ spinner="dots8Bit"
324
+ ):
325
+ relevant_docs = process_and_vectorize_content(contents, query)
326
+ console.log(f"Filtered {len(relevant_docs)} relevant content extracts")
327
+
328
+ with console.status(f"[bold green]Querying LLM with {len(relevant_docs)} relevant extracts", spinner='dots8Bit'):
329
+ respomse = answer_query_with_sources(query, relevant_docs)
330
+
331
+ console.rule(f"[bold green]Response from {provider}")
332
+ console.print(Markdown(respomse.content))
333
+ console.rule("[bold green]")