CyranoB commited on
Commit
9847233
1 Parent(s): 8c28786

Review mode

Browse files
Files changed (7) hide show
  1. README.md +2 -0
  2. copywriter.py +37 -5
  3. requirements.txt +5 -1
  4. search_agent.py +10 -8
  5. search_agent_ui.py +80 -18
  6. web_crawler.py +2 -1
  7. web_rag.py +76 -37
README.md CHANGED
@@ -10,6 +10,8 @@ pinned: false
10
  license: apache-2.0
11
  ---
12
 
 
 
13
  # Simple Search Agent
14
 
15
  This Python project provides a search agent that can perform web searches, optimize search queries, fetch and process web content, and generate responses using a language model and the retrieved information.
 
10
  license: apache-2.0
11
  ---
12
 
13
+ ⚠️ **This project is a demonstration / proof-of-concept and is not intended for use in production environments. It is provided as-is, without warranty or guarantee of any kind. The code and any accompanying materials are for educational, testing, or evaluation purposes only.**⚠️
14
+
15
  # Simple Search Agent
16
 
17
  This Python project provides a search agent that can perform web searches, optimize search queries, fetch and process web content, and generate responses using a language model and the retrieved information.
copywriter.py CHANGED
@@ -7,7 +7,6 @@ from langchain.prompts.chat import (
7
  from langchain.prompts.prompt import PromptTemplate
8
 
9
 
10
-
11
  def get_comments_prompt(query, draft):
12
  system_message = SystemMessage(
13
  content="""
@@ -35,14 +34,11 @@ def get_comments_prompt(query, draft):
35
  )
36
  return [system_message, human_message]
37
 
38
-
39
  def generate_comments(chat_llm, query, draft, callbacks=[]):
40
  messages = get_comments_prompt(query, draft)
41
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
42
  return response.content
43
 
44
-
45
-
46
  def get_final_text_prompt(query, draft, comments):
47
  system_message = SystemMessage(
48
  content="""
@@ -74,4 +70,40 @@ def get_final_text_prompt(query, draft, comments):
74
  def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
75
  messages = get_final_text_prompt(query, draft, comments)
76
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
77
- return response.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from langchain.prompts.prompt import PromptTemplate
8
 
9
 
 
10
  def get_comments_prompt(query, draft):
11
  system_message = SystemMessage(
12
  content="""
 
34
  )
35
  return [system_message, human_message]
36
 
 
37
  def generate_comments(chat_llm, query, draft, callbacks=[]):
38
  messages = get_comments_prompt(query, draft)
39
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
40
  return response.content
41
 
 
 
42
  def get_final_text_prompt(query, draft, comments):
43
  system_message = SystemMessage(
44
  content="""
 
70
  def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
71
  messages = get_final_text_prompt(query, draft, comments)
72
  response = chat_llm.invoke(messages, config={"callbacks": callbacks})
73
+ return response.content
74
+
75
+
76
+ def get_compare_texts_prompts(query, draft_text, final_text):
77
+ system_message = SystemMessage(
78
+ content="""
79
+ I want you to act as a writing quality evaluator.
80
+ I will provide you with the original user request and four texts.
81
+ Your task is to carefully analyze, compare the two texts across the following dimensions and grade each text 0 to 10:
82
+ 1. Grammar and spelling - Which text has fewer grammatical errors and spelling mistakes?
83
+ 2. Clarity and coherence - Which text is easier to understand and has a more logical flow of ideas? Evaluate how well each text conveys its main points.
84
+ 3. Tone and style - Which text has a more appropriate and engaging tone and writing style for its intended purpose and audience?
85
+ 4. Sticking to the request - Which text is more successful responding to the original user request. Consider the request, the style, the length, etc.
86
+ 5. Overall effectiveness - Considering the above factors, which text is more successful overall at communicating its message and achieving its goals?
87
+
88
+ After comparing the texts on these criteria, clearly state which text you think is better and summarize the main reasons why.
89
+ Provide specific examples from each text to support your evaluation.
90
+ """
91
+ )
92
+ human_message = HumanMessage(
93
+ content=f"""
94
+ Original query: {query}
95
+ ------------------------
96
+ Text 1: {draft_text}
97
+ ------------------------
98
+ Text 2: {final_text}
99
+ ------------------------
100
+ Summary:
101
+ """
102
+ )
103
+ return [system_message, human_message]
104
+
105
+
106
+ def compare_text(chat_llm, query, draft, final, callbacks=[]):
107
+ messages = get_compare_texts_prompts(query, draft_text=draft, final_text=final)
108
+ response = chat_llm.invoke(messages, config={"callbacks": callbacks})
109
+ return response.content
requirements.txt CHANGED
@@ -1,5 +1,7 @@
 
1
  boto3
2
  bs4
 
3
  cohere
4
  docopt
5
  faiss-cpu
@@ -7,7 +9,7 @@ google-api-python-client
7
  pdfplumber
8
  python-dotenv
9
  langchain
10
- langchain-cohere
11
  langchain-fireworks
12
  langchain_core
13
  langchain_community
@@ -18,6 +20,8 @@ langsmith
18
  schema
19
  streamlit
20
  selenium
 
 
21
  rich
22
  trafilatura
23
  watchdog
 
1
+ anthropic
2
  boto3
3
  bs4
4
+ chromedriver-py
5
  cohere
6
  docopt
7
  faiss-cpu
 
9
  pdfplumber
10
  python-dotenv
11
  langchain
12
+ langchain-aws
13
  langchain-fireworks
14
  langchain_core
15
  langchain_community
 
20
  schema
21
  streamlit
22
  selenium
23
+ tiktoken
24
+ transformers
25
  rich
26
  trafilatura
27
  watchdog
search_agent.py CHANGED
@@ -8,6 +8,7 @@ Usage:
8
  [--temperature=temp]
9
  [--copywrite]
10
  [--max_pages=num]
 
11
  [--output=text]
12
  SEARCH_QUERY
13
  search_agent.py --version
@@ -21,6 +22,7 @@ Options:
21
  -p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere,fireworks) [default: openai]
22
  -m model --model=model Use a specific model
23
  -n num --max_pages=num Max number of pages to retrieve [default: 10]
 
24
  -o text --output=text Output format (choices: text, markdown) [default: markdown]
25
 
26
  """
@@ -63,8 +65,6 @@ def get_selenium_driver():
63
  driver = webdriver.Chrome(options=chrome_options)
64
  return driver
65
 
66
-
67
-
68
  callbacks = []
69
  if os.getenv("LANGCHAIN_API_KEY"):
70
  callbacks.append(
@@ -90,14 +90,16 @@ if __name__ == '__main__':
90
  temperature = float(arguments["--temperature"])
91
  domain=arguments["--domain"]
92
  max_pages=arguments["--max_pages"]
 
93
  output=arguments["--output"]
94
  query = arguments["SEARCH_QUERY"]
95
 
96
  chat, embedding_model = wr.get_models(provider, model, temperature)
97
- #console.log(f"Using {chat.model_name} on {provider}")
98
 
99
  with console.status(f"[bold green]Optimizing query for search: {query}"):
100
  optimize_search_query = wr.optimize_search_query(chat, query, callbacks=callbacks)
 
 
101
  console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
102
 
103
  with console.status(
@@ -112,11 +114,11 @@ if __name__ == '__main__':
112
  contents = wc.get_links_contents(sources, get_selenium_driver)
113
  console.log(f"Managed to extract content from {len(contents)} sources")
114
 
115
- with console.status(f"[bold green]Embeddubg {len(contents)} sources for content", spinner="growVertical"):
116
  vector_store = wc.vectorize(contents, embedding_model)
117
 
118
- with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
119
- draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = 5, callbacks=callbacks)
120
 
121
  console.rule(f"[bold green]Response from {provider}")
122
  if output == "text":
@@ -129,7 +131,7 @@ if __name__ == '__main__':
129
  with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
130
  comments = cw.generate_comments(chat, query, draft, callbacks=callbacks)
131
 
132
- console.rule(f"[bold green]Response from reviewer")
133
  if output == "text":
134
  console.print(comments)
135
  else:
@@ -139,7 +141,7 @@ if __name__ == '__main__':
139
  with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
140
  final_text = cw.generate_final_text(chat, query, draft, comments, callbacks=callbacks)
141
 
142
- console.rule(f"[bold green]Final text")
143
  if output == "text":
144
  console.print(final_text)
145
  else:
 
8
  [--temperature=temp]
9
  [--copywrite]
10
  [--max_pages=num]
11
+ [--max_extracts=num]
12
  [--output=text]
13
  SEARCH_QUERY
14
  search_agent.py --version
 
22
  -p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere,fireworks) [default: openai]
23
  -m model --model=model Use a specific model
24
  -n num --max_pages=num Max number of pages to retrieve [default: 10]
25
+ -e num --max_extracts=num Max number of page extract to consider [default: 5]
26
  -o text --output=text Output format (choices: text, markdown) [default: markdown]
27
 
28
  """
 
65
  driver = webdriver.Chrome(options=chrome_options)
66
  return driver
67
 
 
 
68
  callbacks = []
69
  if os.getenv("LANGCHAIN_API_KEY"):
70
  callbacks.append(
 
90
  temperature = float(arguments["--temperature"])
91
  domain=arguments["--domain"]
92
  max_pages=arguments["--max_pages"]
93
+ max_extract=int(arguments["--max_extracts"])
94
  output=arguments["--output"]
95
  query = arguments["SEARCH_QUERY"]
96
 
97
  chat, embedding_model = wr.get_models(provider, model, temperature)
 
98
 
99
  with console.status(f"[bold green]Optimizing query for search: {query}"):
100
  optimize_search_query = wr.optimize_search_query(chat, query, callbacks=callbacks)
101
+ if len(optimize_search_query) < 3:
102
+ optimize_search_query = query
103
  console.log(f"Optimized search query: [bold blue]{optimize_search_query}")
104
 
105
  with console.status(
 
114
  contents = wc.get_links_contents(sources, get_selenium_driver)
115
  console.log(f"Managed to extract content from {len(contents)} sources")
116
 
117
+ with console.status(f"[bold green]Embedding {len(contents)} sources for content", spinner="growVertical"):
118
  vector_store = wc.vectorize(contents, embedding_model)
119
 
120
+ with console.status("[bold green]Writing content", spinner='dots8Bit'):
121
+ draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = max_extract, callbacks=callbacks)
122
 
123
  console.rule(f"[bold green]Response from {provider}")
124
  if output == "text":
 
131
  with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
132
  comments = cw.generate_comments(chat, query, draft, callbacks=callbacks)
133
 
134
+ console.rule("[bold green]Response from reviewer")
135
  if output == "text":
136
  console.print(comments)
137
  else:
 
141
  with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
142
  final_text = cw.generate_final_text(chat, query, draft, comments, callbacks=callbacks)
143
 
144
+ console.rule("[bold green]Final text")
145
  if output == "text":
146
  console.print(final_text)
147
  else:
search_agent_ui.py CHANGED
@@ -10,6 +10,7 @@ from langsmith.client import Client
10
 
11
  import web_rag as wr
12
  import web_crawler as wc
 
13
 
14
  dotenv.load_dotenv()
15
 
@@ -18,7 +19,6 @@ ls_tracer = LangChainTracer(
18
  client=Client()
19
  )
20
 
21
-
22
  class StreamHandler(BaseCallbackHandler):
23
  """Stream handler that appends tokens to container."""
24
  def __init__(self, container, initial_text=""):
@@ -28,11 +28,36 @@ class StreamHandler(BaseCallbackHandler):
28
  def on_llm_new_token(self, token: str, **kwargs):
29
  self.text += token
30
  self.container.markdown(self.text)
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  st.title("🔍 Simple Search Agent 💬")
33
 
34
  if "providers" not in st.session_state:
35
  providers = []
 
 
36
  if os.getenv("COHERE_API_KEY"):
37
  providers.append("cohere")
38
  if os.getenv("OPENAI_API_KEY"):
@@ -41,22 +66,34 @@ if "providers" not in st.session_state:
41
  providers.append("groq")
42
  if os.getenv("OLLAMA_API_KEY"):
43
  providers.append("ollama")
44
- if os.getenv("FIREWORKS_API_KEY"):
45
- providers.append("fireworks")
46
  if os.getenv("CREDENTIALS_PROFILE_NAME"):
47
  providers.append("bedrock")
48
  st.session_state["providers"] = providers
49
 
50
- with st.sidebar:
51
- st.write("Options")
52
- model_provider = st.selectbox("🧠 Model provider 🧠", st.session_state["providers"])
53
- temperature = st.slider("🌡️ Model temperature 🌡️", 0.0, 1.0, 0.1, help="The higher the more creative")
54
- max_pages = st.slider("🔍 Max pages to retrieve 🔍", 1, 20, 15, help="How many web pages to retrive from the internet")
55
- top_k_documents = st.slider("📄 How many doc extracts to consider 📄", 1, 20, 5, help="How many of the top extracts to consider")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  if "messages" not in st.session_state:
58
  st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
59
-
60
  for message in st.session_state.messages:
61
  st.chat_message(message["role"]).write(message["content"])
62
  if message["role"] == "assistant" and 'message_id' in message:
@@ -80,6 +117,7 @@ if prompt := st.chat_input("Enter you instructions..." ):
80
  st.write(f"I should search the web for: {optimize_search_query}")
81
 
82
  sources = wc.get_sources(optimize_search_query, max_pages=max_pages)
 
83
 
84
  st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
85
  contents = wc.get_links_contents(sources)
@@ -87,18 +125,42 @@ if prompt := st.chat_input("Enter you instructions..." ):
87
  st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
88
  vector_store = wc.vectorize(contents, embedding_model=embedding_model)
89
  st.write(f"I collected {vector_store.index.ntotal} chunk of data and I can now answer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- rag_prompt = wr.build_rag_prompt(prompt, optimize_search_query, vector_store, top_k=5, callbacks=[ls_tracer])
92
  with st.chat_message("assistant"):
93
  st_cb = StreamHandler(st.empty())
94
  result = chat.invoke(rag_prompt, stream=True, config={ "callbacks": [st_cb, ls_tracer]})
95
  response = result.content.strip()
96
  message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
97
  st.session_state.messages.append({"role": "assistant", "content": response})
98
- if st.session_state.messages[-1]["role"] == "assistant":
99
- st.download_button(
100
- label="Download",
101
- data=st.session_state.messages[-1]["content"],
102
- file_name=f"{message_id}.txt",
103
- mime="text/plain"
104
- )
 
 
 
 
 
 
 
10
 
11
  import web_rag as wr
12
  import web_crawler as wc
13
+ import copywriter as cw
14
 
15
  dotenv.load_dotenv()
16
 
 
19
  client=Client()
20
  )
21
 
 
22
  class StreamHandler(BaseCallbackHandler):
23
  """Stream handler that appends tokens to container."""
24
  def __init__(self, container, initial_text=""):
 
28
  def on_llm_new_token(self, token: str, **kwargs):
29
  self.text += token
30
  self.container.markdown(self.text)
31
+
32
 
33
+ def create_links_markdown(sources_list):
34
+ """
35
+ Create a markdown string for each source in the provided JSON.
36
+
37
+ Args:
38
+ sources_list (list): A list of dictionaries representing the sources.
39
+ Each dictionary should have 'title', 'link', and 'snippet' keys.
40
+
41
+ Returns:
42
+ str: A markdown string with a bullet point for each source,
43
+ including the title linked to the URL and the snippet.
44
+ """
45
+ markdown_list = []
46
+ for source in sources_list:
47
+ title = source['title']
48
+ link = source['link']
49
+ snippet = source['snippet']
50
+ markdown = f"- [{title}]({link})\n {snippet}"
51
+ markdown_list.append(markdown)
52
+ return "\n".join(markdown_list)
53
+
54
+ st.set_page_config(layout="wide")
55
  st.title("🔍 Simple Search Agent 💬")
56
 
57
  if "providers" not in st.session_state:
58
  providers = []
59
+ if os.getenv("FIREWORKS_API_KEY"):
60
+ providers.append("fireworks")
61
  if os.getenv("COHERE_API_KEY"):
62
  providers.append("cohere")
63
  if os.getenv("OPENAI_API_KEY"):
 
66
  providers.append("groq")
67
  if os.getenv("OLLAMA_API_KEY"):
68
  providers.append("ollama")
 
 
69
  if os.getenv("CREDENTIALS_PROFILE_NAME"):
70
  providers.append("bedrock")
71
  st.session_state["providers"] = providers
72
 
73
+ with st.sidebar.expander("Options", expanded=False):
74
+ model_provider = st.selectbox("Model provider 🧠", st.session_state["providers"])
75
+ temperature = st.slider("Model temperature 🌡️", 0.0, 1.0, 0.1, help="The higher the more creative")
76
+ max_pages = st.slider("Max pages to retrieve 🔍", 1, 20, 15, help="How many web pages to retrive from the internet")
77
+ top_k_documents = st.slider("Nbr of doc extracts to consider 📄", 1, 20, 5, help="How many of the top extracts to consider")
78
+ reviewer_mode = st.checkbox("Draft / Comment / Rewrite mode ✍️", value=False, help="First generate a write, then comments and then rewrite")
79
+
80
+ with st.sidebar.expander("Links", expanded=False):
81
+ links_md = st.markdown("")
82
+
83
+ if reviewer_mode:
84
+ with st.sidebar.expander("Answer review", expanded=False):
85
+ st.caption("Draft")
86
+ draft_md = st.markdown("")
87
+ st.divider()
88
+ st.caption("Comments")
89
+ comments_md = st.markdown("")
90
+ st.divider()
91
+ st.caption("Comparaison")
92
+ comparaison_md = st.markdown("")
93
 
94
  if "messages" not in st.session_state:
95
  st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
96
+
97
  for message in st.session_state.messages:
98
  st.chat_message(message["role"]).write(message["content"])
99
  if message["role"] == "assistant" and 'message_id' in message:
 
117
  st.write(f"I should search the web for: {optimize_search_query}")
118
 
119
  sources = wc.get_sources(optimize_search_query, max_pages=max_pages)
120
+ links_md.markdown(create_links_markdown(sources))
121
 
122
  st.write(f"I'll now retrieve the {len(sources)} webpages and documents I found")
123
  contents = wc.get_links_contents(sources)
 
125
  st.write( f"Reading through the {len(contents)} sources I managed to retrieve")
126
  vector_store = wc.vectorize(contents, embedding_model=embedding_model)
127
  st.write(f"I collected {vector_store.index.ntotal} chunk of data and I can now answer")
128
+
129
+
130
+ if reviewer_mode:
131
+ st.write("Creating a draft")
132
+ draft_prompt = wr.build_rag_prompt(
133
+ chat, prompt, optimize_search_query,
134
+ vector_store, top_k=top_k_documents, callbacks=[ls_tracer])
135
+ draft = chat.invoke(draft_prompt, stream=False, config={ "callbacks": [ls_tracer]})
136
+ draft_md.markdown(draft.content)
137
+ st.write("Sending draft for review")
138
+ comments = cw.generate_comments(chat, prompt, draft, callbacks=[ls_tracer])
139
+ comments_md.markdown(comments)
140
+ st.write("Reviewing comments and generating final answer")
141
+ rag_prompt = cw.get_final_text_prompt(prompt, draft, comments)
142
+ else:
143
+ rag_prompt = wr.build_rag_prompt(
144
+ chat, prompt, optimize_search_query, vector_store,
145
+ top_k=top_k_documents, callbacks=[ls_tracer]
146
+ )
147
 
 
148
  with st.chat_message("assistant"):
149
  st_cb = StreamHandler(st.empty())
150
  result = chat.invoke(rag_prompt, stream=True, config={ "callbacks": [st_cb, ls_tracer]})
151
  response = result.content.strip()
152
  message_id = f"{prompt}{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
153
  st.session_state.messages.append({"role": "assistant", "content": response})
154
+
155
+ if st.session_state.messages[-1]["role"] == "assistant":
156
+ st.download_button(
157
+ label="Download",
158
+ data=st.session_state.messages[-1]["content"],
159
+ file_name=f"{message_id}.txt",
160
+ mime="text/plain"
161
+ )
162
+
163
+ if reviewer_mode:
164
+ compare_prompt = cw.get_compare_texts_prompts(prompt, draft_text=draft, final_text=response)
165
+ result = chat.invoke(compare_prompt, stream=False, config={ "callbacks": [ls_tracer]})
166
+ comparaison_md.markdown(result.content)
web_crawler.py CHANGED
@@ -35,12 +35,13 @@ def get_sources(query, max_pages=10, domain=None):
35
  json_response = response.json()
36
 
37
  if 'web' not in json_response or 'results' not in json_response['web']:
 
38
  raise Exception('Invalid API response format')
39
 
40
  final_results = [{
41
  'title': result['title'],
42
  'link': result['url'],
43
- 'snippet': result['description'],
44
  'favicon': result.get('profile', {}).get('img', '')
45
  } for result in json_response['web']['results']]
46
 
 
35
  json_response = response.json()
36
 
37
  if 'web' not in json_response or 'results' not in json_response['web']:
38
+ print(response.text)
39
  raise Exception('Invalid API response format')
40
 
41
  final_results = [{
42
  'title': result['title'],
43
  'link': result['url'],
44
+ 'snippet': extract(result['description'], output_format='txt', include_tables=False, include_images=False, include_formatting=True),
45
  'favicon': result.get('profile', {}).get('img', '')
46
  } for result in json_response['web']['results']]
47
 
web_rag.py CHANGED
@@ -28,13 +28,14 @@ from langchain.prompts.chat import (
28
  from langchain.prompts.prompt import PromptTemplate
29
  from langchain.retrievers.multi_query import MultiQueryRetriever
30
 
 
31
  from langchain_cohere.chat_models import ChatCohere
32
  from langchain_cohere.embeddings import CohereEmbeddings
33
  from langchain_fireworks.chat_models import ChatFireworks
34
- from langchain_groq import ChatGroq
 
35
  from langchain_openai import ChatOpenAI
36
  from langchain_openai.embeddings import OpenAIEmbeddings
37
- from langchain_community.chat_models.bedrock import BedrockChat
38
  from langchain_community.embeddings.bedrock import BedrockEmbeddings
39
  from langchain_community.chat_models.ollama import ChatOllama
40
 
@@ -44,15 +45,15 @@ def get_models(provider, model=None, temperature=0.0):
44
  credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
45
  if model is None:
46
  model = "anthropic.claude-3-sonnet-20240229-v1:0"
47
- chat_llm = BedrockChat(
48
  credentials_profile_name=credentials_profile_name,
49
  model_id=model,
50
- model_kwargs={"temperature": temperature, 'max_tokens': 8192 },
 
 
 
 
51
  )
52
- #embedding_model = BedrockEmbeddings(
53
- # model_id='cohere.embed-multilingual-v3',
54
- # credentials_profile_name=credentials_profile_name
55
- #)
56
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
57
  case 'openai':
58
  if model is None:
@@ -73,14 +74,17 @@ def get_models(provider, model=None, temperature=0.0):
73
  if model is None:
74
  model = 'command-r-plus'
75
  chat_llm = ChatCohere(model=model, temperature=temperature)
76
- embedding_model = CohereEmbeddings(model="embed-english-light-v3.0")
 
77
  case 'fireworks':
78
  if model is None:
79
- model = 'accounts/fireworks/models/mixtral-8x22b-instruct-preview'
80
- chat_llm = ChatFireworks(model_name=model, temperature=temperature)
 
81
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
82
  case _:
83
  raise ValueError(f"Unknown LLM provider {provider}")
 
84
  return chat_llm, embedding_model
85
 
86
 
@@ -96,12 +100,13 @@ def get_optimized_search_messages(query):
96
  """
97
  system_message = SystemMessage(
98
  content="""
99
- I want you to act as a prompt optimizer for web search. I will provide you with a chat prompt, and your goal is to optimize it into a search string that will yield the most relevant and useful information from a search engine like Google.
 
100
  To optimize the prompt:
101
- Identify the key information being requested
102
- Arrange the keywords into a concise search string
103
- Keep it short, around 1 to 5 words total
104
- Put the most important keywords first
105
 
106
  Some tips and things to be sure to remove:
107
  - Remove any conversational or instructional phrases
@@ -110,44 +115,44 @@ def get_optimized_search_messages(query):
110
  - Remove style instructions (exmaple: "in the style of", engaging, short, long)
111
  - Remove lenght instruction (example: essay, article, letter, etc)
112
 
113
- Add "**" to the end of the search string to indicate the end of the query
114
 
115
  Example:
116
  Question: How do I bake chocolate chip cookies from scratch?
117
- Search query: chocolate chip cookies recipe from scratch**
118
  Example:
119
  Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
120
- Search query: Marie Curie timeline**
121
  Example:
122
  Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
123
- Search query: geopolitics nato russia**
124
  Example:
125
  Question: Write an engaging LinkedIn post about Andrew Ng
126
- Search query: Andrew Ng**
127
  Example:
128
  Question: Write a short article about the solar system in the style of Carl Sagan
129
- Search query: solar system**
130
  Example:
131
  Question: Should I use Kubernetes? Answer in the style of Gilfoyle from the TV show Silicon Valley
132
- Search query: Kubernetes decision**
133
  Example:
134
  Question: Biography of Napoleon. Include a table with the major events.
135
- Search query: napoleon biography events**
136
  Example:
137
  Question: Write a short article on the history of the United States. Include a table with the major events.
138
- Search query: united states history events**
139
  Example:
140
  Question: Write a short article about the solar system in the style of donald trump
141
- Search query: solar system**
142
  Exmaple:
143
  Question: Write a short linkedin about how the "freakeconomics" book previsions didn't pan out
144
- Search query: freakeconomics book predictions failed**
145
  """
146
  )
147
  human_message = HumanMessage(
148
  content=f"""
149
  Question: {query}
150
- Search query:
151
  """
152
  )
153
  return [system_message, human_message]
@@ -230,15 +235,49 @@ def multi_query_rag(chat_llm, question, search_query, vectorstore, callbacks = [
230
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
231
  return response.content
232
 
233
-
234
- def build_rag_prompt(question, search_query, vectorstore, top_k = 10, callbacks = []):
235
- unique_docs = vectorstore.similarity_search(
236
- search_query, k=top_k, callbacks=callbacks, verbose=True)
237
- context = format_docs(unique_docs)
238
- prompt = get_rag_prompt_template().format(query=question, context=context)
239
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
242
- prompt = build_rag_prompt(question, search_query, vectorstore, top_k= top_k, callbacks = callbacks)
243
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
244
- return response.content
 
28
  from langchain.prompts.prompt import PromptTemplate
29
  from langchain.retrievers.multi_query import MultiQueryRetriever
30
 
31
+ from langchain_aws import ChatBedrock
32
  from langchain_cohere.chat_models import ChatCohere
33
  from langchain_cohere.embeddings import CohereEmbeddings
34
  from langchain_fireworks.chat_models import ChatFireworks
35
+ #from langchain_groq import ChatGroq
36
+ from langchain_groq.chat_models import ChatGroq
37
  from langchain_openai import ChatOpenAI
38
  from langchain_openai.embeddings import OpenAIEmbeddings
 
39
  from langchain_community.embeddings.bedrock import BedrockEmbeddings
40
  from langchain_community.chat_models.ollama import ChatOllama
41
 
 
45
  credentials_profile_name=os.getenv('CREDENTIALS_PROFILE_NAME')
46
  if model is None:
47
  model = "anthropic.claude-3-sonnet-20240229-v1:0"
48
+ chat_llm = ChatBedrock(
49
  credentials_profile_name=credentials_profile_name,
50
  model_id=model,
51
+ model_kwargs={"temperature": temperature, "max_tokens":4096 },
52
+ )
53
+ embedding_model = BedrockEmbeddings(
54
+ model_id='cohere.embed-multilingual-v3',
55
+ credentials_profile_name=credentials_profile_name
56
  )
 
 
 
 
57
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
58
  case 'openai':
59
  if model is None:
 
74
  if model is None:
75
  model = 'command-r-plus'
76
  chat_llm = ChatCohere(model=model, temperature=temperature)
77
+ #embedding_model = CohereEmbeddings(model="embed-english-light-v3.0")
78
+ embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
79
  case 'fireworks':
80
  if model is None:
81
+ #model = 'accounts/fireworks/models/dbrx-instruct'
82
+ model = 'accounts/fireworks/models/llama-v3-70b-instruct'
83
+ chat_llm = ChatFireworks(model_name=model, temperature=temperature, max_tokens=8192)
84
  embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
85
  case _:
86
  raise ValueError(f"Unknown LLM provider {provider}")
87
+
88
  return chat_llm, embedding_model
89
 
90
 
 
100
  """
101
  system_message = SystemMessage(
102
  content="""
103
+ I want you to act as a prompt optimizer for web search.
104
+ I will provide you with a chat prompt, and your goal is to optimize it into a search string that will yield the most relevant and useful information from a search engine like Google.
105
  To optimize the prompt:
106
+ - Identify the key information being requested
107
+ - Arrange the keywords into a concise search string
108
+ - Keep it short, around 1 to 5 words total
109
+ - Put the most important keywords first
110
 
111
  Some tips and things to be sure to remove:
112
  - Remove any conversational or instructional phrases
 
115
  - Remove style instructions (exmaple: "in the style of", engaging, short, long)
116
  - Remove lenght instruction (example: essay, article, letter, etc)
117
 
118
+ You should answer only with the optimized search query and add "**" to the end of the search string to indicate the end of the query
119
 
120
  Example:
121
  Question: How do I bake chocolate chip cookies from scratch?
122
+ chocolate chip cookies recipe from scratch**
123
  Example:
124
  Question: I would like you to show me a timeline of Marie Curie's life. Show results as a markdown table
125
+ Marie Curie timeline**
126
  Example:
127
  Question: I would like you to write a long article on NATO vs Russia. Use known geopolitical frameworks.
128
+ geopolitics nato russia**
129
  Example:
130
  Question: Write an engaging LinkedIn post about Andrew Ng
131
+ Andrew Ng**
132
  Example:
133
  Question: Write a short article about the solar system in the style of Carl Sagan
134
+ solar system**
135
  Example:
136
  Question: Should I use Kubernetes? Answer in the style of Gilfoyle from the TV show Silicon Valley
137
+ Kubernetes decision**
138
  Example:
139
  Question: Biography of Napoleon. Include a table with the major events.
140
+ napoleon biography events**
141
  Example:
142
  Question: Write a short article on the history of the United States. Include a table with the major events.
143
+ united states history events**
144
  Example:
145
  Question: Write a short article about the solar system in the style of donald trump
146
+ solar system**
147
  Exmaple:
148
  Question: Write a short linkedin about how the "freakeconomics" book previsions didn't pan out
149
+ freakeconomics book predictions failed**
150
  """
151
  )
152
  human_message = HumanMessage(
153
  content=f"""
154
  Question: {query}
155
+
156
  """
157
  )
158
  return [system_message, human_message]
 
235
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
236
  return response.content
237
 
238
+ def get_context_size(chat_llm):
239
+ if isinstance(chat_llm, ChatOpenAI):
240
+ if chat_llm.model_name.startswith("gpt-4"):
241
+ return 128000
242
+ else:
243
+ return 16385
244
+ if isinstance(chat_llm, ChatFireworks):
245
+ return 8192
246
+ if isinstance(chat_llm, ChatGroq):
247
+ return 37862
248
+ if isinstance(chat_llm, ChatOllama):
249
+ return 8192
250
+ if isinstance(chat_llm, ChatCohere):
251
+ return 128000
252
+ if isinstance(chat_llm, ChatBedrock):
253
+ if chat_llm.model_id.startswith("anthropic.claude-3"):
254
+ return 200000
255
+ if chat_llm.model_id.startswith("anthropic.claude"):
256
+ return 100000
257
+ if chat_llm.model_id.startswith("mistral"):
258
+ if chat_llm.model_id.startswith("mistral.mixtral-8x7b"):
259
+ return 4096
260
+ else:
261
+ return 8192
262
+ return 4096
263
+
264
+
265
+ def build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
266
+ done = False
267
+ while not done:
268
+ unique_docs = vectorstore.similarity_search(
269
+ search_query, k=top_k, callbacks=callbacks, verbose=True)
270
+ context = format_docs(unique_docs)
271
+ prompt = get_rag_prompt_template().format(query=question, context=context)
272
+ nbr_tokens = chat_llm.get_num_tokens(prompt)
273
+ if top_k <= 1 or nbr_tokens <= get_context_size(chat_llm) - 768:
274
+ done = True
275
+ else:
276
+ top_k = int(top_k * 0.75)
277
+
278
+ return prompt
279
 
280
  def query_rag(chat_llm, question, search_query, vectorstore, top_k = 10, callbacks = []):
281
+ prompt = build_rag_prompt(chat_llm, question, search_query, vectorstore, top_k=top_k, callbacks = callbacks)
282
  response = chat_llm.invoke(prompt, config={"callbacks": callbacks})
283
+ return response.content