sasan commited on
Commit
0f04201
·
1 Parent(s): fea02f6

chore: Update vehicle speed and destination handling functions

Browse files
kitt/core/__init__.py CHANGED
@@ -21,13 +21,13 @@ voices = [
21
  "Attenborough",
22
  neutral=f"{file_full_path}/audio/attenborough/neutral.wav",
23
  angry=None,
24
- speed=1.1,
25
  ),
26
  Voice(
27
  "Rick",
28
  neutral=f"{file_full_path}/audio/rick/neutral.wav",
29
  angry=None,
30
- speed=1.1,
31
  ),
32
  Voice(
33
  "Freeman",
@@ -45,7 +45,7 @@ voices = [
45
  "Darth Wader",
46
  neutral=f"{file_full_path}/audio/darth/neutral.wav",
47
  angry=None,
48
- speed=1.1,
49
  ),
50
  ]
51
 
 
21
  "Attenborough",
22
  neutral=f"{file_full_path}/audio/attenborough/neutral.wav",
23
  angry=None,
24
+ speed=1.2,
25
  ),
26
  Voice(
27
  "Rick",
28
  neutral=f"{file_full_path}/audio/rick/neutral.wav",
29
  angry=None,
30
+ speed=1.2,
31
  ),
32
  Voice(
33
  "Freeman",
 
45
  "Darth Wader",
46
  neutral=f"{file_full_path}/audio/darth/neutral.wav",
47
  angry=None,
48
+ speed=1.15,
49
  ),
50
  ]
51
 
kitt/core/legacy.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import json
3
+ import re
4
+ from loguru import logger
5
+
6
+ def use_tool(tool_call, tools):
7
+ func_name = tool_call["name"]
8
+ kwargs = tool_call["arguments"]
9
+ for tool in tools:
10
+ if tool.name == func_name:
11
+ return tool.invoke(input=kwargs)
12
+ raise ValueError(f"Tool {func_name} not found.")
13
+
14
+
15
+ def parse_tool_calls(text):
16
+ logger.debug(f"Start parsing tool_calls: {text}")
17
+ pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"
18
+
19
+ if not text.startswith("<tool_call>"):
20
+ if "<tool_call>" in text:
21
+ raise ValueError("<text_and_tool_call>")
22
+
23
+ if "<tool_response>" in text:
24
+ raise ValueError("<tool_response>")
25
+ return [], []
26
+
27
+ matches = re.findall(pattern, text, re.DOTALL)
28
+ tool_calls = []
29
+ errors = []
30
+ for match in matches:
31
+ try:
32
+ tool_call = json.loads(match)
33
+ tool_calls.append(tool_call)
34
+ except json.JSONDecodeError as e:
35
+ errors.append(f"Invalid JSON in tool call: {e}")
36
+
37
+ logger.debug(f"Tool calls: {tool_calls}, errors: {errors}")
38
+ return tool_calls, errors
39
+
40
+ def process_response(user_query, res, history, tools, depth):
41
+ """Returns True if the response contains tool calls, False otherwise."""
42
+ logger.debug(f"Processing response: {res}")
43
+ tool_results = f"Agent iteration {depth} to assist with user query: {user_query}\n"
44
+ tool_call_id = uuid.uuid4().hex
45
+ try:
46
+ tool_calls, errors = parse_tool_calls(res)
47
+ except ValueError as e:
48
+ if "<text_and_tool_call>" in str(e):
49
+ tool_results += "<tool_response>If you need to call a tool your response must be wrapped in <tool_call></tool_call>. Try again, you are great.</tool_response>"
50
+ history.add_message(
51
+ ToolMessage(content=tool_results, tool_call_id=tool_call_id)
52
+ )
53
+ return True, [], []
54
+ if "<tool_response>" in str(e):
55
+ tool_results += "<tool_response>Tool results are not allowed in the response.</tool_response>"
56
+ history.add_message(
57
+ ToolMessage(content=tool_results, tool_call_id=tool_call_id)
58
+ )
59
+ return True, [], []
60
+ # TODO: Handle errors
61
+ if not tool_calls:
62
+ logger.debug("No tool calls found in response.")
63
+ return False, tool_calls, errors
64
+ # tool_results = ""
65
+
66
+ for tool_call in tool_calls:
67
+ # TODO: Extra Validation
68
+ # Call the function
69
+ try:
70
+ result = use_tool(tool_call, tools)
71
+ logger.debug(f"Tool call {tool_call} result: {result}")
72
+ if isinstance(result, tuple):
73
+ result = result[1]
74
+ tool_results += f"<tool_response>\n{result}\n</tool_response>\n"
75
+ except Exception as e:
76
+ logger.error(f"Error calling tool: {e}")
77
+ # Currently only to mimic OpneAI's behavior
78
+ # But it could be used for tracking function calls
79
+
80
+ tool_results = tool_results.strip()
81
+ print(f"Tool results: {tool_results}")
82
+ history.add_message(ToolMessage(content=tool_results, tool_call_id=tool_call_id))
83
+ return True, tool_calls, errors
84
+
85
+
86
+ def process_query(
87
+ user_query: str,
88
+ history: ChatMessageHistory,
89
+ user_preferences,
90
+ tools,
91
+ backend="ollama",
92
+ ):
93
+ # Add vehicle status to the history
94
+ user_query_status = f"consider the vehicle status:\n{vehicle_status()[0]}\nwhen responding to the following query:\n{user_query}"
95
+ history.add_message(HumanMessage(content=user_query_status))
96
+ for depth in range(10):
97
+ # out = run_inference_step(depth, history, tools, schema_json)
98
+ out = run_inference_step(
99
+ depth,
100
+ history,
101
+ tools,
102
+ schema_json,
103
+ user_preferences=user_preferences,
104
+ backend=backend,
105
+ )
106
+ logger.info(f"Inference step result:\n{out}")
107
+ history.add_message(AIMessage(content=out))
108
+ to_continue, tool_calls, errors = process_response(
109
+ user_query, out, history, tools, depth
110
+ )
111
+ if errors:
112
+ history.add_message(AIMessage(content=f"Errors in tool calls: {errors}"))
113
+
114
+ if not to_continue:
115
+ print(f"This is the answer, no more iterations: {out}")
116
+ return out
117
+ # Otherwise, tools result is already added to history, we just need to continue the loop.
118
+ # If we get here something went wrong.
119
+ history.add_message(
120
+ AIMessage(content="Sorry, I am not sure how to help you with that.")
121
+ )
122
+ return "Sorry, I am not sure how to help you with that."
kitt/core/model.py CHANGED
@@ -1,18 +1,22 @@
 
1
  import json
2
  import re
3
  import uuid
 
 
 
4
 
5
  from langchain.memory import ChatMessageHistory
6
  from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
7
- from langchain_core.utils.function_calling import convert_to_openai_function
8
- import ollama
9
  from ollama import Client
10
  from pydantic import BaseModel
11
  from loguru import logger
12
 
13
-
14
  from kitt.skills import vehicle_status
15
  from kitt.skills.common import config
 
16
 
17
 
18
  class FunctionCall(BaseModel):
@@ -28,14 +32,52 @@ class FunctionCall(BaseModel):
28
  """The name of the function to call."""
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  schema_json = json.loads(FunctionCall.schema_json())
 
 
32
  HRMS_SYSTEM_PROMPT = """<|im_start|>system
33
- You are a function calling AI agent. Your name is KITT. You are embodied in a Car. You know where you are, where you are going, and the current date and time. You can call functions to help with user queries.
34
- You can call only one function at a time and analyse data you get from function response.
 
 
 
 
 
 
 
 
 
35
  You are provided with function signatures within <tools></tools> XML tags.
 
36
 
 
 
 
 
 
 
 
 
 
37
  You may use agentic frameworks for reasoning and planning to help with user query.
38
- Please call a function and wait for function results to be provided to you in the next iteration.
39
  Don't make assumptions about what values to plug into function arguments.
40
  Once you have called a function, results will be fed back to you within <tool_response></tool_response> XML tags.
41
  Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
@@ -44,49 +86,38 @@ At each iteration please continue adding the your analysis to previous summary.
44
  Your final response should directly answer the user query. Don't tell what you are doing, just do it.
45
 
46
 
 
47
  Here are the available tools:
48
  <tools> {tools} </tools>
49
- If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows:
50
- <tool_call>
51
- {{"arguments": {{"code_markdown": <python-code>, "name": "code_interpreter"}}}}
52
- </tool_call>
53
  Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
54
  When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
55
 
56
- Example 1:
57
- User: How is the weather?
58
- Assistant:
59
- <tool_call>
60
- {{"arguments": {{"location": ""}}, "name": "get_weather"}}
61
- </tool_call>
62
-
63
- Example 2:
64
- User: Is there a Spa nearby?
65
- Assistant:
66
- <tool_call>
67
- {{"arguments": {{"search_query": "Spa"}}, "name": "search_points_of_interests"}}
68
- </tool_call>
69
-
70
- Example 3:
71
- User: How long will it take to get to the destination?
72
- Assistant:
73
- <tool_call>
74
- {{"arguments": {{"destination": ""}}, "name": "calculate_route"}}
75
-
76
  When asked for the weather or points of interest, use the appropriate tool with the current location of the car. Unless the user provides a location, then use that location.
77
  Always assume user wants to travel by car.
78
 
 
79
  Use the following pydantic model json schema for each tool call you will make:
80
  {schema}
81
 
 
82
  At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
83
  Please keep a running summary with analysis of previous function results and summaries from previous iterations.
84
  Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
 
85
  If you plan to continue with analysis, always call another function.
86
- For each function call return a valid json object (using doulbe quotes) with function name and arguments within <tool_call></tool_call> XML tags as follows:
 
 
 
 
87
  <tool_call>
88
  {{"arguments": <args-dict>, "name": <function-name>}}
89
  </tool_call>
 
 
 
 
 
90
  <|im_end|>"""
91
  AI_PREAMBLE = """
92
  <|im_start|>assistant
@@ -103,6 +134,32 @@ HRMS_TEMPLATE_TOOL_RESULT = """
103
  <|im_end|>"""
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def append_message(prompt, h):
107
  if h.type == "human":
108
  prompt += HRMS_TEMPLATE_USER.format(user_input=h.content)
@@ -113,7 +170,7 @@ def append_message(prompt, h):
113
  return prompt
114
 
115
 
116
- def get_prompt(template, history, tools, schema, car_status=None):
117
  if not car_status:
118
  # car_status = vehicle.dict()
119
  car_status = vehicle_status()[0]
@@ -124,6 +181,7 @@ def get_prompt(template, history, tools, schema, car_status=None):
124
  "schema": schema,
125
  "tools": tools,
126
  "car_status": car_status,
 
127
  }
128
 
129
  prompt = template.format(**kwargs).replace("{{", "{").replace("}}", "}")
@@ -137,99 +195,31 @@ def get_prompt(template, history, tools, schema, car_status=None):
137
  return prompt
138
 
139
 
140
- def use_tool(tool_call, tools):
141
- func_name = tool_call["name"]
142
- kwargs = tool_call["arguments"]
143
- for tool in tools:
144
- if tool.name == func_name:
145
- return tool.invoke(input=kwargs)
146
- return None
147
-
148
 
149
- def parse_tool_calls(text):
150
- logger.debug(f"Start parsing tool_calls: {text}")
151
- pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"
152
-
153
- if not text.startswith("<tool_call>"):
154
- if "<tool_call>" in text:
155
- raise ValueError("<text_and_tool_call>")
156
- return [], []
157
-
158
- matches = re.findall(pattern, text, re.DOTALL)
159
- tool_calls = []
160
- errors = []
161
- for match in matches:
162
- try:
163
- tool_call = json.loads(match)
164
- tool_calls.append(tool_call)
165
- except json.JSONDecodeError as e:
166
- errors.append(f"Invalid JSON in tool call: {e}")
167
-
168
- logger.debug(f"Tool calls: {tool_calls}, errors: {errors}")
169
- return tool_calls, errors
170
-
171
-
172
- def process_response(user_query, res, history, tools, depth):
173
- """Returns True if the response contains tool calls, False otherwise."""
174
- logger.debug(f"Processing response: {res}")
175
- tool_results = f"Agent iteration {depth} to assist with user query: {user_query}\n"
176
- tool_call_id = uuid.uuid4().hex
177
- try:
178
- tool_calls, errors = parse_tool_calls(res)
179
- except ValueError as e:
180
- if "<text_and_tool_call>" in str(e):
181
- tool_results += f"A mix of text and tool_call was found, you must either answer the query in a short sentence or use tool_call not both. Try again, this time only using tool_call."
182
- history.add_message(
183
- ToolMessage(content=tool_results, tool_call_id=tool_call_id)
184
- )
185
- return True, [], []
186
- # TODO: Handle errors
187
- if not tool_calls:
188
- return False, tool_calls, errors
189
- # tool_results = ""
190
-
191
- for tool_call in tool_calls:
192
- # TODO: Extra Validation
193
- # Call the function
194
- try:
195
- result = use_tool(tool_call, tools)
196
- if isinstance(result, tuple):
197
- result = result[1]
198
- tool_results += f"<tool_response>\n{result}\n</tool_response>\n"
199
- except Exception as e:
200
- print(e)
201
- # Currently only to mimic OpneAI's behavior
202
- # But it could be used for tracking function calls
203
-
204
- tool_results = tool_results.strip()
205
- print(f"Tool results: {tool_results}")
206
- history.add_message(ToolMessage(content=tool_results, tool_call_id=tool_call_id))
207
- return True, tool_calls, errors
208
 
209
 
210
  def run_inference_ollama(prompt):
211
  data = {
212
- "prompt": prompt
213
- + "\nThis is the first turn and you don't have <tool_results> to analyze yet"
214
- + AI_PREAMBLE,
215
  # "streaming": False,
216
  # "model": "smangrul/llama-3-8b-instruct-function-calling",
217
  # "model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
218
  # "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
219
- # "model": "interstellarninja/hermes-2-pro-llama-3-8b",
220
- "model": "dolphin-llama3:8b",
221
  # "model": "dolphin-llama3:70b",
222
  "raw": True,
223
  "options": {
224
- "temperature": 0.8,
225
  # "max_tokens": 1500,
226
  "num_predict": 1500,
227
  # "mirostat": 1,
228
  # "mirostat_tau": 2,
229
- "repeat_penalty": 1.1,
230
  "top_k": 25,
231
  "top_p": 0.5,
232
  "num_ctx": 8000,
 
233
  # "num_predict": 1500,
234
  # "max_tokens": 1500,
235
  },
@@ -248,14 +238,26 @@ def run_inference_ollama(prompt):
248
 
249
 
250
  def run_inference_step(
251
- depth, history, tools, schema_json, dry_run=False, backend="ollama"
252
  ):
253
  # If we decide to call a function, we need to generate the prompt for the model
254
  # based on the history of the conversation so far.
255
  # not break the loop
256
- openai_tools = [convert_to_openai_function(tool) for tool in tools]
257
- prompt = get_prompt(HRMS_SYSTEM_PROMPT, history, openai_tools, schema_json)
258
- print(f"Prompt is:{prompt + AI_PREAMBLE}\n------------------\n")
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  if backend == "ollama":
261
  output = run_inference_ollama(prompt)
@@ -272,9 +274,7 @@ def run_inference_replicate(prompt):
272
  replicate = Client(api_token=config.REPLICATE_API_KEY)
273
 
274
  input = {
275
- "prompt": prompt
276
- + "\nThis is the first turn and you don't have <tool_results> to analyze yet"
277
- + AI_PREAMBLE,
278
  "temperature": 0.5,
279
  "system_prompt": "",
280
  "max_new_tokens": 1024,
@@ -283,41 +283,241 @@ def run_inference_replicate(prompt):
283
  }
284
 
285
  output = replicate.run(
286
- "mikeei/dolphin-2.9-llama3-8b-gguf:0f79fb14c45ae2b92e1f07d872dceed3afafcacd903258df487d3bec9e393cb2",
 
287
  input=input,
288
  )
289
  out = "".join(output)
290
 
 
 
291
  return out
292
 
293
 
294
- def process_query(
295
- user_query: str,
296
- history: ChatMessageHistory,
297
- user_preferences,
298
- tools,
299
- backend="ollama",
300
- ):
301
- # Add vehicle status to the history
302
- user_query_status = f"Given that:\n{vehicle_status()[0]}\nUser preferences:\n{user_preferences}\nAnswer the following:\n{user_query}"
303
- history.add_message(HumanMessage(content=user_query_status))
304
- for depth in range(10):
305
- # out = run_inference_step(depth, history, tools, schema_json)
306
- out = run_inference_step(depth, history, tools, schema_json, backend=backend)
307
- print(f"Inference step result:\n{out}\n------------------\n")
308
- history.add_message(AIMessage(content=out))
309
- to_continue, tool_calls, errors = process_response(
310
- user_query, out, history, tools, depth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  )
312
- if errors:
313
- history.add_message(AIMessage(content=f"Errors in tool calls: {errors}"))
314
-
315
- if not to_continue:
316
- print(f"This is the answer, no more iterations: {out}")
317
- return out
318
- # Otherwise, tools result is already added to history, we just need to continue the loop.
319
- # If we get here something went wrong.
320
- history.add_message(
321
- AIMessage(content="Sorry, I am not sure how to help you with that.")
 
 
 
 
 
 
 
 
 
 
322
  )
323
- return "Sorry, I am not sure how to help you with that."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
  import json
3
  import re
4
  import uuid
5
+ from enum import Enum
6
+ from typing import List
7
+ import xml.etree.ElementTree as ET
8
 
9
  from langchain.memory import ChatMessageHistory
10
  from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
11
+ from langchain_core.utils.function_calling import convert_to_openai_tool
12
+ from langchain.tools.base import StructuredTool
13
  from ollama import Client
14
  from pydantic import BaseModel
15
  from loguru import logger
16
 
 
17
  from kitt.skills import vehicle_status
18
  from kitt.skills.common import config
19
+ from .validator import validate_function_call_schema
20
 
21
 
22
  class FunctionCall(BaseModel):
 
32
  """The name of the function to call."""
33
 
34
 
35
+ class ResponseType(Enum):
36
+ TOOL_CALL = "tool_call"
37
+ TEXT = "text"
38
+
39
+
40
+ class AssistantResponse(BaseModel):
41
+ tool_calls: List[FunctionCall]
42
+ """The tool call to make to get the response."""
43
+
44
+ response_type: ResponseType = (
45
+ ResponseType.TOOL_CALL
46
+ ) # The type of response to make to the user. Either 'tool_call' or 'text'.
47
+ """The type of response to make to the user. Either 'tool_call' or 'text'."""
48
+
49
+ response: str
50
+
51
+
52
  schema_json = json.loads(FunctionCall.schema_json())
53
+ # schema_json = json.loads(AssistantResponse.schema_json())
54
+
55
  HRMS_SYSTEM_PROMPT = """<|im_start|>system
56
+ You are a helpful assistant that answers in JSON. Here's the json schema you must adhere to:
57
+ <schema>
58
+ {schema}
59
+ <schema><|im_end|>"""
60
+
61
+
62
+ HRMS_SYSTEM_PROMPT = """<|im_start|>system
63
+ Role:
64
+ Your name is KITT. You are embodied in a Car. The user is a human who is a passenger in the car. You have autonomy to use the tools available to you to assist the user.
65
+ You are the AI assistant in the car. From the information in <car_status></car_status you know where you are, the destination, and the current date and time.
66
+ You are witty, helpful, and have a good sense of humor. You are a function calling AI agent with self-recursion.
67
  You are provided with function signatures within <tools></tools> XML tags.
68
+ User preferences are provided in <user_preferences></user_preferences> XML tags. Use them if needed.
69
 
70
+ <car_status>
71
+ {car_status}
72
+ </car_status>
73
+
74
+ <user_preferences>
75
+ {user_preferences}
76
+ </user_preferences>
77
+
78
+ Objective:
79
  You may use agentic frameworks for reasoning and planning to help with user query.
80
+ Please call one or two functions at a time, the function results to be provided to you immediately. Try to answer the user query, with as little back and forth as possible.
81
  Don't make assumptions about what values to plug into function arguments.
82
  Once you have called a function, results will be fed back to you within <tool_response></tool_response> XML tags.
83
  Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
 
86
  Your final response should directly answer the user query. Don't tell what you are doing, just do it.
87
 
88
 
89
+ Tools:
90
  Here are the available tools:
91
  <tools> {tools} </tools>
 
 
 
 
92
  Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
93
  When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  When asked for the weather or points of interest, use the appropriate tool with the current location of the car. Unless the user provides a location, then use that location.
96
  Always assume user wants to travel by car.
97
 
98
+ Schema:
99
  Use the following pydantic model json schema for each tool call you will make:
100
  {schema}
101
 
102
+ Instructions:
103
  At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
104
  Please keep a running summary with analysis of previous function results and summaries from previous iterations.
105
  Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
106
+ Calling multiple functions at once can overload the system and increase cost so call one function at a time please.
107
  If you plan to continue with analysis, always call another function.
108
+ For each function call return a valid json object (using double quotes) with function name and arguments within <tool_call></tool_call> XML tags as follows:
109
+ <tool_call>
110
+ {{"arguments": <args-dict>, "name": <function-name>}}
111
+ </tool_call>
112
+ If there are more than one function call, return multiple <tool_call></tool_call> XML tags, for example:
113
  <tool_call>
114
  {{"arguments": <args-dict>, "name": <function-name>}}
115
  </tool_call>
116
+ <tool_call>
117
+ {{"arguments": <args-dict>, "name": <function-name>}}
118
+ </tool_call>
119
+ You have to open and close the XML tags for each function call.
120
+
121
  <|im_end|>"""
122
  AI_PREAMBLE = """
123
  <|im_start|>assistant
 
134
  <|im_end|>"""
135
 
136
 
137
+ """
138
+ Below are a few examples, but they are not exhaustive. You can call any tool as long as it is within the <tools></tools> XML tags. Also examples are simplified and don't include all the tags you will see in the conversation.
139
+ Example 1:
140
+ User: How is the weather?
141
+ Assistant:
142
+ <tool_call>
143
+ {{"arguments": {{"location": ""}}, "name": "get_weather"}}
144
+ </tool_call>
145
+
146
+ Example 2:
147
+ User: Is there a Spa nearby?
148
+ Assistant:
149
+ <tool_call>
150
+ {{"arguments": {{"search_query": "Spa"}}, "name": "search_points_of_interest"}}
151
+ </tool_call>
152
+
153
+
154
+ Example 3:
155
+ User: How long will it take to get to the destination?
156
+ Assistant:
157
+ <tool_call>
158
+ {{"arguments": {{"destination": ""}}, "name": "calculate_route"}}
159
+ </tool_call>
160
+ """
161
+
162
+
163
  def append_message(prompt, h):
164
  if h.type == "human":
165
  prompt += HRMS_TEMPLATE_USER.format(user_input=h.content)
 
170
  return prompt
171
 
172
 
173
+ def get_prompt(template, history, tools, schema, user_preferences, car_status=None):
174
  if not car_status:
175
  # car_status = vehicle.dict()
176
  car_status = vehicle_status()[0]
 
181
  "schema": schema,
182
  "tools": tools,
183
  "car_status": car_status,
184
+ "user_preferences": user_preferences,
185
  }
186
 
187
  prompt = template.format(**kwargs).replace("{{", "{").replace("}}", "}")
 
195
  return prompt
196
 
197
 
 
 
 
 
 
 
 
 
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
 
201
  def run_inference_ollama(prompt):
202
  data = {
203
+ "prompt": prompt,
 
 
204
  # "streaming": False,
205
  # "model": "smangrul/llama-3-8b-instruct-function-calling",
206
  # "model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
207
  # "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
208
+ "model": "interstellarninja/hermes-2-pro-llama-3-8b",
209
+ # "model": "dolphin-llama3:8b",
210
  # "model": "dolphin-llama3:70b",
211
  "raw": True,
212
  "options": {
213
+ "temperature": 0.7,
214
  # "max_tokens": 1500,
215
  "num_predict": 1500,
216
  # "mirostat": 1,
217
  # "mirostat_tau": 2,
218
+ "repeat_penalty": 1.2,
219
  "top_k": 25,
220
  "top_p": 0.5,
221
  "num_ctx": 8000,
222
+ # "stop": ["<|im_end|>"]
223
  # "num_predict": 1500,
224
  # "max_tokens": 1500,
225
  },
 
238
 
239
 
240
  def run_inference_step(
241
+ depth, history, tools, schema_json, user_preferences, backend="ollama"
242
  ):
243
  # If we decide to call a function, we need to generate the prompt for the model
244
  # based on the history of the conversation so far.
245
  # not break the loop
246
+ openai_tools = [convert_to_openai_tool(tool) for tool in tools]
247
+ prompt = get_prompt(
248
+ HRMS_SYSTEM_PROMPT,
249
+ history,
250
+ openai_tools,
251
+ schema_json,
252
+ user_preferences=user_preferences,
253
+ )
254
+ logger.debug(f"History is: {history.messages}")
255
+
256
+ # if depth == 0:
257
+ # prompt += "\nThis is the first turn and you don't have <tool_results> to analyze yet."
258
+ prompt += AI_PREAMBLE
259
+
260
+ logger.info(f"Prompt is:\n{prompt}")
261
 
262
  if backend == "ollama":
263
  output = run_inference_ollama(prompt)
 
274
  replicate = Client(api_token=config.REPLICATE_API_KEY)
275
 
276
  input = {
277
+ "prompt": prompt,
 
 
278
  "temperature": 0.5,
279
  "system_prompt": "",
280
  "max_new_tokens": 1024,
 
283
  }
284
 
285
  output = replicate.run(
286
+ # "mikeei/dolphin-2.9-llama3-8b-gguf:0f79fb14c45ae2b92e1f07d872dceed3afafcacd903258df487d3bec9e393cb2",
287
+ "sasan-j/hermes-2-pro-llama-3-8b:28b1dc16f47d9df68d9839418282315d5e78d9e2ab3fa6ff15728c76ae71a6d6",
288
  input=input,
289
  )
290
  out = "".join(output)
291
 
292
+ logger.debug(f"Response from Ollama:\nOut:{out}")
293
+
294
  return out
295
 
296
 
297
+ def run_inference(prompt, backend="ollama"):
298
+ prompt += AI_PREAMBLE
299
+
300
+ logger.info(f"Prompt is:\n{prompt}")
301
+
302
+ if backend == "ollama":
303
+ output = run_inference_ollama(prompt)
304
+ else:
305
+ output = run_inference_replicate(prompt)
306
+
307
+ logger.debug(f"Response from model: {output}")
308
+ return output
309
+
310
+
311
+ def validate_and_extract_tool_calls(assistant_content):
312
+ validation_result = False
313
+ tool_calls = []
314
+ error_message = None
315
+
316
+ try:
317
+ # wrap content in root element
318
+ xml_root_element = f"<root>{assistant_content}</root>"
319
+ root = ET.fromstring(xml_root_element)
320
+
321
+ # extract JSON data
322
+ for element in root.findall(".//tool_call"):
323
+ json_data = None
324
+ try:
325
+ json_text = element.text.strip()
326
+
327
+ try:
328
+ # Prioritize json.loads for better error handling
329
+ json_data = json.loads(json_text)
330
+ except json.JSONDecodeError as json_err:
331
+ try:
332
+ # Fallback to ast.literal_eval if json.loads fails
333
+ json_data = ast.literal_eval(json_text)
334
+ except (SyntaxError, ValueError) as eval_err:
335
+ error_message = (
336
+ f"JSON parsing failed with both json.loads and ast.literal_eval:\n"
337
+ f"- JSON Decode Error: {json_err}\n"
338
+ f"- Fallback Syntax/Value Error: {eval_err}\n"
339
+ f"- Problematic JSON text: {json_text}"
340
+ )
341
+ logger.error(error_message)
342
+ continue
343
+ except Exception as e:
344
+ error_message = f"Cannot strip text: {e}"
345
+ logger.error(error_message)
346
+
347
+ if json_data is not None:
348
+ tool_calls.append(json_data)
349
+ validation_result = True
350
+
351
+ except ET.ParseError as err:
352
+ error_message = f"XML Parse Error: {err}"
353
+ logger.error(f"XML Parse Error: {err}")
354
+
355
+ # Return default values if no valid data is extracted
356
+ return validation_result, tool_calls, error_message
357
+
358
+
359
+ def execute_function_call(tool_call, functions):
360
+ function_name = tool_call.get("name")
361
+ for tool in functions:
362
+ if tool.name == function_name:
363
+ function_to_call = tool
364
+ break
365
+ else:
366
+ raise ValueError(f"Function {function_name} not found.")
367
+ function_args = tool_call.get("arguments", {})
368
+
369
+ logger.info(f"Invoking function call {function_name} ...")
370
+ if isinstance(function_to_call, StructuredTool):
371
+ function_response = function_to_call.invoke(input=function_args)
372
+ else:
373
+ function_response = function_to_call(*function_args.values())
374
+ results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
375
+ return results_dict
376
+
377
+
378
+ def process_completion_and_validate(completion):
379
+
380
+ # I think I don't need this.
381
+ # assistant_message = get_assistant_message(completion, eos_token="<|im_end|>")
382
+ assistant_message = completion.strip()
383
+
384
+ if assistant_message:
385
+ validation, tool_calls, error_message = validate_and_extract_tool_calls(
386
+ assistant_message
387
  )
388
+
389
+ if validation:
390
+ logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
391
+ return tool_calls, assistant_message, error_message
392
+ else:
393
+ tool_calls = None
394
+ return tool_calls, assistant_message, error_message
395
+ else:
396
+ logger.warning("Assistant message is None")
397
+ raise ValueError("Assistant message is None")
398
+
399
+
400
+ UNRESOLVED_MSG = "I'm sorry, I'm not sure how to help you with that."
401
+
402
+
403
+ def get_assistant_message(completion, eos_token):
404
+ """define and match pattern to find the assistant message"""
405
+ completion = completion.strip()
406
+ assistant_pattern = re.compile(
407
+ r"<\|im_start\|>\s*assistant((?:(?!<\|im_start\|>\s*assistant).)*)$", re.DOTALL
408
  )
409
+ assistant_match = assistant_pattern.search(completion)
410
+ if assistant_match:
411
+ assistant_content = assistant_match.group(1).strip()
412
+ return assistant_content.replace(eos_token, "")
413
+ else:
414
+ assistant_content = None
415
+ logger.info("No match found for the assistant pattern")
416
+ return assistant_content
417
+
418
+
419
+ def generate_function_call(
420
+ query, history, user_preferences, tools, functions, backend, max_depth=5
421
+ ) -> str:
422
+ """
423
+ Largely taken from https://github.com/NousResearch/Hermes-Function-Calling
424
+ """
425
+
426
+ try:
427
+ depth = 0
428
+ # user_message = f"{query}\nThis is the first turn and you don't have <tool_results> to analyze yet"
429
+ user_message = f"{query}"
430
+ # chat = [{"role": "user", "content": user_message}]
431
+ history.add_message(HumanMessage(content=user_message))
432
+
433
+ # openai_tools = [convert_to_openai_function(tool) for tool in tools]
434
+ prompt = get_prompt(
435
+ HRMS_SYSTEM_PROMPT,
436
+ history,
437
+ tools,
438
+ schema_json,
439
+ user_preferences=user_preferences,
440
+ )
441
+ logger.debug(f"History is: {history.json()}")
442
+
443
+ # if depth == 0:
444
+ # prompt += "\nThis is the first turn and you don't have <tool_results> to analyze yet."
445
+ completion = run_inference(prompt, backend=backend)
446
+
447
+ def recursive_loop(prompt, completion, depth) -> str:
448
+ nonlocal max_depth
449
+ tool_calls, assistant_message, error_message = (
450
+ process_completion_and_validate(completion)
451
+ )
452
+ # prompt.append({"role": "assistant", "content": assistant_message})
453
+ history.add_message(AIMessage(content=assistant_message))
454
+
455
+ tool_message = (
456
+ f"Agent iteration {depth} to assist with user query: {query}\n"
457
+ )
458
+ if tool_calls:
459
+ logger.info(f"Assistant Message:\n{assistant_message}")
460
+ for tool_call in tool_calls:
461
+ validation, message = validate_function_call_schema(
462
+ tool_call, tools
463
+ )
464
+ if validation:
465
+ try:
466
+ function_response = execute_function_call(
467
+ tool_call, functions=functions
468
+ )
469
+ tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n"
470
+ logger.info(
471
+ f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}"
472
+ )
473
+ except Exception as e:
474
+ logger.warning(f"Could not execute function: {e}")
475
+ tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
476
+ else:
477
+ logger.error(message)
478
+ tool_message += f"<tool_response>\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
479
+ # prompt.append({"role": "tool", "content": tool_message})
480
+ history.add_message(
481
+ ToolMessage(content=tool_message, tool_call_id=uuid.uuid4().hex)
482
+ )
483
+
484
+ depth += 1
485
+ if depth >= max_depth:
486
+ logger.warning(
487
+ f"Maximum recursion depth reached ({max_depth}). Stopping recursion."
488
+ )
489
+ return UNRESOLVED_MSG
490
+
491
+ prompt = get_prompt(
492
+ HRMS_SYSTEM_PROMPT,
493
+ history,
494
+ tools,
495
+ schema_json,
496
+ user_preferences=user_preferences,
497
+ )
498
+ completion = run_inference(prompt, backend=backend)
499
+ return recursive_loop(prompt, completion, depth)
500
+ elif error_message:
501
+ logger.info(f"Assistant Message:\n{assistant_message}")
502
+ tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>"
503
+ prompt.append({"role": "tool", "content": tool_message})
504
+
505
+ depth += 1
506
+ if depth >= max_depth:
507
+ logger.warning(
508
+ f"Maximum recursion depth reached ({max_depth}). Stopping recursion."
509
+ )
510
+ return UNRESOLVED_MSG
511
+
512
+ completion = run_inference(prompt, backend=backend)
513
+ return recursive_loop(prompt, completion, depth)
514
+ else:
515
+ logger.info(f"Assistant Message:\n{assistant_message}")
516
+ return assistant_message
517
+
518
+ return recursive_loop(prompt, completion, depth) # noqa
519
+
520
+ except Exception as e:
521
+ logger.error(f"Exception occurred: {e}")
522
+ return UNRESOLVED_MSG
523
+ # raise e
kitt/core/schema.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List, Dict, Literal, Optional
3
+
4
+ class FunctionCall(BaseModel):
5
+ arguments: dict
6
+ """
7
+ The arguments to call the function with, as generated by the model in JSON
8
+ format. Note that the model does not always generate valid JSON, and may
9
+ hallucinate parameters not defined by your function schema. Validate the
10
+ arguments in your code before calling your function.
11
+ """
12
+
13
+ name: str
14
+ """The name of the function to call."""
15
+
16
+ class FunctionDefinition(BaseModel):
17
+ name: str
18
+ description: Optional[str] = None
19
+ parameters: Optional[Dict[str, object]] = None
20
+
21
+ class FunctionSignature(BaseModel):
22
+ function: FunctionDefinition
23
+ type: Literal["function"]
kitt/core/utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from typing import List, Tuple, Optional, Union
2
 
3
 
@@ -33,3 +35,27 @@ def plot_route(points, vehicle: Union[tuple[float, float], None] = None):
33
  fig.update_geos(fitbounds="locations")
34
  fig.update_layout(margin={"r": 20, "t": 20, "l": 20, "b": 20})
35
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
  from typing import List, Tuple, Optional, Union
4
 
5
 
 
35
  fig.update_geos(fitbounds="locations")
36
  fig.update_layout(margin={"r": 20, "t": 20, "l": 20, "b": 20})
37
  return fig
38
+
39
+
40
+ def extract_json_from_markdown(text):
41
+ """
42
+ Extracts the JSON string from the given text using a regular expression pattern.
43
+
44
+ Args:
45
+ text (str): The input text containing the JSON string.
46
+
47
+ Returns:
48
+ dict: The JSON data loaded from the extracted string, or None if the JSON string is not found.
49
+ """
50
+ json_pattern = r'```json\r?\n(.*?)\r?\n```'
51
+ match = re.search(json_pattern, text, re.DOTALL)
52
+ if match:
53
+ json_string = match.group(1)
54
+ try:
55
+ data = json.loads(json_string)
56
+ return data
57
+ except json.JSONDecodeError as e:
58
+ print(f"Error decoding JSON string: {e}")
59
+ else:
60
+ print("JSON string not found in the text.")
61
+ return None
kitt/core/validator.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import json
3
+ from jsonschema import validate
4
+ from pydantic import ValidationError
5
+ from loguru import logger
6
+
7
+ from .utils import extract_json_from_markdown
8
+ from .schema import FunctionCall, FunctionSignature
9
+
10
+ def validate_function_call_schema(call, signatures):
11
+ try:
12
+ call_data = FunctionCall(**call)
13
+ except ValidationError as e:
14
+ return False, str(e)
15
+
16
+ for signature in signatures:
17
+ try:
18
+ signature_data = FunctionSignature(**signature)
19
+ if signature_data.function.name == call_data.name:
20
+ # Validate types in function arguments
21
+ for arg_name, arg_schema in signature_data.function.parameters.get('properties', {}).items():
22
+ if arg_name in call_data.arguments:
23
+ call_arg_value = call_data.arguments[arg_name]
24
+ if call_arg_value:
25
+ try:
26
+ validate_argument_type(arg_name, call_arg_value, arg_schema)
27
+ except Exception as arg_validation_error:
28
+ return False, str(arg_validation_error)
29
+
30
+ # Check if all required arguments are present
31
+ required_arguments = signature_data.function.parameters.get('required', [])
32
+ result, missing_arguments = check_required_arguments(call_data.arguments, required_arguments)
33
+ if not result:
34
+ return False, f"Missing required arguments: {missing_arguments}"
35
+
36
+ return True, None
37
+ except Exception as e:
38
+ # Handle validation errors for the function signature
39
+ return False, str(e)
40
+
41
+ # No matching function signature found
42
+ return False, f"No matching function signature found for function: {call_data.name}"
43
+
44
+ def check_required_arguments(call_arguments, required_arguments):
45
+ missing_arguments = [arg for arg in required_arguments if arg not in call_arguments]
46
+ return not bool(missing_arguments), missing_arguments
47
+
48
+ def validate_enum_value(arg_name, arg_value, enum_values):
49
+ if arg_value not in enum_values:
50
+ raise Exception(
51
+ f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}"
52
+ )
53
+
54
+ def validate_argument_type(arg_name, arg_value, arg_schema):
55
+ arg_type = arg_schema.get('type', None)
56
+ if arg_type:
57
+ if arg_type == 'string' and 'enum' in arg_schema:
58
+ enum_values = arg_schema['enum']
59
+ if None not in enum_values and enum_values != []:
60
+ try:
61
+ validate_enum_value(arg_name, arg_value, enum_values)
62
+ except Exception as e:
63
+ # Propagate the validation error message
64
+ raise Exception(f"Error validating function call: {e}")
65
+
66
+ python_type = get_python_type(arg_type)
67
+ if not isinstance(arg_value, python_type):
68
+ raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}")
69
+
70
+ def get_python_type(json_type):
71
+ type_mapping = {
72
+ 'string': str,
73
+ 'number': (int, float),
74
+ 'integer': int,
75
+ 'boolean': bool,
76
+ 'array': list,
77
+ 'object': dict,
78
+ 'null': type(None),
79
+ }
80
+ return type_mapping[json_type]
81
+
82
+ def validate_json_data(json_object, json_schema):
83
+ valid = False
84
+ error_message = None
85
+ result_json = None
86
+
87
+ try:
88
+ # Attempt to load JSON using json.loads
89
+ try:
90
+ result_json = json.loads(json_object)
91
+ except json.decoder.JSONDecodeError:
92
+ # If json.loads fails, try ast.literal_eval
93
+ try:
94
+ result_json = ast.literal_eval(json_object)
95
+ except (SyntaxError, ValueError) as e:
96
+ try:
97
+ result_json = extract_json_from_markdown(json_object)
98
+ except Exception as e:
99
+ error_message = f"JSON decoding error: {e}"
100
+ logger.info(f"Validation failed for JSON data: {error_message}")
101
+ return valid, result_json, error_message
102
+
103
+ # Return early if both json.loads and ast.literal_eval fail
104
+ if result_json is None:
105
+ error_message = "Failed to decode JSON data"
106
+ logger.info(f"Validation failed for JSON data: {error_message}")
107
+ return valid, result_json, error_message
108
+
109
+ # Validate each item in the list against schema if it's a list
110
+ if isinstance(result_json, list):
111
+ for index, item in enumerate(result_json):
112
+ try:
113
+ validate(instance=item, schema=json_schema)
114
+ logger.info(f"Item {index+1} is valid against the schema.")
115
+ except ValidationError as e:
116
+ error_message = f"Validation failed for item {index+1}: {e}"
117
+ break
118
+ else:
119
+ # Default to validation without list
120
+ try:
121
+ validate(instance=result_json, schema=json_schema)
122
+ except ValidationError as e:
123
+ error_message = f"Validation failed: {e}"
124
+
125
+ except Exception as e:
126
+ error_message = f"Error occurred: {e}"
127
+
128
+ if error_message is None:
129
+ valid = True
130
+ logger.info("JSON data is valid against the schema.")
131
+ else:
132
+ logger.info(f"Validation failed for JSON data: {error_message}")
133
+
134
+ return valid, result_json, error_message
kitt/skills/__init__.py CHANGED
@@ -1,10 +1,11 @@
1
  from datetime import datetime
2
  import inspect
 
3
 
4
  from .common import execute_function_call, extract_func_args, vehicle as vehicle_obj
5
  from .weather import get_weather_current_location, get_weather, get_forecast
6
  from .routing import find_route
7
- from .poi import search_points_of_interests, search_along_route_w_coordinates
8
  from .vehicle import vehicle_status, set_vehicle_speed, set_vehicle_destination
9
  from .interpreter import code_interpreter
10
 
@@ -32,6 +33,8 @@ def format_functions_for_prompt_raven(*functions):
32
  """
33
  formatted_functions = []
34
  for func in functions:
 
 
35
  signature = f"{func.__name__}{inspect.signature(func)}"
36
  docstring = inspect.getdoc(func)
37
  formatted_functions.append(
@@ -40,4 +43,4 @@ def format_functions_for_prompt_raven(*functions):
40
  return "\n".join(formatted_functions)
41
 
42
 
43
- SKILLS_PROMPT = format_functions_for_prompt_raven(get_weather, get_forecast, find_route, search_points_of_interests)
 
1
  from datetime import datetime
2
  import inspect
3
+ from langchain.tools import StructuredTool
4
 
5
  from .common import execute_function_call, extract_func_args, vehicle as vehicle_obj
6
  from .weather import get_weather_current_location, get_weather, get_forecast
7
  from .routing import find_route
8
+ from .poi import search_points_of_interest, search_along_route_w_coordinates
9
  from .vehicle import vehicle_status, set_vehicle_speed, set_vehicle_destination
10
  from .interpreter import code_interpreter
11
 
 
33
  """
34
  formatted_functions = []
35
  for func in functions:
36
+ if isinstance(func, StructuredTool):
37
+ func = func.func
38
  signature = f"{func.__name__}{inspect.signature(func)}"
39
  docstring = inspect.getdoc(func)
40
  formatted_functions.append(
 
43
  return "\n".join(formatted_functions)
44
 
45
 
46
+ SKILLS_PROMPT = format_functions_for_prompt_raven(get_weather, get_forecast, find_route, search_points_of_interest)
kitt/skills/poi.py CHANGED
@@ -1,5 +1,8 @@
1
  import json
 
2
  import requests
 
 
3
  from .common import config, vehicle
4
 
5
 
@@ -16,7 +19,8 @@ def _select_equally_spaced_coordinates(coords, number_of_points=10):
16
  return selected_coords
17
 
18
 
19
- def search_points_of_interests(search_query="french restaurant"):
 
20
  """
21
  Get some of the closest points of interest matching the query.
22
 
@@ -27,16 +31,31 @@ def search_points_of_interests(search_query="french restaurant"):
27
  # Extract the latitude and longitude of the vehicle
28
  vehicle_coordinates = getattr(vehicle, "location_coordinates")
29
  lat, lon = vehicle_coordinates
30
- print(f"POI search vehicle's lat: {lat}, lon: {lon}")
31
 
32
  # https://developer.tomtom.com/search-api/documentation/search-service/search
33
- r = requests.get(
34
- f"https://api.tomtom.com/search/2/search/{search_query}.json?key={config.TOMTOM_API_KEY}&lat={lat}&lon={lon}&category&radius=1000&limit=100",
35
- timeout=5,
36
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # Parse JSON from the response
39
  data = r.json()
 
 
40
  # Extract results
41
  results = data["results"]
42
 
@@ -57,7 +76,7 @@ def search_points_of_interests(search_query="french restaurant"):
57
  output = (
58
  f"There are {len(results)} options in the vicinity. The most relevant are: "
59
  )
60
- return output + ".\n ".join(formatted_results)
61
 
62
 
63
  def find_points_of_interest(lat="0", lon="0", type_of_poi="restaurant"):
@@ -69,12 +88,14 @@ def find_points_of_interest(lat="0", lon="0", type_of_poi="restaurant"):
69
  :param type_of_poi (string): Required. type of point of interest depending on what the user wants to do.
70
  """
71
  # https://developer.tomtom.com/search-api/documentation/search-service/points-of-interest-search
72
- r = requests.get(
73
- f"https://api.tomtom.com/search/2/search/{type_of_poi}"
74
- ".json?key={0}&lat={1}&lon={2}&radius=10000&vehicleTypeSet=Car&idxSet=POI&limit=100".format(
75
- config.TOMTOM_API_KEY, lat, lon
76
- )
77
- )
 
 
78
 
79
  # Parse JSON from the response
80
  data = r.json()
@@ -103,7 +124,11 @@ def search_along_route_w_coordinates(points: list[tuple[float, float]], query: s
103
  """
104
 
105
  # The API endpoint for searching along a route
106
- url = f"https://api.tomtom.com/search/2/searchAlongRoute/{query}.json?key={config.TOMTOM_API_KEY}&maxDetourTime=360&limit=20&sortBy=detourTime"
 
 
 
 
107
 
108
  points = _select_equally_spaced_coordinates(points, number_of_points=20)
109
 
 
1
  import json
2
+ import urllib.parse
3
  import requests
4
+ from loguru import logger
5
+ from langchain.tools import tool
6
  from .common import config, vehicle
7
 
8
 
 
19
  return selected_coords
20
 
21
 
22
+ @tool
23
+ def search_points_of_interest(search_query: str ="french restaurant"):
24
  """
25
  Get some of the closest points of interest matching the query.
26
 
 
31
  # Extract the latitude and longitude of the vehicle
32
  vehicle_coordinates = getattr(vehicle, "location_coordinates")
33
  lat, lon = vehicle_coordinates
34
+ logger.info(f"POI search vehicle's lat: {lat}, lon: {lon}")
35
 
36
  # https://developer.tomtom.com/search-api/documentation/search-service/search
37
+ # Encode the parameters
38
+ # Even with encoding tomtom doesn't return the correct results
39
+ search_query = search_query.replace("'", "")
40
+ encoded_search_query = urllib.parse.quote(search_query)
41
+
42
+ # Construct the URL
43
+ url = f"https://api.tomtom.com/search/2/search/{encoded_search_query}.json"
44
+ params = {
45
+ "key": config.TOMTOM_API_KEY,
46
+ "lat": lat,
47
+ "lon": lon,
48
+ "radius": 5000,
49
+ "idxSet": "POI",
50
+ "limit": 50
51
+ }
52
+
53
+ r = requests.get(url, params=params, timeout=5)
54
 
55
  # Parse JSON from the response
56
  data = r.json()
57
+
58
+ logger.debug(f"POI search response: {data}\n url:{url} params: {params}")
59
  # Extract results
60
  results = data["results"]
61
 
 
76
  output = (
77
  f"There are {len(results)} options in the vicinity. The most relevant are: "
78
  )
79
+ return output + ".\n ".join(formatted_results), results[:3]
80
 
81
 
82
  def find_points_of_interest(lat="0", lon="0", type_of_poi="restaurant"):
 
88
  :param type_of_poi (string): Required. type of point of interest depending on what the user wants to do.
89
  """
90
  # https://developer.tomtom.com/search-api/documentation/search-service/points-of-interest-search
91
+ # Encode the parameters
92
+ encoded_type_of_poi = urllib.parse.quote(type_of_poi)
93
+
94
+ # Construct the URL
95
+ url = f"https://api.tomtom.com/search/2/search/{encoded_type_of_poi}.json?key={config.TOMTOM_API_KEY}&lat={lat}&lon={lon}&radius=10000&vehicleTypeSet=Car&idxSet=POI&limit=100"
96
+
97
+ r = requests.get(url, timeout=5)
98
+
99
 
100
  # Parse JSON from the response
101
  data = r.json()
 
124
  """
125
 
126
  # The API endpoint for searching along a route
127
+
128
+ # urlencode the query
129
+ query = urllib.parse.quote(query)
130
+
131
+ url = f"https://api.tomtom.com/search/2/searchAlongRoute/{query}.json?key={config.TOMTOM_API_KEY}&maxDetourTime=600&limit=20&sortBy=detourTime"
132
 
133
  points = _select_equally_spaced_coordinates(points, number_of_points=20)
134
 
kitt/skills/routing.py CHANGED
@@ -1,5 +1,7 @@
1
  from datetime import datetime
2
  import requests
 
 
3
  from .common import config, vehicle
4
 
5
 
@@ -120,6 +122,7 @@ def find_route_a_to_b(origin="", destination=""):
120
  return _format_tomtom_trip_info(trip_info, destination)
121
 
122
 
 
123
  def find_route(destination):
124
  """Get a route to a destination from the current location of the vehicle.
125
 
 
1
  from datetime import datetime
2
  import requests
3
+ from loguru import logger
4
+ from langchain.tools import tool
5
  from .common import config, vehicle
6
 
7
 
 
122
  return _format_tomtom_trip_info(trip_info, destination)
123
 
124
 
125
+ @tool
126
  def find_route(destination):
127
  """Get a route to a destination from the current location of the vehicle.
128
 
kitt/skills/vehicle.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from .common import vehicle, Speed
2
 
3
 
@@ -35,8 +37,8 @@ def vehicle_status() -> tuple[str, dict[str, str]]:
35
  return STATUS_TEMPLATE.format(**vs), vs
36
 
37
 
38
-
39
- def set_vehicle_speed(speed: Speed):
40
  """Set the speed of the vehicle.
41
  Args:
42
  speed (Speed): The speed of the vehicle. ("slow", "fast")
@@ -44,7 +46,8 @@ def set_vehicle_speed(speed: Speed):
44
  vehicle.speed = speed
45
  return f"The vehicle speed is set to {speed.value}."
46
 
47
- def set_vehicle_destination(destination: str):
 
48
  """Set the destination of the vehicle.
49
  Args:
50
  destination (str): The destination of the vehicle.
 
1
+ from langchain.tools import tool
2
+
3
  from .common import vehicle, Speed
4
 
5
 
 
37
  return STATUS_TEMPLATE.format(**vs), vs
38
 
39
 
40
+ @tool
41
+ def set_vehicle_speed(speed: Speed) -> str:
42
  """Set the speed of the vehicle.
43
  Args:
44
  speed (Speed): The speed of the vehicle. ("slow", "fast")
 
46
  vehicle.speed = speed
47
  return f"The vehicle speed is set to {speed.value}."
48
 
49
+ @tool
50
+ def set_vehicle_destination(destination: str) -> str:
51
  """Set the destination of the vehicle.
52
  Args:
53
  destination (str): The destination of the vehicle.
kitt/skills/weather.py CHANGED
@@ -1,5 +1,6 @@
1
  import requests
2
  from loguru import logger
 
3
 
4
  from .common import config, vehicle
5
 
@@ -19,7 +20,7 @@ def get_weather_current_location():
19
  return get_weather(location)
20
 
21
 
22
- # current weather API
23
  def get_weather(location: str = "here"):
24
  """
25
  Get the current weather in a specified location.
@@ -70,7 +71,8 @@ def get_weather(location: str = "here"):
70
  # f"Humidity is at {humidity}%. "
71
  # f"Wind speed is {wind_kph} kph." if 'wind_kph' in weather_data['current'] else ""
72
  )
73
- return weather_sentences, weather_data
 
74
 
75
 
76
  # weather forecast API
 
1
  import requests
2
  from loguru import logger
3
+ from langchain.tools import tool
4
 
5
  from .common import config, vehicle
6
 
 
20
  return get_weather(location)
21
 
22
 
23
+ @tool
24
  def get_weather(location: str = "here"):
25
  """
26
  Get the current weather in a specified location.
 
71
  # f"Humidity is at {humidity}%. "
72
  # f"Wind speed is {wind_kph} kph." if 'wind_kph' in weather_data['current'] else ""
73
  )
74
+ # return weather_sentences, weather_data
75
+ return weather_sentences
76
 
77
 
78
  # weather forecast API
main.py CHANGED
@@ -12,7 +12,8 @@ import ollama
12
 
13
  from langchain.tools.base import StructuredTool
14
  from langchain.memory import ChatMessageHistory
15
- from langchain_core.utils.function_calling import convert_to_openai_function
 
16
  from loguru import logger
17
 
18
 
@@ -22,7 +23,7 @@ from kitt.skills import (
22
  get_forecast,
23
  vehicle_status as vehicle_status_fn,
24
  set_vehicle_speed,
25
- search_points_of_interests,
26
  search_along_route_w_coordinates,
27
  set_vehicle_destination,
28
  do_anything_else,
@@ -32,7 +33,8 @@ from kitt.skills import (
32
  )
33
  from kitt.skills import extract_func_args
34
  from kitt.core import voice_options, tts_gradio
35
- from kitt.core.model import process_query
 
36
  from kitt.core import utils as kitt_utils
37
 
38
 
@@ -68,6 +70,8 @@ Answer questions concisely and do not mention what you base your reply on.<|im_e
68
  <|im_start|>assistant
69
  """
70
 
 
 
71
 
72
  def get_prompt(template, input, history, tools):
73
  # "vehicle_status": vehicle_status_fn()[0]
@@ -98,6 +102,7 @@ def use_tool(func_name, kwargs, tools):
98
  hour_options = [f"{i:02d}:00:00" for i in range(24)]
99
 
100
 
 
101
  def search_along_route(query=""):
102
  """Search for points of interest along the route/way to the destination.
103
 
@@ -120,18 +125,29 @@ def get_vehicle_status(state):
120
 
121
 
122
  tools = [
123
- StructuredTool.from_function(get_weather),
124
- StructuredTool.from_function(find_route),
125
  # StructuredTool.from_function(vehicle_status_fn),
126
- StructuredTool.from_function(set_vehicle_speed),
127
- StructuredTool.from_function(search_points_of_interests),
128
- StructuredTool.from_function(search_along_route),
 
129
  # StructuredTool.from_function(date_time_info),
130
  # StructuredTool.from_function(get_weather_current_location),
131
- StructuredTool.from_function(code_interpreter),
132
  # StructuredTool.from_function(do_anything_else),
133
  ]
134
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  def run_generic_model(query):
137
  print(f"Running the generic model with query: {query}")
@@ -186,11 +202,16 @@ def run_nexusraven_model(query, voice_character, state):
186
 
187
 
188
  def run_llama3_model(query, voice_character, state):
 
 
 
 
189
  output_text = process_query(
190
  query,
191
  history=history,
192
  user_preferences=state["user_preferences"],
193
- tools=tools,
 
194
  backend=state["llm_backend"],
195
  )
196
  gr.Info(f"Output text: {output_text}, generating voice output...")
@@ -216,6 +237,9 @@ def run_model(query, voice_character, state):
216
  text, voice = run_llama3_model(query, voice_character, state)
217
  else:
218
  text, voice = "Error running model", None
 
 
 
219
  return text, voice, vehicle.model_dump_json()
220
 
221
 
@@ -285,8 +309,8 @@ def save_and_transcribe_audio(audio):
285
 
286
  def save_and_transcribe_run_model(audio, voice_character, state):
287
  text = save_and_transcribe_audio(audio)
288
- out_text, out_voice = run_model(text, voice_character, state)
289
- return text, out_text, out_voice
290
 
291
 
292
  def set_tts_enabled(tts_enabled, state):
@@ -314,6 +338,12 @@ def set_user_preferences(preferences, state):
314
  return state
315
 
316
 
 
 
 
 
 
 
317
  # to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
318
  # in "Insecure origins treated as secure", enable it and relaunch chrome
319
 
@@ -322,6 +352,10 @@ def set_user_preferences(preferences, state):
322
  # What's the closest restaurant from here?
323
 
324
 
 
 
 
 
325
  def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = True):
326
  print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
327
  with gr.Blocks(theme=gr.themes.Default()) as demo:
@@ -332,11 +366,13 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
332
  "route_points": [],
333
  "model": model,
334
  "tts_enabled": tts_enabled,
335
- "llm_backend": "Ollama",
336
- "user_preferences": "",
 
337
  }
338
  )
339
  trip_points = gr.State(value=[])
 
340
 
341
  with gr.Row():
342
  with gr.Column(scale=1, min_width=300):
@@ -346,12 +382,6 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
346
  value="08:00:00",
347
  interactive=True,
348
  )
349
- history = gr.Radio(
350
- ["Yes", "No"],
351
- label="Maintain the conversation history?",
352
- value="No",
353
- interactive=True,
354
- )
355
  voice_character = gr.Radio(
356
  choices=voice_options,
357
  label="Choose a voice",
@@ -359,24 +389,24 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
359
  show_label=True,
360
  )
361
  origin = gr.Textbox(
362
- value="Mondorf-les-Bains, Luxembourg",
363
  label="Origin",
364
  interactive=True,
365
  )
366
  destination = gr.Textbox(
367
- value="Rue Alphonse Weicker, Luxembourg",
368
  label="Destination",
369
  interactive=True,
370
  )
371
  preferences = gr.Textbox(
372
- value="I love italian food\nI like doing sports",
373
  label="User preferences",
374
  lines=3,
375
  interactive=True,
376
  )
377
 
378
  with gr.Column(scale=2, min_width=600):
379
- map_plot = gr.Plot()
380
  trip_progress = gr.Slider(
381
  0, 100, step=5, label="Trip progress", interactive=True
382
  )
@@ -422,6 +452,12 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
422
  value="Ollama",
423
  interactive=True,
424
  )
 
 
 
 
 
 
425
  # Push button
426
  clear_history_btn = gr.Button(value="Clear History")
427
  with gr.Column():
@@ -472,7 +508,7 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
472
  input_audio.stop_recording(
473
  fn=save_and_transcribe_run_model,
474
  inputs=[input_audio, voice_character, state],
475
- outputs=[input_text, output_text, output_audio],
476
  )
477
  input_audio_debug.stop_recording(
478
  fn=save_and_transcribe_audio,
@@ -490,6 +526,10 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
490
  llm_backend.change(
491
  fn=set_llm_backend, inputs=[llm_backend, state], outputs=[state]
492
  )
 
 
 
 
493
  return demo
494
 
495
 
 
12
 
13
  from langchain.tools.base import StructuredTool
14
  from langchain.memory import ChatMessageHistory
15
+ from langchain_core.utils.function_calling import convert_to_openai_tool
16
+ from langchain.tools import tool
17
  from loguru import logger
18
 
19
 
 
23
  get_forecast,
24
  vehicle_status as vehicle_status_fn,
25
  set_vehicle_speed,
26
+ search_points_of_interest,
27
  search_along_route_w_coordinates,
28
  set_vehicle_destination,
29
  do_anything_else,
 
33
  )
34
  from kitt.skills import extract_func_args
35
  from kitt.core import voice_options, tts_gradio
36
+ # from kitt.core.model import process_query
37
+ from kitt.core.model import generate_function_call as process_query
38
  from kitt.core import utils as kitt_utils
39
 
40
 
 
70
  <|im_start|>assistant
71
  """
72
 
73
+ USER_PREFERENCES = "I love italian food\nI like doing sports"
74
+
75
 
76
  def get_prompt(template, input, history, tools):
77
  # "vehicle_status": vehicle_status_fn()[0]
 
102
  hour_options = [f"{i:02d}:00:00" for i in range(24)]
103
 
104
 
105
+ @tool
106
  def search_along_route(query=""):
107
  """Search for points of interest along the route/way to the destination.
108
 
 
125
 
126
 
127
  tools = [
128
+ # StructuredTool.from_function(get_weather),
129
+ # StructuredTool.from_function(find_route),
130
  # StructuredTool.from_function(vehicle_status_fn),
131
+ # StructuredTool.from_function(set_vehicle_speed),
132
+ # StructuredTool.from_function(set_vehicle_destination),
133
+ # StructuredTool.from_function(search_points_of_interest),
134
+ # StructuredTool.from_function(search_along_route),
135
  # StructuredTool.from_function(date_time_info),
136
  # StructuredTool.from_function(get_weather_current_location),
137
+ # StructuredTool.from_function(code_interpreter),
138
  # StructuredTool.from_function(do_anything_else),
139
  ]
140
 
141
+ functions = [
142
+ set_vehicle_speed,
143
+ set_vehicle_destination,
144
+ get_weather,
145
+ find_route,
146
+ search_points_of_interest,
147
+ search_along_route
148
+ ]
149
+ openai_tools = [convert_to_openai_tool(tool) for tool in functions]
150
+
151
 
152
  def run_generic_model(query):
153
  print(f"Running the generic model with query: {query}")
 
202
 
203
 
204
  def run_llama3_model(query, voice_character, state):
205
+
206
+ assert len (functions) > 0, "No functions to call"
207
+ assert len (openai_tools) > 0, "No openai tools to call"
208
+
209
  output_text = process_query(
210
  query,
211
  history=history,
212
  user_preferences=state["user_preferences"],
213
+ tools=openai_tools,
214
+ functions=functions,
215
  backend=state["llm_backend"],
216
  )
217
  gr.Info(f"Output text: {output_text}, generating voice output...")
 
237
  text, voice = run_llama3_model(query, voice_character, state)
238
  else:
239
  text, voice = "Error running model", None
240
+
241
+ if not state["enable_history"]:
242
+ history.clear()
243
  return text, voice, vehicle.model_dump_json()
244
 
245
 
 
309
 
310
  def save_and_transcribe_run_model(audio, voice_character, state):
311
  text = save_and_transcribe_audio(audio)
312
+ out_text, out_voice, vehicle_status = run_model(text, voice_character, state)
313
+ return text, out_text, out_voice, vehicle_status
314
 
315
 
316
  def set_tts_enabled(tts_enabled, state):
 
338
  return state
339
 
340
 
341
+ def set_enable_history(enable_history, state):
342
+ new_enable_history = enable_history == "Yes"
343
+ logger.info(f"Enable history was {state['enable_history']} and changed to {new_enable_history}")
344
+ state["enable_history"] = new_enable_history
345
+ return state
346
+
347
  # to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
348
  # in "Insecure origins treated as secure", enable it and relaunch chrome
349
 
 
352
  # What's the closest restaurant from here?
353
 
354
 
355
+ ORIGIN = "Mondorf-les-Bains, Luxembourg"
356
+ DESTINATION = "Rue Alphonse Weicker, Luxembourg"
357
+
358
+
359
  def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = True):
360
  print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
361
  with gr.Blocks(theme=gr.themes.Default()) as demo:
 
366
  "route_points": [],
367
  "model": model,
368
  "tts_enabled": tts_enabled,
369
+ "llm_backend": "ollama",
370
+ "user_preferences": USER_PREFERENCES,
371
+ "enable_history": False,
372
  }
373
  )
374
  trip_points = gr.State(value=[])
375
+ plot, vehicle_status, _ = calculate_route_gradio(ORIGIN, DESTINATION)
376
 
377
  with gr.Row():
378
  with gr.Column(scale=1, min_width=300):
 
382
  value="08:00:00",
383
  interactive=True,
384
  )
 
 
 
 
 
 
385
  voice_character = gr.Radio(
386
  choices=voice_options,
387
  label="Choose a voice",
 
389
  show_label=True,
390
  )
391
  origin = gr.Textbox(
392
+ value=ORIGIN,
393
  label="Origin",
394
  interactive=True,
395
  )
396
  destination = gr.Textbox(
397
+ value=DESTINATION,
398
  label="Destination",
399
  interactive=True,
400
  )
401
  preferences = gr.Textbox(
402
+ value=USER_PREFERENCES,
403
  label="User preferences",
404
  lines=3,
405
  interactive=True,
406
  )
407
 
408
  with gr.Column(scale=2, min_width=600):
409
+ map_plot = gr.Plot(value=plot, label="Map")
410
  trip_progress = gr.Slider(
411
  0, 100, step=5, label="Trip progress", interactive=True
412
  )
 
452
  value="Ollama",
453
  interactive=True,
454
  )
455
+ enable_history = gr.Radio(
456
+ ["Yes", "No"],
457
+ label="Maintain the conversation history?",
458
+ value="No",
459
+ interactive=True,
460
+ )
461
  # Push button
462
  clear_history_btn = gr.Button(value="Clear History")
463
  with gr.Column():
 
508
  input_audio.stop_recording(
509
  fn=save_and_transcribe_run_model,
510
  inputs=[input_audio, voice_character, state],
511
+ outputs=[input_text, output_text, output_audio, vehicle_status],
512
  )
513
  input_audio_debug.stop_recording(
514
  fn=save_and_transcribe_audio,
 
526
  llm_backend.change(
527
  fn=set_llm_backend, inputs=[llm_backend, state], outputs=[state]
528
  )
529
+ enable_history.change(
530
+ fn=set_enable_history, inputs=[enable_history, state], outputs=[state]
531
+ )
532
+
533
  return demo
534
 
535