Spaces:
Running
Running
Added web ui using streamlit
Browse files- dotenv.sample +16 -8
- requirements.txt +2 -0
- search_agent.py +10 -235
- search_agent_ui.py +60 -0
- web_crawler.py +150 -0
- messages.py → web_rag.py +97 -3
dotenv.sample
CHANGED
@@ -1,10 +1,18 @@
|
|
1 |
-
|
2 |
-
OPENAI_API_KEY=sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
|
3 |
-
ANTHROPIC_API_KEY=sk-ant-api03-XXXXXXXXXXXXXXXXXXX
|
4 |
-
GROQ_API_KEY=gsk_XXXXXXXXXXXXXXXXXXXXXXXX
|
5 |
-
CREDENTIALS_PROFILE_NAME=XXXXXXXXXXXXXXXX
|
6 |
-
|
7 |
-
LANGCHAIN_API_KEY=ls__XXXXXXXXXXXXXXXXXXXXXXXXXX
|
8 |
LANGCHAIN_TRACING_V2=true
|
|
|
|
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LANGCHAIN_API_KEY=ls__XXXXXXXXXXXXXXXXXXX
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
LANGCHAIN_TRACING_V2=true
|
3 |
+
LANGCHAIN_PROJECT="search agent"
|
4 |
+
LANGCHAIN_ENDPOINT="https://api.smith.langchain.com"
|
5 |
|
6 |
+
|
7 |
+
OPENAI_API_KEY=sk-XXXXXXXXXXXXXXXXXXX
|
8 |
+
ANTHROPIC_API_KEY=sk-ant-api03-XXXXXXXXXXXXXXXXXXX
|
9 |
+
GROQ_API_KEY=gsk_XXXXXXXXXXXXXXXXXXX
|
10 |
+
CREDENTIALS_PROFILE_NAME=XXXXXXXXXXXXXXXXXXX
|
11 |
+
COHERE_API_KEY=XXXXXXXXXXXXXXXXXXX
|
12 |
+
TAVILY_API_KEY=tvly-XXXXXXXXXXXXXXXXXXX
|
13 |
+
BRAVE_SEARCH_API_KEY=XXXXXXXXXXXXXXXXXXX
|
14 |
+
SERPER_API_KEY=XXXXXXXXXXXXXXXXXXX
|
15 |
+
SERPAPI_API_KEY=XXXXXXXXXXXXXXXXXXX
|
16 |
+
WOLFRAM_ALPHA_APPID=XXXXXXXXXXXXXXXXXXX
|
17 |
+
GOOGLE_CSE_ID=XXXXXXXXXXXXXXXXXXX
|
18 |
+
GOOGLE_API_KEY=XXXXXXXXXXXXXXXXXXX
|
requirements.txt
CHANGED
@@ -14,5 +14,7 @@ langchain_experimental
|
|
14 |
langchain_openai
|
15 |
langchain_groq
|
16 |
langsmith
|
|
|
17 |
rich
|
18 |
trafilatura
|
|
|
|
14 |
langchain_openai
|
15 |
langchain_groq
|
16 |
langsmith
|
17 |
+
streamlit
|
18 |
rich
|
19 |
trafilatura
|
20 |
+
watchdog
|
search_agent.py
CHANGED
@@ -23,241 +23,20 @@ Options:
|
|
23 |
|
24 |
"""
|
25 |
|
26 |
-
import json
|
27 |
import os
|
28 |
-
import io
|
29 |
-
from concurrent.futures import ThreadPoolExecutor
|
30 |
-
from urllib.parse import quote
|
31 |
|
32 |
from docopt import docopt
|
33 |
import dotenv
|
34 |
-
import pdfplumber
|
35 |
-
from trafilatura import extract
|
36 |
|
37 |
-
from selenium import webdriver
|
38 |
-
from selenium.webdriver.chrome.options import Options
|
39 |
-
|
40 |
-
from langchain_core.documents.base import Document
|
41 |
-
from langchain_experimental.text_splitter import SemanticChunker
|
42 |
-
from langchain.retrievers.multi_query import MultiQueryRetriever
|
43 |
from langchain.callbacks import LangChainTracer
|
44 |
-
from langchain_cohere.chat_models import ChatCohere
|
45 |
-
from langchain_groq import ChatGroq
|
46 |
-
from langchain_openai import ChatOpenAI
|
47 |
-
from langchain_openai import OpenAIEmbeddings
|
48 |
-
from langchain_community.chat_models.bedrock import BedrockChat
|
49 |
-
from langchain_community.chat_models.ollama import ChatOllama
|
50 |
-
from langchain_community.vectorstores.faiss import FAISS
|
51 |
|
52 |
from langsmith import Client
|
53 |
|
54 |
-
import requests
|
55 |
-
|
56 |
from rich.console import Console
|
57 |
from rich.markdown import Markdown
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
def get_chat_llm(provider, model=None, temperature=0.0):
|
63 |
-
match provider:
|
64 |
-
case 'bedrock':
|
65 |
-
if model is None:
|
66 |
-
model = "anthropic.claude-3-sonnet-20240229-v1:0"
|
67 |
-
chat_llm = BedrockChat(
|
68 |
-
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME'),
|
69 |
-
model_id=model,
|
70 |
-
model_kwargs={"temperature": temperature },
|
71 |
-
)
|
72 |
-
case 'openai':
|
73 |
-
if model is None:
|
74 |
-
model = "gpt-3.5-turbo"
|
75 |
-
chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
|
76 |
-
case 'groq':
|
77 |
-
if model is None:
|
78 |
-
model = 'mixtral-8x7b-32768'
|
79 |
-
chat_llm = ChatGroq(model_name=model, temperature=temperature)
|
80 |
-
case 'ollama':
|
81 |
-
if model is None:
|
82 |
-
model = 'llama2'
|
83 |
-
chat_llm = ChatOllama(model=model, temperature=temperature)
|
84 |
-
case 'cohere':
|
85 |
-
if model is None:
|
86 |
-
model = 'command-r-plus'
|
87 |
-
chat_llm = ChatCohere(model=model, temperature=temperature)
|
88 |
-
case _:
|
89 |
-
raise ValueError(f"Unknown LLM provider {provider}")
|
90 |
-
|
91 |
-
console.log(f"Using {model} on {provider} with temperature {temperature}")
|
92 |
-
return chat_llm
|
93 |
-
|
94 |
-
def optimize_search_query(chat_llm, query):
|
95 |
-
messages = get_optimized_search_messages(query)
|
96 |
-
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
97 |
-
optimized_search_query = response.content
|
98 |
-
return optimized_search_query.strip('"').split("**", 1)[0]
|
99 |
-
|
100 |
-
|
101 |
-
def get_sources(query, max_pages=10, domain=None):
|
102 |
-
search_query = query
|
103 |
-
if domain:
|
104 |
-
search_query += f" site:{domain}"
|
105 |
-
|
106 |
-
url = f"https://api.search.brave.com/res/v1/web/search?q={quote(search_query)}&count={max_pages}"
|
107 |
-
headers = {
|
108 |
-
'Accept': 'application/json',
|
109 |
-
'Accept-Encoding': 'gzip',
|
110 |
-
'X-Subscription-Token': os.getenv("BRAVE_SEARCH_API_KEY")
|
111 |
-
}
|
112 |
-
|
113 |
-
try:
|
114 |
-
response = requests.get(url, headers=headers, timeout=30)
|
115 |
-
|
116 |
-
if response.status_code != 200:
|
117 |
-
return []
|
118 |
-
|
119 |
-
json_response = response.json()
|
120 |
-
|
121 |
-
if 'web' not in json_response or 'results' not in json_response['web']:
|
122 |
-
raise Exception('Invalid API response format')
|
123 |
-
|
124 |
-
final_results = [{
|
125 |
-
'title': result['title'],
|
126 |
-
'link': result['url'],
|
127 |
-
'snippet': result['description'],
|
128 |
-
'favicon': result.get('profile', {}).get('img', '')
|
129 |
-
} for result in json_response['web']['results']]
|
130 |
-
|
131 |
-
return final_results
|
132 |
-
|
133 |
-
except Exception as error:
|
134 |
-
console.log('Error fetching search results:', error)
|
135 |
-
raise
|
136 |
-
|
137 |
-
def fetch_with_selenium(url, timeout=8):
|
138 |
-
chrome_options = Options()
|
139 |
-
chrome_options.add_argument("headless")
|
140 |
-
chrome_options.add_argument("--disable-extensions")
|
141 |
-
chrome_options.add_argument("--disable-gpu")
|
142 |
-
chrome_options.add_argument("--no-sandbox")
|
143 |
-
chrome_options.add_argument("--disable-dev-shm-usage")
|
144 |
-
chrome_options.add_argument("--remote-debugging-port=9222")
|
145 |
-
chrome_options.add_argument('--blink-settings=imagesEnabled=false')
|
146 |
-
chrome_options.add_argument("--window-size=1920,1080")
|
147 |
-
|
148 |
-
driver = webdriver.Chrome(options=chrome_options)
|
149 |
-
|
150 |
-
driver.get(url)
|
151 |
-
driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
|
152 |
-
html = driver.page_source
|
153 |
-
driver.quit()
|
154 |
-
|
155 |
-
return html
|
156 |
-
|
157 |
-
def fetch_with_timeout(url, timeout=8):
|
158 |
-
try:
|
159 |
-
response = requests.get(url, timeout=timeout)
|
160 |
-
response.raise_for_status()
|
161 |
-
return response
|
162 |
-
except requests.RequestException as error:
|
163 |
-
return None
|
164 |
-
|
165 |
-
|
166 |
-
def process_source(source):
|
167 |
-
url = source['link']
|
168 |
-
#console.log(f"Processing {url}")
|
169 |
-
response = fetch_with_timeout(url, 8)
|
170 |
-
if response:
|
171 |
-
content_type = response.headers.get('Content-Type')
|
172 |
-
if content_type:
|
173 |
-
if content_type.startswith('application/pdf'):
|
174 |
-
# The response is a PDF file
|
175 |
-
pdf_content = response.content
|
176 |
-
# Create a file-like object from the bytes
|
177 |
-
pdf_file = io.BytesIO(pdf_content)
|
178 |
-
# Extract text from PDF using pdfplumber
|
179 |
-
with pdfplumber.open(pdf_file) as pdf:
|
180 |
-
text = ""
|
181 |
-
for page in pdf.pages:
|
182 |
-
text += page.extract_text()
|
183 |
-
return {**source, 'page_content': text}
|
184 |
-
elif content_type.startswith('text/html'):
|
185 |
-
# The response is an HTML file
|
186 |
-
html = response.text
|
187 |
-
main_content = extract(html, output_format='txt', include_links=True)
|
188 |
-
return {**source, 'page_content': main_content}
|
189 |
-
else:
|
190 |
-
console.log(f"Skipping {url}! Unsupported content type: {content_type}")
|
191 |
-
return {**source, 'page_content': source['snippet']}
|
192 |
-
else:
|
193 |
-
console.log(f"Skipping {url}! No content type")
|
194 |
-
return {**source, 'page_content': source['snippet']}
|
195 |
-
return {**source, 'page_content': None}
|
196 |
-
|
197 |
-
def get_links_contents(sources):
|
198 |
-
with ThreadPoolExecutor() as executor:
|
199 |
-
results = list(executor.map(process_source, sources))
|
200 |
-
for result in results:
|
201 |
-
if result['page_content'] is None:
|
202 |
-
url = result['link']
|
203 |
-
console.log(f"Fetching with selenium {url}")
|
204 |
-
html = fetch_with_selenium(url, 8)
|
205 |
-
main_content = extract(html, output_format='txt', include_links=True)
|
206 |
-
if main_content:
|
207 |
-
result['page_content'] = main_content
|
208 |
-
|
209 |
-
# Filter out None results
|
210 |
-
return [result for result in results if result is not None]
|
211 |
-
|
212 |
-
def vectorize(contents, text_chunk_size=400,text_chunk_overlap=40):
|
213 |
-
documents = []
|
214 |
-
for content in contents:
|
215 |
-
try:
|
216 |
-
metadata = {'title': content['title'], 'source': content['link']}
|
217 |
-
doc = Document(page_content=content['page_content'], metadata=metadata)
|
218 |
-
documents.append(doc)
|
219 |
-
except Exception as e:
|
220 |
-
console.log(f"[gray]Error processing content for {content['link']}: {e}")
|
221 |
-
semantic_chunker = SemanticChunker(OpenAIEmbeddings(model="text-embedding-3-large"), breakpoint_threshold_type="percentile")
|
222 |
-
docs = semantic_chunker.split_documents(documents)
|
223 |
-
console.log(f"Vectorizing {len(docs)} document chunks")
|
224 |
-
embeddings = OpenAIEmbeddings()
|
225 |
-
store = FAISS.from_documents(docs, embeddings)
|
226 |
-
return store
|
227 |
-
|
228 |
-
def format_docs(docs):
|
229 |
-
formatted_docs = []
|
230 |
-
for d in docs:
|
231 |
-
content = d.page_content
|
232 |
-
title = d.metadata['title']
|
233 |
-
source = d.metadata['source']
|
234 |
-
doc = {"content": content, "title": title, "link": source}
|
235 |
-
formatted_docs.append(doc)
|
236 |
-
docs_as_json = json.dumps(formatted_docs, indent=2, ensure_ascii=False)
|
237 |
-
return docs_as_json
|
238 |
-
|
239 |
-
|
240 |
-
def multi_query_rag(chat_llm, question, search_query, vectorstore):
|
241 |
-
retriever_from_llm = MultiQueryRetriever.from_llm(
|
242 |
-
retriever=vectorstore.as_retriever(), llm=chat_llm, include_original=True,
|
243 |
-
)
|
244 |
-
unique_docs = retriever_from_llm.get_relevant_documents(
|
245 |
-
query=search_query, callbacks=callbacks, verbose=True
|
246 |
-
)
|
247 |
-
context = format_docs(unique_docs)
|
248 |
-
prompt = get_rag_prompt_template().format(query=question, context=context)
|
249 |
-
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
250 |
-
return response.content
|
251 |
-
|
252 |
-
|
253 |
-
def query_rag(chat_llm, question, search_query, vectorstore):
|
254 |
-
unique_docs = vectorstore.similarity_search(search_query, k=15, callbacks=callbacks, verbose=True)
|
255 |
-
context = format_docs(unique_docs)
|
256 |
-
prompt = get_rag_prompt_template().format(query=question, context=context)
|
257 |
-
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
258 |
-
return response.content
|
259 |
-
|
260 |
-
|
261 |
|
262 |
console = Console()
|
263 |
dotenv.load_dotenv()
|
@@ -265,12 +44,7 @@ dotenv.load_dotenv()
|
|
265 |
callbacks = []
|
266 |
if os.getenv("LANGCHAIN_API_KEY"):
|
267 |
callbacks.append(
|
268 |
-
LangChainTracer(
|
269 |
-
project_name="search agent",
|
270 |
-
client=Client(
|
271 |
-
api_url="https://api.smith.langchain.com",
|
272 |
-
)
|
273 |
-
)
|
274 |
)
|
275 |
|
276 |
if __name__ == '__main__':
|
@@ -284,29 +58,30 @@ if __name__ == '__main__':
|
|
284 |
output=arguments["--output"]
|
285 |
query = arguments["SEARCH_QUERY"]
|
286 |
|
287 |
-
chat = get_chat_llm(provider, model, temperature)
|
|
|
288 |
|
289 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
290 |
-
optimize_search_query = optimize_search_query(chat, query)
|
291 |
console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
|
292 |
|
293 |
with console.status(
|
294 |
f"[bold green]Searching sources using the optimized query: {optimize_search_query}"
|
295 |
):
|
296 |
-
sources = get_sources(optimize_search_query, max_pages=max_pages, domain=domain)
|
297 |
console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
|
298 |
|
299 |
with console.status(
|
300 |
f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
|
301 |
):
|
302 |
-
contents = get_links_contents(sources)
|
303 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
304 |
|
305 |
with console.status(f"[bold green]Embeddubg {len(contents)} sources for content", spinner="growVertical"):
|
306 |
-
vector_store = vectorize(contents)
|
307 |
|
308 |
with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
|
309 |
-
respomse = query_rag(chat, query, optimize_search_query, vector_store)
|
310 |
|
311 |
console.rule(f"[bold green]Response from {provider}")
|
312 |
if output == "text":
|
|
|
23 |
|
24 |
"""
|
25 |
|
|
|
26 |
import os
|
|
|
|
|
|
|
27 |
|
28 |
from docopt import docopt
|
29 |
import dotenv
|
|
|
|
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
from langchain.callbacks import LangChainTracer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
from langsmith import Client
|
34 |
|
|
|
|
|
35 |
from rich.console import Console
|
36 |
from rich.markdown import Markdown
|
37 |
|
38 |
+
import web_rag as wr
|
39 |
+
import web_crawler as wc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
console = Console()
|
42 |
dotenv.load_dotenv()
|
|
|
44 |
callbacks = []
|
45 |
if os.getenv("LANGCHAIN_API_KEY"):
|
46 |
callbacks.append(
|
47 |
+
LangChainTracer(client=Client())
|
|
|
|
|
|
|
|
|
|
|
48 |
)
|
49 |
|
50 |
if __name__ == '__main__':
|
|
|
58 |
output=arguments["--output"]
|
59 |
query = arguments["SEARCH_QUERY"]
|
60 |
|
61 |
+
chat = wr.get_chat_llm(provider, model, temperature)
|
62 |
+
console.log(f"Using {chat.get_name} on {provider} with temperature {temperature}")
|
63 |
|
64 |
with console.status(f"[bold green]Optimizing query for search: {query}"):
|
65 |
+
optimize_search_query = wr.optimize_search_query(chat, query, callbacks=callbacks)
|
66 |
console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
|
67 |
|
68 |
with console.status(
|
69 |
f"[bold green]Searching sources using the optimized query: {optimize_search_query}"
|
70 |
):
|
71 |
+
sources = wc.get_sources(optimize_search_query, max_pages=max_pages, domain=domain)
|
72 |
console.log(f"Found {len(sources)} sources {'on ' + domain if domain else ''}")
|
73 |
|
74 |
with console.status(
|
75 |
f"[bold green]Fetching content for {len(sources)} sources", spinner="growVertical"
|
76 |
):
|
77 |
+
contents = wc.get_links_contents(sources)
|
78 |
console.log(f"Managed to extract content from {len(contents)} sources")
|
79 |
|
80 |
with console.status(f"[bold green]Embeddubg {len(contents)} sources for content", spinner="growVertical"):
|
81 |
+
vector_store = wc.vectorize(contents)
|
82 |
|
83 |
with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
|
84 |
+
respomse = wr.query_rag(chat, query, optimize_search_query, vector_store, callbacks=callbacks)
|
85 |
|
86 |
console.rule(f"[bold green]Response from {provider}")
|
87 |
if output == "text":
|
search_agent_ui.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dotenv
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
|
5 |
+
import web_rag as wr
|
6 |
+
import web_crawler as wc
|
7 |
+
|
8 |
+
from langchain_core.tracers.langchain import LangChainTracer
|
9 |
+
from langsmith.client import Client
|
10 |
+
|
11 |
+
dotenv.load_dotenv()
|
12 |
+
|
13 |
+
ls_tracer = LangChainTracer(
|
14 |
+
project_name="Search Agent UI",
|
15 |
+
client=Client()
|
16 |
+
)
|
17 |
+
|
18 |
+
chat = wr.get_chat_llm(provider="cohere")
|
19 |
+
|
20 |
+
st.title("🔍 Simple Search Agent 💬")
|
21 |
+
|
22 |
+
if "messages" not in st.session_state:
|
23 |
+
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
24 |
+
|
25 |
+
for message in st.session_state.messages:
|
26 |
+
st.chat_message(message["role"]).write(message["content"])
|
27 |
+
|
28 |
+
if prompt := st.chat_input():
|
29 |
+
|
30 |
+
st.chat_message("user").write(prompt)
|
31 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
32 |
+
|
33 |
+
message = "I first need to do some research"
|
34 |
+
st.chat_message("assistant").write(message)
|
35 |
+
st.session_state.messages.append({"role": "assistant", "content": message})
|
36 |
+
|
37 |
+
with st.spinner("Optimizing search query"):
|
38 |
+
optimize_search_query = wr.optimize_search_query(chat, query=prompt, callbacks=[ls_tracer])
|
39 |
+
|
40 |
+
message = f"I'll search the web for: {optimize_search_query}"
|
41 |
+
st.chat_message("assistant").write(message)
|
42 |
+
st.session_state.messages.append({"role": "assistant", "content": message})
|
43 |
+
|
44 |
+
|
45 |
+
with st.spinner(f"Searching the web for: {optimize_search_query}"):
|
46 |
+
sources = wc.get_sources(optimize_search_query)
|
47 |
+
|
48 |
+
with st.spinner(f"I'm now retrieveing the {len(sources)} webpages and documents I found (be patient)"):
|
49 |
+
contents = wc.get_links_contents(sources)
|
50 |
+
|
51 |
+
|
52 |
+
with st.spinner( f"Reading through the {len(contents)} sources I managed to retrieve"):
|
53 |
+
vector_store = wc.vectorize(contents)
|
54 |
+
|
55 |
+
with st.spinner( "Ok I have now enough information to answer"):
|
56 |
+
response = wr.query_rag(chat, prompt, optimize_search_query, vector_store, callbacks=[ls_tracer])
|
57 |
+
|
58 |
+
st.chat_message("assistant").write(response)
|
59 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
60 |
+
|
web_crawler.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from concurrent.futures import ThreadPoolExecutor
|
2 |
+
from urllib.parse import quote
|
3 |
+
|
4 |
+
import os
|
5 |
+
import io
|
6 |
+
|
7 |
+
from trafilatura import extract
|
8 |
+
from selenium import webdriver
|
9 |
+
from selenium.webdriver.chrome.options import Options
|
10 |
+
from selenium.common.exceptions import TimeoutException
|
11 |
+
from langchain_core.documents.base import Document
|
12 |
+
from langchain_experimental.text_splitter import SemanticChunker
|
13 |
+
from langchain_openai import OpenAIEmbeddings
|
14 |
+
from langchain_community.vectorstores.faiss import FAISS
|
15 |
+
|
16 |
+
import requests
|
17 |
+
import pdfplumber
|
18 |
+
|
19 |
+
def get_sources(query, max_pages=10, domain=None):
|
20 |
+
search_query = query
|
21 |
+
if domain:
|
22 |
+
search_query += f" site:{domain}"
|
23 |
+
|
24 |
+
url = f"https://api.search.brave.com/res/v1/web/search?q={quote(search_query)}&count={max_pages}"
|
25 |
+
headers = {
|
26 |
+
'Accept': 'application/json',
|
27 |
+
'Accept-Encoding': 'gzip',
|
28 |
+
'X-Subscription-Token': os.getenv("BRAVE_SEARCH_API_KEY")
|
29 |
+
}
|
30 |
+
|
31 |
+
try:
|
32 |
+
response = requests.get(url, headers=headers, timeout=30)
|
33 |
+
|
34 |
+
if response.status_code != 200:
|
35 |
+
return []
|
36 |
+
|
37 |
+
json_response = response.json()
|
38 |
+
|
39 |
+
if 'web' not in json_response or 'results' not in json_response['web']:
|
40 |
+
raise Exception('Invalid API response format')
|
41 |
+
|
42 |
+
final_results = [{
|
43 |
+
'title': result['title'],
|
44 |
+
'link': result['url'],
|
45 |
+
'snippet': result['description'],
|
46 |
+
'favicon': result.get('profile', {}).get('img', '')
|
47 |
+
} for result in json_response['web']['results']]
|
48 |
+
|
49 |
+
return final_results
|
50 |
+
|
51 |
+
except Exception as error:
|
52 |
+
print('Error fetching search results:', error)
|
53 |
+
raise
|
54 |
+
|
55 |
+
def fetch_with_selenium(url, timeout=8):
|
56 |
+
chrome_options = Options()
|
57 |
+
chrome_options.add_argument("headless")
|
58 |
+
chrome_options.add_argument("--disable-extensions")
|
59 |
+
chrome_options.add_argument("--disable-gpu")
|
60 |
+
chrome_options.add_argument("--no-sandbox")
|
61 |
+
chrome_options.add_argument("--disable-dev-shm-usage")
|
62 |
+
chrome_options.add_argument("--remote-debugging-port=9222")
|
63 |
+
chrome_options.add_argument('--blink-settings=imagesEnabled=false')
|
64 |
+
chrome_options.add_argument("--window-size=1920,1080")
|
65 |
+
|
66 |
+
driver = webdriver.Chrome(options=chrome_options)
|
67 |
+
|
68 |
+
try:
|
69 |
+
driver.set_page_load_timeout(timeout)
|
70 |
+
driver.get(url)
|
71 |
+
driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
|
72 |
+
html = driver.page_source
|
73 |
+
except TimeoutException:
|
74 |
+
print(f"Page load timed out after {timeout} seconds.")
|
75 |
+
html = None
|
76 |
+
finally:
|
77 |
+
driver.quit()
|
78 |
+
|
79 |
+
return html
|
80 |
+
|
81 |
+
def fetch_with_timeout(url, timeout=8):
|
82 |
+
try:
|
83 |
+
response = requests.get(url, timeout=timeout)
|
84 |
+
response.raise_for_status()
|
85 |
+
return response
|
86 |
+
except requests.RequestException as error:
|
87 |
+
return None
|
88 |
+
|
89 |
+
|
90 |
+
def process_source(source):
|
91 |
+
url = source['link']
|
92 |
+
#console.log(f"Processing {url}")
|
93 |
+
response = fetch_with_timeout(url, 8)
|
94 |
+
if response:
|
95 |
+
content_type = response.headers.get('Content-Type')
|
96 |
+
if content_type:
|
97 |
+
if content_type.startswith('application/pdf'):
|
98 |
+
# The response is a PDF file
|
99 |
+
pdf_content = response.content
|
100 |
+
# Create a file-like object from the bytes
|
101 |
+
pdf_file = io.BytesIO(pdf_content)
|
102 |
+
# Extract text from PDF using pdfplumber
|
103 |
+
with pdfplumber.open(pdf_file) as pdf:
|
104 |
+
text = ""
|
105 |
+
for page in pdf.pages:
|
106 |
+
text += page.extract_text()
|
107 |
+
return {**source, 'page_content': text}
|
108 |
+
elif content_type.startswith('text/html'):
|
109 |
+
# The response is an HTML file
|
110 |
+
html = response.text
|
111 |
+
main_content = extract(html, output_format='txt', include_links=True)
|
112 |
+
return {**source, 'page_content': main_content}
|
113 |
+
else:
|
114 |
+
print(f"Skipping {url}! Unsupported content type: {content_type}")
|
115 |
+
return {**source, 'page_content': source['snippet']}
|
116 |
+
else:
|
117 |
+
print(f"Skipping {url}! No content type")
|
118 |
+
return {**source, 'page_content': source['snippet']}
|
119 |
+
return {**source, 'page_content': None}
|
120 |
+
|
121 |
+
def get_links_contents(sources):
|
122 |
+
with ThreadPoolExecutor() as executor:
|
123 |
+
results = list(executor.map(process_source, sources))
|
124 |
+
for result in results:
|
125 |
+
if result['page_content'] is None:
|
126 |
+
url = result['link']
|
127 |
+
print(f"Fetching with selenium {url}")
|
128 |
+
html = fetch_with_selenium(url, 8)
|
129 |
+
main_content = extract(html, output_format='txt', include_links=True)
|
130 |
+
if main_content:
|
131 |
+
result['page_content'] = main_content
|
132 |
+
|
133 |
+
# Filter out None results
|
134 |
+
return [result for result in results if result is not None]
|
135 |
+
|
136 |
+
def vectorize(contents):
|
137 |
+
documents = []
|
138 |
+
for content in contents:
|
139 |
+
try:
|
140 |
+
metadata = {'title': content['title'], 'source': content['link']}
|
141 |
+
doc = Document(page_content=content['page_content'], metadata=metadata)
|
142 |
+
documents.append(doc)
|
143 |
+
except Exception as e:
|
144 |
+
print(f"[gray]Error processing content for {content['link']}: {e}")
|
145 |
+
semantic_chunker = SemanticChunker(OpenAIEmbeddings(model="text-embedding-3-large"), breakpoint_threshold_type="percentile")
|
146 |
+
docs = semantic_chunker.split_documents(documents)
|
147 |
+
print(f"Vectorizing {len(docs)} document chunks")
|
148 |
+
embeddings = OpenAIEmbeddings()
|
149 |
+
store = FAISS.from_documents(docs, embeddings)
|
150 |
+
return store
|
messages.py → web_rag.py
RENAMED
@@ -1,8 +1,24 @@
|
|
1 |
"""
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
"""
|
5 |
-
|
|
|
6 |
from langchain.schema import SystemMessage, HumanMessage
|
7 |
from langchain.prompts.chat import (
|
8 |
HumanMessagePromptTemplate,
|
@@ -10,6 +26,44 @@ from langchain.prompts.chat import (
|
|
10 |
ChatPromptTemplate
|
11 |
)
|
12 |
from langchain.prompts.prompt import PromptTemplate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def get_optimized_search_messages(query):
|
15 |
"""
|
@@ -76,6 +130,14 @@ def get_optimized_search_messages(query):
|
|
76 |
)
|
77 |
return [system_message, human_message]
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
def get_rag_prompt_template():
|
80 |
"""
|
81 |
Get the prompt template for Retrieval-Augmented Generation (RAG).
|
@@ -121,3 +183,35 @@ def get_rag_prompt_template():
|
|
121 |
input_variables=["context", "query"],
|
122 |
messages=[system_prompt, human_prompt],
|
123 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
+
Module for performing retrieval-augmented generation (RAG) using LangChain.
|
3 |
+
This module provides functions to optimize search queries, retrieve relevant documents,
|
4 |
+
and generate answers to questions using the retrieved context. It leverages the LangChain
|
5 |
+
library for building the RAG pipeline.
|
6 |
+
Functions:
|
7 |
+
- get_optimized_search_messages(query: str) -> list:
|
8 |
+
Generate optimized search messages for a given query.
|
9 |
+
- optimize_search_query(chat_llm, query: str, callbacks: list = []) -> str:
|
10 |
+
Optimize the search query using the chat language model.
|
11 |
+
- get_rag_prompt_template() -> ChatPromptTemplate:
|
12 |
+
Get the prompt template for retrieval-augmented generation (RAG).
|
13 |
+
- format_docs(docs: list) -> str:
|
14 |
+
Format the retrieved documents into a JSON string.
|
15 |
+
- multi_query_rag(chat_llm, question: str, search_query: str, vectorstore, callbacks: list = []) -> str:
|
16 |
+
Perform RAG using multiple queries to retrieve relevant documents.
|
17 |
+
- query_rag(chat_llm, question: str, search_query: str, vectorstore, callbacks: list = []) -> str:
|
18 |
+
Perform RAG using a single query to retrieve relevant documents.
|
19 |
"""
|
20 |
+
import os
|
21 |
+
import json
|
22 |
from langchain.schema import SystemMessage, HumanMessage
|
23 |
from langchain.prompts.chat import (
|
24 |
HumanMessagePromptTemplate,
|
|
|
26 |
ChatPromptTemplate
|
27 |
)
|
28 |
from langchain.prompts.prompt import PromptTemplate
|
29 |
+
from langchain.retrievers.multi_query import MultiQueryRetriever
|
30 |
+
|
31 |
+
from langchain_cohere.chat_models import ChatCohere
|
32 |
+
from langchain_groq import ChatGroq
|
33 |
+
from langchain_openai import ChatOpenAI
|
34 |
+
from langchain_community.chat_models.bedrock import BedrockChat
|
35 |
+
from langchain_community.chat_models.ollama import ChatOllama
|
36 |
+
|
37 |
+
def get_chat_llm(provider, model=None, temperature=0.0):
|
38 |
+
match provider:
|
39 |
+
case 'bedrock':
|
40 |
+
if model is None:
|
41 |
+
model = "anthropic.claude-3-sonnet-20240229-v1:0"
|
42 |
+
chat_llm = BedrockChat(
|
43 |
+
credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME'),
|
44 |
+
model_id=model,
|
45 |
+
model_kwargs={"temperature": temperature },
|
46 |
+
)
|
47 |
+
case 'openai':
|
48 |
+
if model is None:
|
49 |
+
model = "gpt-3.5-turbo"
|
50 |
+
chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
|
51 |
+
case 'groq':
|
52 |
+
if model is None:
|
53 |
+
model = 'mixtral-8x7b-32768'
|
54 |
+
chat_llm = ChatGroq(model_name=model, temperature=temperature)
|
55 |
+
case 'ollama':
|
56 |
+
if model is None:
|
57 |
+
model = 'llama2'
|
58 |
+
chat_llm = ChatOllama(model=model, temperature=temperature)
|
59 |
+
case 'cohere':
|
60 |
+
if model is None:
|
61 |
+
model = 'command-r-plus'
|
62 |
+
chat_llm = ChatCohere(model=model, temperature=temperature)
|
63 |
+
case _:
|
64 |
+
raise ValueError(f"Unknown LLM provider {provider}")
|
65 |
+
return chat_llm
|
66 |
+
|
67 |
|
68 |
def get_optimized_search_messages(query):
|
69 |
"""
|
|
|
130 |
)
|
131 |
return [system_message, human_message]
|
132 |
|
133 |
+
|
134 |
+
def optimize_search_query(chat_llm, query, callbacks=[]):
|
135 |
+
messages = get_optimized_search_messages(query)
|
136 |
+
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
137 |
+
optimized_search_query = response.content
|
138 |
+
return optimized_search_query.strip('"').split("**", 1)[0]
|
139 |
+
|
140 |
+
|
141 |
def get_rag_prompt_template():
|
142 |
"""
|
143 |
Get the prompt template for Retrieval-Augmented Generation (RAG).
|
|
|
183 |
input_variables=["context", "query"],
|
184 |
messages=[system_prompt, human_prompt],
|
185 |
)
|
186 |
+
|
187 |
+
def format_docs(docs):
|
188 |
+
formatted_docs = []
|
189 |
+
for d in docs:
|
190 |
+
content = d.page_content
|
191 |
+
title = d.metadata['title']
|
192 |
+
source = d.metadata['source']
|
193 |
+
doc = {"content": content, "title": title, "link": source}
|
194 |
+
formatted_docs.append(doc)
|
195 |
+
docs_as_json = json.dumps(formatted_docs, indent=2, ensure_ascii=False)
|
196 |
+
return docs_as_json
|
197 |
+
|
198 |
+
|
199 |
+
def multi_query_rag(chat_llm, question, search_query, vectorstore, callbacks = []):
|
200 |
+
retriever_from_llm = MultiQueryRetriever.from_llm(
|
201 |
+
retriever=vectorstore.as_retriever(), llm=chat_llm, include_original=True,
|
202 |
+
)
|
203 |
+
unique_docs = retriever_from_llm.get_relevant_documents(
|
204 |
+
query=search_query, callbacks=callbacks, verbose=True
|
205 |
+
)
|
206 |
+
context = format_docs(unique_docs)
|
207 |
+
prompt = get_rag_prompt_template().format(query=question, context=context)
|
208 |
+
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
209 |
+
return response.content
|
210 |
+
|
211 |
+
|
212 |
+
def query_rag(chat_llm, question, search_query, vectorstore, callbacks = []):
|
213 |
+
unique_docs = vectorstore.similarity_search(search_query, k=15, callbacks=callbacks, verbose=True)
|
214 |
+
context = format_docs(unique_docs)
|
215 |
+
prompt = get_rag_prompt_template().format(query=question, context=context)
|
216 |
+
response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
|
217 |
+
return response.content
|