CyranoB commited on
Commit
d594a38
1 Parent(s): d21cce9

Added web ui using streamlit

Browse files
dotenv.sample CHANGED
@@ -1,10 +1,18 @@
1
-
2
- OPENAI_API_KEY=sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
3
- ANTHROPIC_API_KEY=sk-ant-api03-XXXXXXXXXXXXXXXXXXX
4
- GROQ_API_KEY=gsk_XXXXXXXXXXXXXXXXXXXXXXXX
5
- CREDENTIALS_PROFILE_NAME=XXXXXXXXXXXXXXXX
6
-
7
- LANGCHAIN_API_KEY=ls__XXXXXXXXXXXXXXXXXXXXXXXXXX
8
  LANGCHAIN_TRACING_V2=true
 
 
9
 
10
- BRAVE_SEARCH_API_KEY=BSXXXXXXXXXXXXXXXX
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LANGCHAIN_API_KEY=ls__XXXXXXXXXXXXXXXXXXX
 
 
 
 
 
 
2
  LANGCHAIN_TRACING_V2=true
3
+ LANGCHAIN_PROJECT="search agent"
4
+ LANGCHAIN_ENDPOINT="https://api.smith.langchain.com"
5
 
6
+
7
+ OPENAI_API_KEY=sk-XXXXXXXXXXXXXXXXXXX
8
+ ANTHROPIC_API_KEY=sk-ant-api03-XXXXXXXXXXXXXXXXXXX
9
+ GROQ_API_KEY=gsk_XXXXXXXXXXXXXXXXXXX
10
+ CREDENTIALS_PROFILE_NAME=XXXXXXXXXXXXXXXXXXX
11
+ COHERE_API_KEY=XXXXXXXXXXXXXXXXXXX
12
+ TAVILY_API_KEY=tvly-XXXXXXXXXXXXXXXXXXX
13
+ BRAVE_SEARCH_API_KEY=XXXXXXXXXXXXXXXXXXX
14
+ SERPER_API_KEY=XXXXXXXXXXXXXXXXXXX
15
+ SERPAPI_API_KEY=XXXXXXXXXXXXXXXXXXX
16
+ WOLFRAM_ALPHA_APPID=XXXXXXXXXXXXXXXXXXX
17
+ GOOGLE_CSE_ID=XXXXXXXXXXXXXXXXXXX
18
+ GOOGLE_API_KEY=XXXXXXXXXXXXXXXXXXX
requirements.txt CHANGED
@@ -14,5 +14,7 @@ langchain_experimental
14
  langchain_openai
15
  langchain_groq
16
  langsmith
 
17
  rich
18
  trafilatura
 
 
14
  langchain_openai
15
  langchain_groq
16
  langsmith
17
+ streamlit
18
  rich
19
  trafilatura
20
+ watchdog
search_agent.py CHANGED
@@ -23,241 +23,20 @@ Options:
23
 
24
  """
25
 
26
- import json
27
  import os
28
- import io
29
- from concurrent.futures import ThreadPoolExecutor
30
- from urllib.parse import quote
31
 
32
  from docopt import docopt
33
  import dotenv
34
- import pdfplumber
35
- from trafilatura import extract
36
 
37
- from selenium import webdriver
38
- from selenium.webdriver.chrome.options import Options
39
-
40
- from langchain_core.documents.base import Document
41
- from langchain_experimental.text_splitter import SemanticChunker
42
- from langchain.retrievers.multi_query import MultiQueryRetriever
43
  from langchain.callbacks import LangChainTracer
44
- from langchain_cohere.chat_models import ChatCohere
45
- from langchain_groq import ChatGroq
46
- from langchain_openai import ChatOpenAI
47
- from langchain_openai import OpenAIEmbeddings
48
- from langchain_community.chat_models.bedrock import BedrockChat
49
- from langchain_community.chat_models.ollama import ChatOllama
50
- from langchain_community.vectorstores.faiss import FAISS
51
 
52
  from langsmith import Client
53
 
54
- import requests
55
-
56
  from rich.console import Console
57
  from rich.markdown import Markdown
58
 
59
- from messages import get_rag_prompt_template, get_optimized_search_messages
60
-
61
-
62
- def get_chat_llm(provider, model=None, temperature=0.0):
63
- match provider:
64
- case 'bedrock':
65
- if model is None:
66
- model = "anthropic.claude-3-sonnet-20240229-v1:0"
67
- chat_llm = BedrockChat(
68
- credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME'),
69
- model_id=model,
70
- model_kwargs={"temperature": temperature },
71
- )
72
- case 'openai':
73
- if model is None:
74
- model = "gpt-3.5-turbo"
75
- chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
76
- case 'groq':
77
- if model is None:
78
- model = 'mixtral-8x7b-32768'
79
- chat_llm = ChatGroq(model_name=model, temperature=temperature)
80
- case 'ollama':
81
- if model is None:
82
- model = 'llama2'
83
- chat_llm = ChatOllama(model=model, temperature=temperature)
84
- case 'cohere':
85
- if model is None:
86
- model = 'command-r-plus'
87
- chat_llm = ChatCohere(model=model, temperature=temperature)
88
- case _:
89
- raise ValueError(f"Unknown LLM provider {provider}")
90
-
91
- console.log(f"Using {model} on {provider} with temperature {temperature}")
92
- return chat_llm
93
-
94
- def optimize_search_query(chat_llm, query):
95
- messages = get_optimized_search_messages(query)
96
- response = chat_llm.invoke(messages, config={"callbacks": callbacks})
97
- optimized_search_query = response.content
98
- return optimized_search_query.strip('"').split("**", 1)[0]
99
-
100
-
101
- def get_sources(query, max_pages=10, domain=None):
102
- search_query = query
103
- if domain:
104
- search_query += f" site:{domain}"
105
-
106
- url = f"https://api.search.brave.com/res/v1/web/search?q={quote(search_query)}&count={max_pages}"
107
- headers = {
108
- 'Accept': 'application/json',
109
- 'Accept-Encoding': 'gzip',
110
- 'X-Subscription-Token': os.getenv("BRAVE_SEARCH_API_KEY")
111
- }
112
-
113
- try:
114
- response = requests.get(url, headers=headers, timeout=30)
115
-
116
- if response.status_code != 200:
117
- return []
118
-
119
- json_response = response.json()
120
-
121
- if 'web' not in json_response or 'results' not in json_response['web']:
122
- raise Exception('Invalid API response format')
123
-
124
- final_results = [{
125
- 'title': result['title'],
126
- 'link': result['url'],
127
- 'snippet': result['description'],
128
- 'favicon': result.get('profile', {}).get('img', '')
129
- } for result in json_response['web']['results']]
130
-
131
- return final_results
132
-
133
- except Exception as error:
134
- console.log('Error fetching search results:', error)
135
- raise
136
-
137
- def fetch_with_selenium(url, timeout=8):
138
- chrome_options = Options()
139
- chrome_options.add_argument("headless")
140
- chrome_options.add_argument("--disable-extensions")
141
- chrome_options.add_argument("--disable-gpu")
142
- chrome_options.add_argument("--no-sandbox")
143
- chrome_options.add_argument("--disable-dev-shm-usage")
144
- chrome_options.add_argument("--remote-debugging-port=9222")
145
- chrome_options.add_argument('--blink-settings=imagesEnabled=false')
146
- chrome_options.add_argument("--window-size=1920,1080")
147
-
148
- driver = webdriver.Chrome(options=chrome_options)
149
-
150
- driver.get(url)
151
- driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
152
- html = driver.page_source
153
- driver.quit()
154
-
155
- return html
156
-
157
- def fetch_with_timeout(url, timeout=8):
158
- try:
159
- response = requests.get(url, timeout=timeout)
160
- response.raise_for_status()
161
- return response
162
- except requests.RequestException as error:
163
- return None
164
-
165
-
166
- def process_source(source):
167
- url = source['link']
168
- #console.log(f"Processing {url}")
169
- response = fetch_with_timeout(url, 8)
170
- if response:
171
- content_type = response.headers.get('Content-Type')
172
- if content_type:
173
- if content_type.startswith('application/pdf'):
174
- # The response is a PDF file
175
- pdf_content = response.content
176
- # Create a file-like object from the bytes
177
- pdf_file = io.BytesIO(pdf_content)
178
- # Extract text from PDF using pdfplumber
179
- with pdfplumber.open(pdf_file) as pdf:
180
- text = ""
181
- for page in pdf.pages:
182
- text += page.extract_text()
183
- return {**source, 'page_content': text}
184
- elif content_type.startswith('text/html'):
185
- # The response is an HTML file
186
- html = response.text
187
- main_content = extract(html, output_format='txt', include_links=True)
188
- return {**source, 'page_content': main_content}
189
- else:
190
- console.log(f"Skipping {url}! Unsupported content type: {content_type}")
191
- return {**source, 'page_content': source['snippet']}
192
- else:
193
- console.log(f"Skipping {url}! No content type")
194
- return {**source, 'page_content': source['snippet']}
195
- return {**source, 'page_content': None}
196
-
197
- def get_links_contents(sources):
198
- with ThreadPoolExecutor() as executor:
199
- results = list(executor.map(process_source, sources))
200
- for result in results:
201
- if result['page_content'] is None:
202
- url = result['link']
203
- console.log(f"Fetching with selenium {url}")
204
- html = fetch_with_selenium(url, 8)
205
- main_content = extract(html, output_format='txt', include_links=True)
206
- if main_content:
207
- result['page_content'] = main_content
208
-
209
- # Filter out None results
210
- return [result for result in results if result is not None]
211
-
212
- def vectorize(contents, text_chunk_size=400,text_chunk_overlap=40):
213
- documents = []
214
- for content in contents:
215
- try:
216
- metadata = {'title': content['title'], 'source': content['link']}
217
- doc = Document(page_content=content['page_content'], metadata=metadata)
218
- documents.append(doc)
219
- except Exception as e:
220
- console.log(f"[gray]Error processing content for {content['link']}: {e}")
221
- semantic_chunker = SemanticChunker(OpenAIEmbeddings(model="text-embedding-3-large"), breakpoint_threshold_type="percentile")
222
- docs = semantic_chunker.split_documents(documents)
223
- console.log(f"Vectorizing {len(docs)} document chunks")
224
- embeddings = OpenAIEmbeddings()
225
- store = FAISS.from_documents(docs, embeddings)
226
- return store
227
-
228
- def format_docs(docs):
229
- formatted_docs = []
230
- for d in docs:
231
- content = d.page_content
232
- title = d.metadata['title']
233
- source = d.metadata['source']
234
- doc = {"content": content, "title": title, "link": source}
235
- formatted_docs.append(doc)
236
- docs_as_json = json.dumps(formatted_docs, indent=2, ensure_ascii=False)
237
- return docs_as_json
238
-
239
-
240
- def multi_query_rag(chat_llm, question, search_query, vectorstore):
241
- retriever_from_llm = MultiQueryRetriever.from_llm(
242
- retriever=vectorstore.as_retriever(), llm=chat_llm, include_original=True,
243
- )
244
- unique_docs = retriever_from_llm.get_relevant_documents(
245
- query=search_query, callbacks=callbacks, verbose=True
246
- )
247
- context = format_docs(unique_docs)
248
- prompt = get_rag_prompt_template().format(query=question, context=context)
249
- response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
250
- return response.content
251
-
252
-
253
- def query_rag(chat_llm, question, search_query, vectorstore):
254
- unique_docs = vectorstore.similarity_search(search_query, k=15, callbacks=callbacks, verbose=True)
255
- context = format_docs(unique_docs)
256
- prompt = get_rag_prompt_template().format(query=question, context=context)
257
- response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
258
- return response.content
259
-
260
-
261
 
262
  console = Console()
263
  dotenv.load_dotenv()
@@ -265,12 +44,7 @@ dotenv.load_dotenv()
265
  callbacks = []
266
  if os.getenv("LANGCHAIN_API_KEY"):
267
  callbacks.append(
268
- LangChainTracer(
269
- project_name="search agent",
270
- client=Client(
271
- api_url="https://api.smith.langchain.com",
272
- )
273
- )
274
  )
275
 
276
  if __name__ == '__main__':
@@ -284,29 +58,30 @@ if __name__ == '__main__':
284
  output=arguments["--output"]
285
  query = arguments["SEARCH_QUERY"]
286
 
287
- chat = get_chat_llm(provider, model, temperature)
 
288
 
289
  with console.status(f"[bold green]Optimizing query for search: {query}"):
290
- optimize_search_query = optimize_search_query(chat, query)
291
  console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
292
 
293
  with console.status(
294
  f"[bold green]Searching sources using the optimized query: {optimize_search_query}"
295
  ):
296
- sources = get_sources(optimize_search_query, max_pages=max_pages, domain=domain)
297
  console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
298
 
299
  with console.status(
300
  f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
301
  ):
302
- contents = get_links_contents(sources)
303
  console.log(f"Managed to extract content from {len(contents)} sources")
304
 
305
  with console.status(f"[bold green]Embeddubg {len(contents)} sources for content", spinner="growVertical"):
306
- vector_store = vectorize(contents)
307
 
308
  with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
309
- respomse = query_rag(chat, query, optimize_search_query, vector_store)
310
 
311
  console.rule(f"[bold green]Response from {provider}")
312
  if output == "text":
 
23
 
24
  """
25
 
 
26
  import os
 
 
 
27
 
28
  from docopt import docopt
29
  import dotenv
 
 
30
 
 
 
 
 
 
 
31
  from langchain.callbacks import LangChainTracer
 
 
 
 
 
 
 
32
 
33
  from langsmith import Client
34
 
 
 
35
  from rich.console import Console
36
  from rich.markdown import Markdown
37
 
38
+ import web_rag as wr
39
+ import web_crawler as wc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  console = Console()
42
  dotenv.load_dotenv()
 
44
  callbacks = []
45
  if os.getenv("LANGCHAIN_API_KEY"):
46
  callbacks.append(
47
+ LangChainTracer(client=Client())
 
 
 
 
 
48
  )
49
 
50
  if __name__ == '__main__':
 
58
  output=arguments["--output"]
59
  query = arguments["SEARCH_QUERY"]
60
 
61
+ chat = wr.get_chat_llm(provider, model, temperature)
62
+ console.log(f"Using {chat.get_name} on {provider} with temperature {temperature}")
63
 
64
  with console.status(f"[bold green]Optimizing query for search: {query}"):
65
+ optimize_search_query = wr.optimize_search_query(chat, query, callbacks=callbacks)
66
  console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
67
 
68
  with console.status(
69
  f"[bold green]Searching sources using the optimized query: {optimize_search_query}"
70
  ):
71
+ sources = wc.get_sources(optimize_search_query, max_pages=max_pages, domain=domain)
72
  console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
73
 
74
  with console.status(
75
  f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
76
  ):
77
+ contents = wc.get_links_contents(sources)
78
  console.log(f"Managed to extract content from {len(contents)} sources")
79
 
80
  with console.status(f"[bold green]Embeddubg {len(contents)} sources for content", spinner="growVertical"):
81
+ vector_store = wc.vectorize(contents)
82
 
83
  with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
84
+ respomse = wr.query_rag(chat, query, optimize_search_query, vector_store, callbacks=callbacks)
85
 
86
  console.rule(f"[bold green]Response from {provider}")
87
  if output == "text":
search_agent_ui.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dotenv
2
+
3
+ import streamlit as st
4
+
5
+ import web_rag as wr
6
+ import web_crawler as wc
7
+
8
+ from langchain_core.tracers.langchain import LangChainTracer
9
+ from langsmith.client import Client
10
+
11
+ dotenv.load_dotenv()
12
+
13
+ ls_tracer = LangChainTracer(
14
+ project_name="Search Agent UI",
15
+ client=Client()
16
+ )
17
+
18
+ chat = wr.get_chat_llm(provider="cohere")
19
+
20
+ st.title("🔍 Simple Search Agent 💬")
21
+
22
+ if "messages" not in st.session_state:
23
+ st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
24
+
25
+ for message in st.session_state.messages:
26
+ st.chat_message(message["role"]).write(message["content"])
27
+
28
+ if prompt := st.chat_input():
29
+
30
+ st.chat_message("user").write(prompt)
31
+ st.session_state.messages.append({"role": "user", "content": prompt})
32
+
33
+ message = "I first need to do some research"
34
+ st.chat_message("assistant").write(message)
35
+ st.session_state.messages.append({"role": "assistant", "content": message})
36
+
37
+ with st.spinner("Optimizing search query"):
38
+ optimize_search_query = wr.optimize_search_query(chat, query=prompt, callbacks=[ls_tracer])
39
+
40
+ message = f"I'll search the web for: {optimize_search_query}"
41
+ st.chat_message("assistant").write(message)
42
+ st.session_state.messages.append({"role": "assistant", "content": message})
43
+
44
+
45
+ with st.spinner(f"Searching the web for: {optimize_search_query}"):
46
+ sources = wc.get_sources(optimize_search_query)
47
+
48
+ with st.spinner(f"I'm now retrieveing the {len(sources)} webpages and documents I found (be patient)"):
49
+ contents = wc.get_links_contents(sources)
50
+
51
+
52
+ with st.spinner( f"Reading through the {len(contents)} sources I managed to retrieve"):
53
+ vector_store = wc.vectorize(contents)
54
+
55
+ with st.spinner( "Ok I have now enough information to answer"):
56
+ response = wr.query_rag(chat, prompt, optimize_search_query, vector_store, callbacks=[ls_tracer])
57
+
58
+ st.chat_message("assistant").write(response)
59
+ st.session_state.messages.append({"role": "assistant", "content": response})
60
+
web_crawler.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ from urllib.parse import quote
3
+
4
+ import os
5
+ import io
6
+
7
+ from trafilatura import extract
8
+ from selenium import webdriver
9
+ from selenium.webdriver.chrome.options import Options
10
+ from selenium.common.exceptions import TimeoutException
11
+ from langchain_core.documents.base import Document
12
+ from langchain_experimental.text_splitter import SemanticChunker
13
+ from langchain_openai import OpenAIEmbeddings
14
+ from langchain_community.vectorstores.faiss import FAISS
15
+
16
+ import requests
17
+ import pdfplumber
18
+
19
+ def get_sources(query, max_pages=10, domain=None):
20
+ search_query = query
21
+ if domain:
22
+ search_query += f" site:{domain}"
23
+
24
+ url = f"https://api.search.brave.com/res/v1/web/search?q={quote(search_query)}&count={max_pages}"
25
+ headers = {
26
+ 'Accept': 'application/json',
27
+ 'Accept-Encoding': 'gzip',
28
+ 'X-Subscription-Token': os.getenv("BRAVE_SEARCH_API_KEY")
29
+ }
30
+
31
+ try:
32
+ response = requests.get(url, headers=headers, timeout=30)
33
+
34
+ if response.status_code != 200:
35
+ return []
36
+
37
+ json_response = response.json()
38
+
39
+ if 'web' not in json_response or 'results' not in json_response['web']:
40
+ raise Exception('Invalid API response format')
41
+
42
+ final_results = [{
43
+ 'title': result['title'],
44
+ 'link': result['url'],
45
+ 'snippet': result['description'],
46
+ 'favicon': result.get('profile', {}).get('img', '')
47
+ } for result in json_response['web']['results']]
48
+
49
+ return final_results
50
+
51
+ except Exception as error:
52
+ print('Error fetching search results:', error)
53
+ raise
54
+
55
+ def fetch_with_selenium(url, timeout=8):
56
+ chrome_options = Options()
57
+ chrome_options.add_argument("headless")
58
+ chrome_options.add_argument("--disable-extensions")
59
+ chrome_options.add_argument("--disable-gpu")
60
+ chrome_options.add_argument("--no-sandbox")
61
+ chrome_options.add_argument("--disable-dev-shm-usage")
62
+ chrome_options.add_argument("--remote-debugging-port=9222")
63
+ chrome_options.add_argument('--blink-settings=imagesEnabled=false')
64
+ chrome_options.add_argument("--window-size=1920,1080")
65
+
66
+ driver = webdriver.Chrome(options=chrome_options)
67
+
68
+ try:
69
+ driver.set_page_load_timeout(timeout)
70
+ driver.get(url)
71
+ driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
72
+ html = driver.page_source
73
+ except TimeoutException:
74
+ print(f"Page load timed out after {timeout} seconds.")
75
+ html = None
76
+ finally:
77
+ driver.quit()
78
+
79
+ return html
80
+
81
+ def fetch_with_timeout(url, timeout=8):
82
+ try:
83
+ response = requests.get(url, timeout=timeout)
84
+ response.raise_for_status()
85
+ return response
86
+ except requests.RequestException as error:
87
+ return None
88
+
89
+
90
+ def process_source(source):
91
+ url = source['link']
92
+ #console.log(f"Processing {url}")
93
+ response = fetch_with_timeout(url, 8)
94
+ if response:
95
+ content_type = response.headers.get('Content-Type')
96
+ if content_type:
97
+ if content_type.startswith('application/pdf'):
98
+ # The response is a PDF file
99
+ pdf_content = response.content
100
+ # Create a file-like object from the bytes
101
+ pdf_file = io.BytesIO(pdf_content)
102
+ # Extract text from PDF using pdfplumber
103
+ with pdfplumber.open(pdf_file) as pdf:
104
+ text = ""
105
+ for page in pdf.pages:
106
+ text += page.extract_text()
107
+ return {**source, 'page_content': text}
108
+ elif content_type.startswith('text/html'):
109
+ # The response is an HTML file
110
+ html = response.text
111
+ main_content = extract(html, output_format='txt', include_links=True)
112
+ return {**source, 'page_content': main_content}
113
+ else:
114
+ print(f"Skipping {url}! Unsupported content type: {content_type}")
115
+ return {**source, 'page_content': source['snippet']}
116
+ else:
117
+ print(f"Skipping {url}! No content type")
118
+ return {**source, 'page_content': source['snippet']}
119
+ return {**source, 'page_content': None}
120
+
121
+ def get_links_contents(sources):
122
+ with ThreadPoolExecutor() as executor:
123
+ results = list(executor.map(process_source, sources))
124
+ for result in results:
125
+ if result['page_content'] is None:
126
+ url = result['link']
127
+ print(f"Fetching with selenium {url}")
128
+ html = fetch_with_selenium(url, 8)
129
+ main_content = extract(html, output_format='txt', include_links=True)
130
+ if main_content:
131
+ result['page_content'] = main_content
132
+
133
+ # Filter out None results
134
+ return [result for result in results if result is not None]
135
+
136
+ def vectorize(contents):
137
+ documents = []
138
+ for content in contents:
139
+ try:
140
+ metadata = {'title': content['title'], 'source': content['link']}
141
+ doc = Document(page_content=content['page_content'], metadata=metadata)
142
+ documents.append(doc)
143
+ except Exception as e:
144
+ print(f"[gray]Error processing content for {content['link']}: {e}")
145
+ semantic_chunker = SemanticChunker(OpenAIEmbeddings(model="text-embedding-3-large"), breakpoint_threshold_type="percentile")
146
+ docs = semantic_chunker.split_documents(documents)
147
+ print(f"Vectorizing {len(docs)} document chunks")
148
+ embeddings = OpenAIEmbeddings()
149
+ store = FAISS.from_documents(docs, embeddings)
150
+ return store
messages.py → web_rag.py RENAMED
@@ -1,8 +1,24 @@
1
  """
2
- This module provides functions for generating optimized search messages, RAG prompt templates,
3
- and messages for queries with relevant source documents using the LangChain library.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
-
 
6
  from langchain.schema import SystemMessage, HumanMessage
7
  from langchain.prompts.chat import (
8
  HumanMessagePromptTemplate,
@@ -10,6 +26,44 @@ from langchain.prompts.chat import (
10
  ChatPromptTemplate
11
  )
12
  from langchain.prompts.prompt import PromptTemplate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def get_optimized_search_messages(query):
15
  """
@@ -76,6 +130,14 @@ def get_optimized_search_messages(query):
76
  )
77
  return [system_message, human_message]
78
 
 
 
 
 
 
 
 
 
79
  def get_rag_prompt_template():
80
  """
81
  Get the prompt template for Retrieval-Augmented Generation (RAG).
@@ -121,3 +183,35 @@ def get_rag_prompt_template():
121
  input_variables=["context", "query"],
122
  messages=[system_prompt, human_prompt],
123
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Module for performing retrieval-augmented generation (RAG) using LangChain.
3
+ This module provides functions to optimize search queries, retrieve relevant documents,
4
+ and generate answers to questions using the retrieved context. It leverages the LangChain
5
+ library for building the RAG pipeline.
6
+ Functions:
7
+ - get_optimized_search_messages(query: str) -> list:
8
+ Generate optimized search messages for a given query.
9
+ - optimize_search_query(chat_llm, query: str, callbacks: list = []) -> str:
10
+ Optimize the search query using the chat language model.
11
+ - get_rag_prompt_template() -> ChatPromptTemplate:
12
+ Get the prompt template for retrieval-augmented generation (RAG).
13
+ - format_docs(docs: list) -> str:
14
+ Format the retrieved documents into a JSON string.
15
+ - multi_query_rag(chat_llm, question: str, search_query: str, vectorstore, callbacks: list = []) -> str:
16
+ Perform RAG using multiple queries to retrieve relevant documents.
17
+ - query_rag(chat_llm, question: str, search_query: str, vectorstore, callbacks: list = []) -> str:
18
+ Perform RAG using a single query to retrieve relevant documents.
19
  """
20
+ import os
21
+ import json
22
  from langchain.schema import SystemMessage, HumanMessage
23
  from langchain.prompts.chat import (
24
  HumanMessagePromptTemplate,
 
26
  ChatPromptTemplate
27
  )
28
  from langchain.prompts.prompt import PromptTemplate
29
+ from langchain.retrievers.multi_query import MultiQueryRetriever
30
+
31
+ from langchain_cohere.chat_models import ChatCohere
32
+ from langchain_groq import ChatGroq
33
+ from langchain_openai import ChatOpenAI
34
+ from langchain_community.chat_models.bedrock import BedrockChat
35
+ from langchain_community.chat_models.ollama import ChatOllama
36
+
37
+ def get_chat_llm(provider, model=None, temperature=0.0):
38
+ match provider:
39
+ case 'bedrock':
40
+ if model is None:
41
+ model = "anthropic.claude-3-sonnet-20240229-v1:0"
42
+ chat_llm = BedrockChat(
43
+ credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME'),
44
+ model_id=model,
45
+ model_kwargs={"temperature": temperature },
46
+ )
47
+ case 'openai':
48
+ if model is None:
49
+ model = "gpt-3.5-turbo"
50
+ chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
51
+ case 'groq':
52
+ if model is None:
53
+ model = 'mixtral-8x7b-32768'
54
+ chat_llm = ChatGroq(model_name=model, temperature=temperature)
55
+ case 'ollama':
56
+ if model is None:
57
+ model = 'llama2'
58
+ chat_llm = ChatOllama(model=model, temperature=temperature)
59
+ case 'cohere':
60
+ if model is None:
61
+ model = 'command-r-plus'
62
+ chat_llm = ChatCohere(model=model, temperature=temperature)
63
+ case _:
64
+ raise ValueError(f"Unknown LLM provider {provider}")
65
+ return chat_llm
66
+
67
 
68
  def get_optimized_search_messages(query):
69
  """
 
130
  )
131
  return [system_message, human_message]
132
 
133
+
134
+ def optimize_search_query(chat_llm, query, callbacks=[]):
135
+ messages = get_optimized_search_messages(query)
136
+ response = chat_llm.invoke(messages, config={"callbacks": callbacks})
137
+ optimized_search_query = response.content
138
+ return optimized_search_query.strip('"').split("**", 1)[0]
139
+
140
+
141
  def get_rag_prompt_template():
142
  """
143
  Get the prompt template for Retrieval-Augmented Generation (RAG).
 
183
  input_variables=["context", "query"],
184
  messages=[system_prompt, human_prompt],
185
  )
186
+
187
+ def format_docs(docs):
188
+ formatted_docs = []
189
+ for d in docs:
190
+ content = d.page_content
191
+ title = d.metadata['title']
192
+ source = d.metadata['source']
193
+ doc = {"content": content, "title": title, "link": source}
194
+ formatted_docs.append(doc)
195
+ docs_as_json = json.dumps(formatted_docs, indent=2, ensure_ascii=False)
196
+ return docs_as_json
197
+
198
+
199
+ def multi_query_rag(chat_llm, question, search_query, vectorstore, callbacks = []):
200
+ retriever_from_llm = MultiQueryRetriever.from_llm(
201
+ retriever=vectorstore.as_retriever(), llm=chat_llm, include_original=True,
202
+ )
203
+ unique_docs = retriever_from_llm.get_relevant_documents(
204
+ query=search_query, callbacks=callbacks, verbose=True
205
+ )
206
+ context = format_docs(unique_docs)
207
+ prompt = get_rag_prompt_template().format(query=question, context=context)
208
+ response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
209
+ return response.content
210
+
211
+
212
+ def query_rag(chat_llm, question, search_query, vectorstore, callbacks = []):
213
+ unique_docs = vectorstore.similarity_search(search_query, k=15, callbacks=callbacks, verbose=True)
214
+ context = format_docs(unique_docs)
215
+ prompt = get_rag_prompt_template().format(query=question, context=context)
216
+ response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
217
+ return response.content