CyranoB commited on
Commit
df527c8
β€’
1 Parent(s): bda01ad

Added Firework.ai as provider. Better streamlit ui.

Browse files
Files changed (7) hide show
  1. README.md +1 -1
  2. dotenv.sample +1 -0
  3. requirements.txt +1 -0
  4. search_agent.py +4 -4
  5. search_agent_ui.py +32 -14
  6. web_crawler.py +2 -2
  7. web_rag.py +21 -3
README.md CHANGED
@@ -15,7 +15,7 @@ license: apache-2.0
15
  This Python project provides a search agent that can perform web searches, optimize search queries, fetch and process web content, and generate responses using a language model and the retrieved information.
16
  Does a bit what [Perplexity AI](https://www.perplexity.ai/) does.
17
 
18
- The Streamlit GUI hosted on πŸ€— Sapces is [available to test](https://huggingface.co/spaces/CyranoB/search_agent)
19
 
20
  This Python script and Streamli GUI are a basic search agent that utilizes the LangChain library to perform optimized web searches, retrieve relevant content, and generate informative answers to user queries. The script supports multiple language models and providers, including OpenAI, Anthropic, and Groq.
21
 
 
15
  This Python project provides a search agent that can perform web searches, optimize search queries, fetch and process web content, and generate responses using a language model and the retrieved information.
16
  Does a bit what [Perplexity AI](https://www.perplexity.ai/) does.
17
 
18
+ The Streamlit GUI hosted on πŸ€— Spaces is [available to test](https://huggingface.co/spaces/CyranoB/search_agent)
19
 
20
  This Python script and Streamli GUI are a basic search agent that utilizes the LangChain library to perform optimized web searches, retrieve relevant content, and generate informative answers to user queries. The script supports multiple language models and providers, including OpenAI, Anthropic, and Groq.
21
 
dotenv.sample CHANGED
@@ -6,6 +6,7 @@ LANGCHAIN_ENDPOINT="https://api.smith.langchain.com"
6
 
7
  OPENAI_API_KEY=sk-XXXXXXXXXXXXXXXXXXX
8
  ANTHROPIC_API_KEY=sk-ant-api03-XXXXXXXXXXXXXXXXXXX
 
9
  GROQ_API_KEY=gsk_XXXXXXXXXXXXXXXXXXX
10
  CREDENTIALS_PROFILE_NAME=XXXXXXXXXXXXXXXXXXX
11
  COHERE_API_KEY=XXXXXXXXXXXXXXXXXXX
 
6
 
7
  OPENAI_API_KEY=sk-XXXXXXXXXXXXXXXXXXX
8
  ANTHROPIC_API_KEY=sk-ant-api03-XXXXXXXXXXXXXXXXXXX
9
+ FIREWORKS_API_KEY=XXXXXXXXXXXXXXXXXXX
10
  GROQ_API_KEY=gsk_XXXXXXXXXXXXXXXXXXX
11
  CREDENTIALS_PROFILE_NAME=XXXXXXXXXXXXXXXXXXX
12
  COHERE_API_KEY=XXXXXXXXXXXXXXXXXXX
requirements.txt CHANGED
@@ -8,6 +8,7 @@ pdfplumber
8
  python-dotenv
9
  langchain
10
  langchain-cohere
 
11
  langchain_core
12
  langchain_community
13
  langchain_experimental
 
8
  python-dotenv
9
  langchain
10
  langchain-cohere
11
+ langchain-fireworks
12
  langchain_core
13
  langchain_community
14
  langchain_experimental
search_agent.py CHANGED
@@ -16,7 +16,7 @@ Options:
16
  --version Show version.
17
  -d domain --domain=domain Limit search to a specific domain
18
  -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
19
- -p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere) [default: openai]
20
  -m model --model=model Use a specific model
21
  -n num --max_pages=num Max number of pages to retrieve [default: 10]
22
  -o text --output=text Output format (choices: text, markdown) [default: markdown]
@@ -78,8 +78,8 @@ if __name__ == '__main__':
78
  output=arguments["--output"]
79
  query = arguments["SEARCH_QUERY"]
80
 
81
- chat = wr.get_chat_llm(provider, model, temperature)
82
- console.log(f"Using {chat.model} on {provider} with temperature {temperature}")
83
 
84
  with console.status(f"[bold green]Optimizing query for search: {query}"):
85
  optimize_search_query = wr.optimize_search_query(chat, query, callbacks=callbacks)
@@ -98,7 +98,7 @@ if __name__ == '__main__':
98
  console.log(f"Managed to extract content from {len(contents)} sources")
99
 
100
  with console.status(f"[bold green]Embeddubg {len(contents)} sources for content", spinner="growVertical"):
101
- vector_store = wc.vectorize(contents)
102
 
103
  with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
104
  respomse = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = 5, callbacks=callbacks)
 
16
  --version Show version.
17
  -d domain --domain=domain Limit search to a specific domain
18
  -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
19
+ -p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere,fireworks) [default: openai]
20
  -m model --model=model Use a specific model
21
  -n num --max_pages=num Max number of pages to retrieve [default: 10]
22
  -o text --output=text Output format (choices: text, markdown) [default: markdown]
 
78
  output=arguments["--output"]
79
  query = arguments["SEARCH_QUERY"]
80
 
81
+ chat, embedding_model = wr.get_models(provider, model, temperature)
82
+ #console.log(f"Using {chat.model_name} on {provider}")
83
 
84
  with console.status(f"[bold green]Optimizing query for search: {query}"):
85
  optimize_search_query = wr.optimize_search_query(chat, query, callbacks=callbacks)
 
98
  console.log(f"Managed to extract content from {len(contents)} sources")
99
 
100
  with console.status(f"[bold green]Embeddubg {len(contents)} sources for content", spinner="growVertical"):
101
+ vector_store = wc.vectorize(contents, embedding_model)
102
 
103
  with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
104
  respomse = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = 5, callbacks=callbacks)
search_agent_ui.py CHANGED
@@ -1,4 +1,5 @@
1
  import datetime
 
2
 
3
  import dotenv
4
  import streamlit as st
@@ -13,11 +14,13 @@ import web_crawler as wc
13
  dotenv.load_dotenv()
14
 
15
  ls_tracer = LangChainTracer(
16
- project_name="Search Agent UI",
17
  client=Client()
18
  )
19
 
 
20
  class StreamHandler(BaseCallbackHandler):
 
21
  def __init__(self, container, initial_text=""):
22
  self.container = container
23
  self.text = initial_text
@@ -26,16 +29,34 @@ class StreamHandler(BaseCallbackHandler):
26
  self.text += token
27
  self.container.markdown(self.text)
28
 
29
- chat = wr.get_chat_llm(provider="cohere")
30
-
31
  st.title("πŸ” Simple Search Agent πŸ’¬")
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  if "messages" not in st.session_state:
34
  st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
35
 
36
- if "input_disabled" not in st.session_state:
37
- st.session_state["input_disabled"] = False
38
-
39
  for message in st.session_state.messages:
40
  st.chat_message(message["role"]).write(message["content"])
41
  if message["role"] == "assistant" and 'message_id' in message:
@@ -46,26 +67,25 @@ for message in st.session_state.messages:
46
  mime="text/plain"
47
  )
48
 
49
- if prompt := st.chat_input("Enter you instructions...", disabled=st.session_state["input_disabled"] ):
50
-
51
- st.session_state["input_disabled"] = True
52
-
53
  st.chat_message("user").write(prompt)
54
  st.session_state.messages.append({"role": "user", "content": prompt})
55
 
 
 
56
  with st.status("Thinking", expanded=True):
57
  st.write("I first need to do some research")
58
 
59
  optimize_search_query = wr.optimize_search_query(chat, query=prompt, callbacks=[ls_tracer])
60
  st.write(f"I should search the web for: {optimize_search_query}")
61
 
62
- sources = wc.get_sources(optimize_search_query, max_pages=20)
63
 
64
  st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
65
  contents = wc.get_links_contents(sources)
66
 
67
  st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
68
- vector_store = wc.vectorize(contents)
69
  st.write(f"I collected {vector_store.index.ntotal} chunk of data and I can now answer")
70
 
71
  rag_prompt = wr.build_rag_prompt(prompt, optimize_search_query, vector_store, top_k=5, callbacks=[ls_tracer])
@@ -82,5 +102,3 @@ if prompt := st.chat_input("Enter you instructions...", disabled=st.session_stat
82
  file_name=f"{message_id}.txt",
83
  mime="text/plain"
84
  )
85
- st.session_state["input_disabled"] = False
86
-
 
1
  import datetime
2
+ import os
3
 
4
  import dotenv
5
  import streamlit as st
 
14
  dotenv.load_dotenv()
15
 
16
  ls_tracer = LangChainTracer(
17
+ project_name=os.getenv("LANGSMITH_PROJECT_NAME"),
18
  client=Client()
19
  )
20
 
21
+
22
  class StreamHandler(BaseCallbackHandler):
23
+ """Stream handler that appends tokens to container."""
24
  def __init__(self, container, initial_text=""):
25
  self.container = container
26
  self.text = initial_text
 
29
  self.text += token
30
  self.container.markdown(self.text)
31
 
 
 
32
  st.title("πŸ” Simple Search Agent πŸ’¬")
33
 
34
+ if "providers" not in st.session_state:
35
+ providers = []
36
+ if os.getenv("COHERE_API_KEY"):
37
+ providers.append("cohere")
38
+ if os.getenv("OPENAI_API_KEY"):
39
+ providers.append("openai")
40
+ if os.getenv("GROQ_API_KEY"):
41
+ providers.append("groq")
42
+ if os.getenv("OLLAMA_API_KEY"):
43
+ providers.append("ollama")
44
+ if os.getenv("FIREWORKS_API_KEY"):
45
+ providers.append("fireworks")
46
+ if os.getenv("CREDENTIALS_PROFILE_NAME"):
47
+ providers.append("bedrock")
48
+ st.session_state["providers"] = providers
49
+
50
+ with st.sidebar:
51
+ st.write("Options")
52
+ model_provider = st.selectbox("🧠 Model provider 🧠", st.session_state["providers"])
53
+ temperature = st.slider("🌑️ Model temperature 🌑️", 0.0, 1.0, 0.1, help="The higher the more creative")
54
+ max_pages = st.slider("πŸ” Max pages to retrieve πŸ”", 1, 20, 15, help="How many web pages to retrive from the internet")
55
+ top_k_documents = st.slider("πŸ“„ How many document extracts to consider πŸ“„", 1, 20, 5, help="How many of the top extracts to consider")
56
+
57
  if "messages" not in st.session_state:
58
  st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
59
 
 
 
 
60
  for message in st.session_state.messages:
61
  st.chat_message(message["role"]).write(message["content"])
62
  if message["role"] == "assistant" and 'message_id' in message:
 
67
  mime="text/plain"
68
  )
69
 
70
+ if prompt := st.chat_input("Enter you instructions..." ):
 
 
 
71
  st.chat_message("user").write(prompt)
72
  st.session_state.messages.append({"role": "user", "content": prompt})
73
 
74
+ chat, embedding_model = wr.get_models(model_provider, temperature=temperature)
75
+
76
  with st.status("Thinking", expanded=True):
77
  st.write("I first need to do some research")
78
 
79
  optimize_search_query = wr.optimize_search_query(chat, query=prompt, callbacks=[ls_tracer])
80
  st.write(f"I should search the web for: {optimize_search_query}")
81
 
82
+ sources = wc.get_sources(optimize_search_query, max_pages=max_pages)
83
 
84
  st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
85
  contents = wc.get_links_contents(sources)
86
 
87
  st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
88
+ vector_store = wc.vectorize(contents, embedding_model=embedding_model)
89
  st.write(f"I collected {vector_store.index.ntotal} chunk of data and I can now answer")
90
 
91
  rag_prompt = wr.build_rag_prompt(prompt, optimize_search_query, vector_store, top_k=5, callbacks=[ls_tracer])
 
102
  file_name=f"{message_id}.txt",
103
  mime="text/plain"
104
  )
 
 
web_crawler.py CHANGED
@@ -124,7 +124,7 @@ def get_links_contents(sources, get_driver_func=None):
124
  result['page_content'] = main_content
125
  return results
126
 
127
- def vectorize(contents):
128
  documents = []
129
  for content in contents:
130
  try:
@@ -135,7 +135,7 @@ def vectorize(contents):
135
  documents.append(doc)
136
  except Exception as e:
137
  print(f"[gray]Error processing content for {content['link']}: {e}")
138
- semantic_chunker = SemanticChunker(OpenAIEmbeddings(model="text-embedding-3-large"), breakpoint_threshold_type="percentile")
139
  docs = semantic_chunker.split_documents(documents)
140
  embeddings = OpenAIEmbeddings()
141
  store = FAISS.from_documents(docs, embeddings)
 
124
  result['page_content'] = main_content
125
  return results
126
 
127
+ def vectorize(contents, embedding_model):
128
  documents = []
129
  for content in contents:
130
  try:
 
135
  documents.append(doc)
136
  except Exception as e:
137
  print(f"[gray]Error processing content for {content['link']}: {e}")
138
+ semantic_chunker = SemanticChunker(embedding_model, breakpoint_threshold_type="percentile")
139
  docs = semantic_chunker.split_documents(documents)
140
  embeddings = OpenAIEmbeddings()
141
  store = FAISS.from_documents(docs, embeddings)
web_rag.py CHANGED
@@ -29,40 +29,58 @@ from langchain.prompts.prompt import PromptTemplate
29
  from langchain.retrievers.multi_query import MultiQueryRetriever
30
 
31
  from langchain_cohere.chat_models import ChatCohere
 
 
32
  from langchain_groq import ChatGroq
33
  from langchain_openai import ChatOpenAI
 
34
  from langchain_community.chat_models.bedrock import BedrockChat
 
35
  from langchain_community.chat_models.ollama import ChatOllama
36
 
37
- def get_chat_llm(provider, model=None, temperature=0.0):
38
  match provider:
39
  case 'bedrock':
 
40
  if model is None:
41
  model = "anthropic.claude-3-sonnet-20240229-v1:0"
42
  chat_llm = BedrockChat(
43
- credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME'),
44
  model_id=model,
45
  model_kwargs={"temperature": temperature },
46
  )
 
 
 
 
47
  case 'openai':
48
  if model is None:
49
  model = "gpt-3.5-turbo"
50
  chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
 
51
  case 'groq':
52
  if model is None:
53
  model = 'mixtral-8x7b-32768'
54
  chat_llm = ChatGroq(model_name=model, temperature=temperature)
 
55
  case 'ollama':
56
  if model is None:
57
  model = 'llama2'
58
  chat_llm = ChatOllama(model=model, temperature=temperature)
 
59
  case 'cohere':
60
  if model is None:
61
  model = 'command-r-plus'
62
  chat_llm = ChatCohere(model=model, temperature=temperature)
 
 
 
 
 
 
63
  case _:
64
  raise ValueError(f"Unknown LLM provider {provider}")
65
- return chat_llm
66
 
67
 
68
  def get_optimized_search_messages(query):
 
29
  from langchain.retrievers.multi_query import MultiQueryRetriever
30
 
31
  from langchain_cohere.chat_models import ChatCohere
32
+ from langchain_cohere.embeddings import CohereEmbeddings
33
+ from langchain_fireworks.chat_models import ChatFireworks
34
  from langchain_groq import ChatGroq
35
  from langchain_openai import ChatOpenAI
36
+ from langchain_openai.embeddings import OpenAIEmbeddings
37
  from langchain_community.chat_models.bedrock import BedrockChat
38
+ from langchain_community.embeddings.bedrock import BedrockEmbeddings
39
  from langchain_community.chat_models.ollama import ChatOllama
40
 
41
+ def get_models(provider, model=None, temperature=0.0):
42
  match provider:
43
  case 'bedrock':
44
+ credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
45
  if model is None:
46
  model = "anthropic.claude-3-sonnet-20240229-v1:0"
47
  chat_llm = BedrockChat(
48
+ credentials_profile_name=credentials_profile_name,
49
  model_id=model,
50
  model_kwargs={"temperature": temperature },
51
  )
52
+ embedding_model = BedrockEmbeddings(
53
+ model_id='cohere.embed-multilingual-v3',
54
+ credentials_profile_name=credentials_profile_name
55
+ )
56
  case 'openai':
57
  if model is None:
58
  model = "gpt-3.5-turbo"
59
  chat_llm = ChatOpenAI(model_name=model, temperature=temperature)
60
+ embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
61
  case 'groq':
62
  if model is None:
63
  model = 'mixtral-8x7b-32768'
64
  chat_llm = ChatGroq(model_name=model, temperature=temperature)
65
+ embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
66
  case 'ollama':
67
  if model is None:
68
  model = 'llama2'
69
  chat_llm = ChatOllama(model=model, temperature=temperature)
70
+ embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
71
  case 'cohere':
72
  if model is None:
73
  model = 'command-r-plus'
74
  chat_llm = ChatCohere(model=model, temperature=temperature)
75
+ embedding_model = CohereEmbeddings(model="embed-english-light-v3.0")
76
+ case 'fireworks':
77
+ if model is None:
78
+ model = 'accounts/fireworks/models/mixtral-8x22b-instruct-preview'
79
+ chat_llm = ChatFireworks(model_name=model, temperature=temperature)
80
+ embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
81
  case _:
82
  raise ValueError(f"Unknown LLM provider {provider}")
83
+ return chat_llm, embedding_model
84
 
85
 
86
  def get_optimized_search_messages(query):