Eddie Pick commited on
Commit
7402de3
1 Parent(s): 6f80de5

Changed provider/model delimiter to ':'

Browse files
Files changed (3) hide show
  1. README.md +7 -3
  2. models.py +9 -5
  3. search_agent.py +7 -0
README.md CHANGED
@@ -72,7 +72,7 @@ python search_agent.py [OPTIONS] SEARCH_QUERY
72
  - `-c`, `--copywrite`: First produce a draft, review it, and rewrite for a final text.
73
  - `-d DOMAIN`, `--domain=DOMAIN`: Limit search to a specific domain.
74
  - `-t TEMP`, `--temperature=TEMP`: Set the temperature of the LLM [default: 0.0].
75
- - `-m MODEL`, `--model=MODEL`: Use a specific model [default: openai/gpt-4o-mini].
76
  - `-e MODEL`, `--embedding_model=MODEL`: Use a specific embedding model [default: same provider as model].
77
  - `-n NUM`, `--max_pages=NUM`: Max number of pages to retrieve [default: 10].
78
  - `-x NUM`, `--max_extracts=NUM`: Max number of page extracts to consider [default: 7].
@@ -82,11 +82,15 @@ python search_agent.py [OPTIONS] SEARCH_QUERY
82
  ### Examples
83
 
84
  ```bash
85
- python search_agent.py -m openai/gpt-4o-mini "Write a linked post about the current state of M&A for startups. Write in the style of Russ from Silicon Valley TV show."
86
  ```
87
 
88
  ```bash
89
- python search_agent.py -m openai -e ollama -t 0.7 -n 20 -x 15 "Write a linked post about the state of M&A for startups in 2024. Write in the style of Russ from TV show Silicon Valley" -s
 
 
 
 
90
  ```
91
 
92
  ## License
 
72
  - `-c`, `--copywrite`: First produce a draft, review it, and rewrite for a final text.
73
  - `-d DOMAIN`, `--domain=DOMAIN`: Limit search to a specific domain.
74
  - `-t TEMP`, `--temperature=TEMP`: Set the temperature of the LLM [default: 0.0].
75
+ - `-m MODEL`, `--model=MODEL`: Use a specific model [default: openai:gpt-4o-mini].
76
  - `-e MODEL`, `--embedding_model=MODEL`: Use a specific embedding model [default: same provider as model].
77
  - `-n NUM`, `--max_pages=NUM`: Max number of pages to retrieve [default: 10].
78
  - `-x NUM`, `--max_extracts=NUM`: Max number of page extracts to consider [default: 7].
 
82
  ### Examples
83
 
84
  ```bash
85
+ python search_agent.py -m openai:gpt-4o-mini "Write a linked post about the current state of M&A for startups. Write in the style of Russ from Silicon Valley TV show."
86
  ```
87
 
88
  ```bash
89
+ python search_agent.py -m groq:llama-3.1-70b-versatile -e ollama:nomic-embed-text:latest -t 0.7 -n 20 -x 15 "Write a linked post about the state of M&A for startups in 2024. Write in the style of Russ from TV show Silicon Valley" -s
90
+ ```
91
+
92
+ ```bash
93
+ python search_agent.py -m groq -e openai "Write an engaging long linked post about the state of M&A for startups in 2024"
94
  ```
95
 
96
  ## License
models.py CHANGED
@@ -28,10 +28,14 @@ from langchain_community.chat_models import ChatPerplexity
28
  from langchain_together import ChatTogether
29
  from langchain_together.embeddings import TogetherEmbeddings
30
 
31
-
 
 
 
 
32
 
33
  def get_model(provider_model, temperature=0.0):
34
- provider, model = (provider_model.rstrip('/').split('/') + [None])[:2]
35
  match provider:
36
  case 'bedrock':
37
  if model is None:
@@ -76,8 +80,8 @@ def get_model(provider_model, temperature=0.0):
76
  return chat_llm
77
 
78
 
79
- def get_embedding_model(provider_embedding_model):
80
- provider, model = (provider_embedding_model.rstrip('/').split('/') + [None])[:2]
81
  match provider:
82
  case 'bedrock':
83
  if model is None:
@@ -224,7 +228,7 @@ class TestGetModel(unittest.TestCase):
224
  @patch('models.ChatGroq')
225
  def test_groq_model(self, mock_groq):
226
  result = get_model('groq')
227
- mock_groq.assert_called_once_with(model_name='llama2-70b-4096', temperature=0.0)
228
  self.assertEqual(result, mock_groq.return_value)
229
 
230
  @patch('models.ChatOllama')
 
28
  from langchain_together import ChatTogether
29
  from langchain_together.embeddings import TogetherEmbeddings
30
 
31
+ def split_provider_model(provider_model):
32
+ parts = provider_model.split(':', 1)
33
+ provider = parts[0]
34
+ model = parts[1] if len(parts) > 1 else None
35
+ return provider, model
36
 
37
  def get_model(provider_model, temperature=0.0):
38
+ provider, model = split_provider_model(provider_model)
39
  match provider:
40
  case 'bedrock':
41
  if model is None:
 
80
  return chat_llm
81
 
82
 
83
+ def get_embedding_model(provider_model):
84
+ provider, model = split_provider_model(provider_model)
85
  match provider:
86
  case 'bedrock':
87
  if model is None:
 
228
  @patch('models.ChatGroq')
229
  def test_groq_model(self, mock_groq):
230
  result = get_model('groq')
231
+ mock_groq.assert_called_once_with(model_name='llama-3.1-8b-instant', temperature=0.0)
232
  self.assertEqual(result, mock_groq.return_value)
233
 
234
  @patch('models.ChatOllama')
search_agent.py CHANGED
@@ -12,6 +12,7 @@ Usage:
12
  [--max_extracts=num]
13
  [--use_selenium]
14
  [--output=text]
 
15
  SEARCH_QUERY
16
  search_agent.py --version
17
 
@@ -27,6 +28,7 @@ Options:
27
  -x num --max_extracts=num Max number of page extract to consider [default: 7]
28
  -s --use_selenium Use selenium to fetch content from the web [default: False]
29
  -o text --output=text Output format (choices: text, markdown) [default: markdown]
 
30
 
31
  """
32
 
@@ -80,6 +82,7 @@ if os.getenv("LANGCHAIN_API_KEY"):
80
  )
81
  @traceable(run_type="tool", name="search_agent")
82
  def main(arguments):
 
83
  copywrite_mode = arguments["--copywrite"]
84
  model = arguments["--model"]
85
  embedding_model = arguments["--embedding_model"]
@@ -98,6 +101,10 @@ def main(arguments):
98
  else:
99
  embedding_model = md.get_embedding_model(embedding_model)
100
 
 
 
 
 
101
  with console.status(f"[bold green]Optimizing query for search: {query}"):
102
  optimize_search_query = wr.optimize_search_query(chat, query)
103
  if len(optimize_search_query) < 3:
 
12
  [--max_extracts=num]
13
  [--use_selenium]
14
  [--output=text]
15
+ [--verbose]
16
  SEARCH_QUERY
17
  search_agent.py --version
18
 
 
28
  -x num --max_extracts=num Max number of page extract to consider [default: 7]
29
  -s --use_selenium Use selenium to fetch content from the web [default: False]
30
  -o text --output=text Output format (choices: text, markdown) [default: markdown]
31
+ -v --verbose Print verbose output [default: False]
32
 
33
  """
34
 
 
82
  )
83
  @traceable(run_type="tool", name="search_agent")
84
  def main(arguments):
85
+ verbose = arguments["--verbose"]
86
  copywrite_mode = arguments["--copywrite"]
87
  model = arguments["--model"]
88
  embedding_model = arguments["--embedding_model"]
 
101
  else:
102
  embedding_model = md.get_embedding_model(embedding_model)
103
 
104
+ if verbose:
105
+ console.log(f"Using model: {chat.model_name}")
106
+ console.log(f"Using embedding model: { embedding_model.model}")
107
+
108
  with console.status(f"[bold green]Optimizing query for search: {query}"):
109
  optimize_search_query = wr.optimize_search_query(chat, query)
110
  if len(optimize_search_query) < 3: