chore: Update vehicle speed and destination handling functions
Browse files- kitt/core/__init__.py +3 -3
- kitt/core/legacy.py +122 -0
- kitt/core/model.py +342 -142
- kitt/core/schema.py +23 -0
- kitt/core/utils.py +26 -0
- kitt/core/validator.py +134 -0
- kitt/skills/__init__.py +5 -2
- kitt/skills/poi.py +39 -14
- kitt/skills/routing.py +3 -0
- kitt/skills/vehicle.py +6 -3
- kitt/skills/weather.py +4 -2
- main.py +65 -25
kitt/core/__init__.py
CHANGED
@@ -21,13 +21,13 @@ voices = [
|
|
21 |
"Attenborough",
|
22 |
neutral=f"{file_full_path}/audio/attenborough/neutral.wav",
|
23 |
angry=None,
|
24 |
-
speed=1.
|
25 |
),
|
26 |
Voice(
|
27 |
"Rick",
|
28 |
neutral=f"{file_full_path}/audio/rick/neutral.wav",
|
29 |
angry=None,
|
30 |
-
speed=1.
|
31 |
),
|
32 |
Voice(
|
33 |
"Freeman",
|
@@ -45,7 +45,7 @@ voices = [
|
|
45 |
"Darth Wader",
|
46 |
neutral=f"{file_full_path}/audio/darth/neutral.wav",
|
47 |
angry=None,
|
48 |
-
speed=1.
|
49 |
),
|
50 |
]
|
51 |
|
|
|
21 |
"Attenborough",
|
22 |
neutral=f"{file_full_path}/audio/attenborough/neutral.wav",
|
23 |
angry=None,
|
24 |
+
speed=1.2,
|
25 |
),
|
26 |
Voice(
|
27 |
"Rick",
|
28 |
neutral=f"{file_full_path}/audio/rick/neutral.wav",
|
29 |
angry=None,
|
30 |
+
speed=1.2,
|
31 |
),
|
32 |
Voice(
|
33 |
"Freeman",
|
|
|
45 |
"Darth Wader",
|
46 |
neutral=f"{file_full_path}/audio/darth/neutral.wav",
|
47 |
angry=None,
|
48 |
+
speed=1.15,
|
49 |
),
|
50 |
]
|
51 |
|
kitt/core/legacy.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uuid
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
from loguru import logger
|
5 |
+
|
6 |
+
def use_tool(tool_call, tools):
|
7 |
+
func_name = tool_call["name"]
|
8 |
+
kwargs = tool_call["arguments"]
|
9 |
+
for tool in tools:
|
10 |
+
if tool.name == func_name:
|
11 |
+
return tool.invoke(input=kwargs)
|
12 |
+
raise ValueError(f"Tool {func_name} not found.")
|
13 |
+
|
14 |
+
|
15 |
+
def parse_tool_calls(text):
|
16 |
+
logger.debug(f"Start parsing tool_calls: {text}")
|
17 |
+
pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"
|
18 |
+
|
19 |
+
if not text.startswith("<tool_call>"):
|
20 |
+
if "<tool_call>" in text:
|
21 |
+
raise ValueError("<text_and_tool_call>")
|
22 |
+
|
23 |
+
if "<tool_response>" in text:
|
24 |
+
raise ValueError("<tool_response>")
|
25 |
+
return [], []
|
26 |
+
|
27 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
28 |
+
tool_calls = []
|
29 |
+
errors = []
|
30 |
+
for match in matches:
|
31 |
+
try:
|
32 |
+
tool_call = json.loads(match)
|
33 |
+
tool_calls.append(tool_call)
|
34 |
+
except json.JSONDecodeError as e:
|
35 |
+
errors.append(f"Invalid JSON in tool call: {e}")
|
36 |
+
|
37 |
+
logger.debug(f"Tool calls: {tool_calls}, errors: {errors}")
|
38 |
+
return tool_calls, errors
|
39 |
+
|
40 |
+
def process_response(user_query, res, history, tools, depth):
|
41 |
+
"""Returns True if the response contains tool calls, False otherwise."""
|
42 |
+
logger.debug(f"Processing response: {res}")
|
43 |
+
tool_results = f"Agent iteration {depth} to assist with user query: {user_query}\n"
|
44 |
+
tool_call_id = uuid.uuid4().hex
|
45 |
+
try:
|
46 |
+
tool_calls, errors = parse_tool_calls(res)
|
47 |
+
except ValueError as e:
|
48 |
+
if "<text_and_tool_call>" in str(e):
|
49 |
+
tool_results += "<tool_response>If you need to call a tool your response must be wrapped in <tool_call></tool_call>. Try again, you are great.</tool_response>"
|
50 |
+
history.add_message(
|
51 |
+
ToolMessage(content=tool_results, tool_call_id=tool_call_id)
|
52 |
+
)
|
53 |
+
return True, [], []
|
54 |
+
if "<tool_response>" in str(e):
|
55 |
+
tool_results += "<tool_response>Tool results are not allowed in the response.</tool_response>"
|
56 |
+
history.add_message(
|
57 |
+
ToolMessage(content=tool_results, tool_call_id=tool_call_id)
|
58 |
+
)
|
59 |
+
return True, [], []
|
60 |
+
# TODO: Handle errors
|
61 |
+
if not tool_calls:
|
62 |
+
logger.debug("No tool calls found in response.")
|
63 |
+
return False, tool_calls, errors
|
64 |
+
# tool_results = ""
|
65 |
+
|
66 |
+
for tool_call in tool_calls:
|
67 |
+
# TODO: Extra Validation
|
68 |
+
# Call the function
|
69 |
+
try:
|
70 |
+
result = use_tool(tool_call, tools)
|
71 |
+
logger.debug(f"Tool call {tool_call} result: {result}")
|
72 |
+
if isinstance(result, tuple):
|
73 |
+
result = result[1]
|
74 |
+
tool_results += f"<tool_response>\n{result}\n</tool_response>\n"
|
75 |
+
except Exception as e:
|
76 |
+
logger.error(f"Error calling tool: {e}")
|
77 |
+
# Currently only to mimic OpneAI's behavior
|
78 |
+
# But it could be used for tracking function calls
|
79 |
+
|
80 |
+
tool_results = tool_results.strip()
|
81 |
+
print(f"Tool results: {tool_results}")
|
82 |
+
history.add_message(ToolMessage(content=tool_results, tool_call_id=tool_call_id))
|
83 |
+
return True, tool_calls, errors
|
84 |
+
|
85 |
+
|
86 |
+
def process_query(
|
87 |
+
user_query: str,
|
88 |
+
history: ChatMessageHistory,
|
89 |
+
user_preferences,
|
90 |
+
tools,
|
91 |
+
backend="ollama",
|
92 |
+
):
|
93 |
+
# Add vehicle status to the history
|
94 |
+
user_query_status = f"consider the vehicle status:\n{vehicle_status()[0]}\nwhen responding to the following query:\n{user_query}"
|
95 |
+
history.add_message(HumanMessage(content=user_query_status))
|
96 |
+
for depth in range(10):
|
97 |
+
# out = run_inference_step(depth, history, tools, schema_json)
|
98 |
+
out = run_inference_step(
|
99 |
+
depth,
|
100 |
+
history,
|
101 |
+
tools,
|
102 |
+
schema_json,
|
103 |
+
user_preferences=user_preferences,
|
104 |
+
backend=backend,
|
105 |
+
)
|
106 |
+
logger.info(f"Inference step result:\n{out}")
|
107 |
+
history.add_message(AIMessage(content=out))
|
108 |
+
to_continue, tool_calls, errors = process_response(
|
109 |
+
user_query, out, history, tools, depth
|
110 |
+
)
|
111 |
+
if errors:
|
112 |
+
history.add_message(AIMessage(content=f"Errors in tool calls: {errors}"))
|
113 |
+
|
114 |
+
if not to_continue:
|
115 |
+
print(f"This is the answer, no more iterations: {out}")
|
116 |
+
return out
|
117 |
+
# Otherwise, tools result is already added to history, we just need to continue the loop.
|
118 |
+
# If we get here something went wrong.
|
119 |
+
history.add_message(
|
120 |
+
AIMessage(content="Sorry, I am not sure how to help you with that.")
|
121 |
+
)
|
122 |
+
return "Sorry, I am not sure how to help you with that."
|
kitt/core/model.py
CHANGED
@@ -1,18 +1,22 @@
|
|
|
|
1 |
import json
|
2 |
import re
|
3 |
import uuid
|
|
|
|
|
|
|
4 |
|
5 |
from langchain.memory import ChatMessageHistory
|
6 |
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
|
7 |
-
from langchain_core.utils.function_calling import
|
8 |
-
import
|
9 |
from ollama import Client
|
10 |
from pydantic import BaseModel
|
11 |
from loguru import logger
|
12 |
|
13 |
-
|
14 |
from kitt.skills import vehicle_status
|
15 |
from kitt.skills.common import config
|
|
|
16 |
|
17 |
|
18 |
class FunctionCall(BaseModel):
|
@@ -28,14 +32,52 @@ class FunctionCall(BaseModel):
|
|
28 |
"""The name of the function to call."""
|
29 |
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
schema_json = json.loads(FunctionCall.schema_json())
|
|
|
|
|
32 |
HRMS_SYSTEM_PROMPT = """<|im_start|>system
|
33 |
-
You are a
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
You are provided with function signatures within <tools></tools> XML tags.
|
|
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
You may use agentic frameworks for reasoning and planning to help with user query.
|
38 |
-
Please call
|
39 |
Don't make assumptions about what values to plug into function arguments.
|
40 |
Once you have called a function, results will be fed back to you within <tool_response></tool_response> XML tags.
|
41 |
Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
|
@@ -44,49 +86,38 @@ At each iteration please continue adding the your analysis to previous summary.
|
|
44 |
Your final response should directly answer the user query. Don't tell what you are doing, just do it.
|
45 |
|
46 |
|
|
|
47 |
Here are the available tools:
|
48 |
<tools> {tools} </tools>
|
49 |
-
If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows:
|
50 |
-
<tool_call>
|
51 |
-
{{"arguments": {{"code_markdown": <python-code>, "name": "code_interpreter"}}}}
|
52 |
-
</tool_call>
|
53 |
Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
|
54 |
When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
|
55 |
|
56 |
-
Example 1:
|
57 |
-
User: How is the weather?
|
58 |
-
Assistant:
|
59 |
-
<tool_call>
|
60 |
-
{{"arguments": {{"location": ""}}, "name": "get_weather"}}
|
61 |
-
</tool_call>
|
62 |
-
|
63 |
-
Example 2:
|
64 |
-
User: Is there a Spa nearby?
|
65 |
-
Assistant:
|
66 |
-
<tool_call>
|
67 |
-
{{"arguments": {{"search_query": "Spa"}}, "name": "search_points_of_interests"}}
|
68 |
-
</tool_call>
|
69 |
-
|
70 |
-
Example 3:
|
71 |
-
User: How long will it take to get to the destination?
|
72 |
-
Assistant:
|
73 |
-
<tool_call>
|
74 |
-
{{"arguments": {{"destination": ""}}, "name": "calculate_route"}}
|
75 |
-
|
76 |
When asked for the weather or points of interest, use the appropriate tool with the current location of the car. Unless the user provides a location, then use that location.
|
77 |
Always assume user wants to travel by car.
|
78 |
|
|
|
79 |
Use the following pydantic model json schema for each tool call you will make:
|
80 |
{schema}
|
81 |
|
|
|
82 |
At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
|
83 |
Please keep a running summary with analysis of previous function results and summaries from previous iterations.
|
84 |
Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
|
|
|
85 |
If you plan to continue with analysis, always call another function.
|
86 |
-
For each function call return a valid json object (using
|
|
|
|
|
|
|
|
|
87 |
<tool_call>
|
88 |
{{"arguments": <args-dict>, "name": <function-name>}}
|
89 |
</tool_call>
|
|
|
|
|
|
|
|
|
|
|
90 |
<|im_end|>"""
|
91 |
AI_PREAMBLE = """
|
92 |
<|im_start|>assistant
|
@@ -103,6 +134,32 @@ HRMS_TEMPLATE_TOOL_RESULT = """
|
|
103 |
<|im_end|>"""
|
104 |
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
def append_message(prompt, h):
|
107 |
if h.type == "human":
|
108 |
prompt += HRMS_TEMPLATE_USER.format(user_input=h.content)
|
@@ -113,7 +170,7 @@ def append_message(prompt, h):
|
|
113 |
return prompt
|
114 |
|
115 |
|
116 |
-
def get_prompt(template, history, tools, schema, car_status=None):
|
117 |
if not car_status:
|
118 |
# car_status = vehicle.dict()
|
119 |
car_status = vehicle_status()[0]
|
@@ -124,6 +181,7 @@ def get_prompt(template, history, tools, schema, car_status=None):
|
|
124 |
"schema": schema,
|
125 |
"tools": tools,
|
126 |
"car_status": car_status,
|
|
|
127 |
}
|
128 |
|
129 |
prompt = template.format(**kwargs).replace("{{", "{").replace("}}", "}")
|
@@ -137,99 +195,31 @@ def get_prompt(template, history, tools, schema, car_status=None):
|
|
137 |
return prompt
|
138 |
|
139 |
|
140 |
-
def use_tool(tool_call, tools):
|
141 |
-
func_name = tool_call["name"]
|
142 |
-
kwargs = tool_call["arguments"]
|
143 |
-
for tool in tools:
|
144 |
-
if tool.name == func_name:
|
145 |
-
return tool.invoke(input=kwargs)
|
146 |
-
return None
|
147 |
-
|
148 |
|
149 |
-
def parse_tool_calls(text):
|
150 |
-
logger.debug(f"Start parsing tool_calls: {text}")
|
151 |
-
pattern = r"<tool_call>\s*(\{.*?\})\s*</tool_call>"
|
152 |
-
|
153 |
-
if not text.startswith("<tool_call>"):
|
154 |
-
if "<tool_call>" in text:
|
155 |
-
raise ValueError("<text_and_tool_call>")
|
156 |
-
return [], []
|
157 |
-
|
158 |
-
matches = re.findall(pattern, text, re.DOTALL)
|
159 |
-
tool_calls = []
|
160 |
-
errors = []
|
161 |
-
for match in matches:
|
162 |
-
try:
|
163 |
-
tool_call = json.loads(match)
|
164 |
-
tool_calls.append(tool_call)
|
165 |
-
except json.JSONDecodeError as e:
|
166 |
-
errors.append(f"Invalid JSON in tool call: {e}")
|
167 |
-
|
168 |
-
logger.debug(f"Tool calls: {tool_calls}, errors: {errors}")
|
169 |
-
return tool_calls, errors
|
170 |
-
|
171 |
-
|
172 |
-
def process_response(user_query, res, history, tools, depth):
|
173 |
-
"""Returns True if the response contains tool calls, False otherwise."""
|
174 |
-
logger.debug(f"Processing response: {res}")
|
175 |
-
tool_results = f"Agent iteration {depth} to assist with user query: {user_query}\n"
|
176 |
-
tool_call_id = uuid.uuid4().hex
|
177 |
-
try:
|
178 |
-
tool_calls, errors = parse_tool_calls(res)
|
179 |
-
except ValueError as e:
|
180 |
-
if "<text_and_tool_call>" in str(e):
|
181 |
-
tool_results += f"A mix of text and tool_call was found, you must either answer the query in a short sentence or use tool_call not both. Try again, this time only using tool_call."
|
182 |
-
history.add_message(
|
183 |
-
ToolMessage(content=tool_results, tool_call_id=tool_call_id)
|
184 |
-
)
|
185 |
-
return True, [], []
|
186 |
-
# TODO: Handle errors
|
187 |
-
if not tool_calls:
|
188 |
-
return False, tool_calls, errors
|
189 |
-
# tool_results = ""
|
190 |
-
|
191 |
-
for tool_call in tool_calls:
|
192 |
-
# TODO: Extra Validation
|
193 |
-
# Call the function
|
194 |
-
try:
|
195 |
-
result = use_tool(tool_call, tools)
|
196 |
-
if isinstance(result, tuple):
|
197 |
-
result = result[1]
|
198 |
-
tool_results += f"<tool_response>\n{result}\n</tool_response>\n"
|
199 |
-
except Exception as e:
|
200 |
-
print(e)
|
201 |
-
# Currently only to mimic OpneAI's behavior
|
202 |
-
# But it could be used for tracking function calls
|
203 |
-
|
204 |
-
tool_results = tool_results.strip()
|
205 |
-
print(f"Tool results: {tool_results}")
|
206 |
-
history.add_message(ToolMessage(content=tool_results, tool_call_id=tool_call_id))
|
207 |
-
return True, tool_calls, errors
|
208 |
|
209 |
|
210 |
def run_inference_ollama(prompt):
|
211 |
data = {
|
212 |
-
"prompt": prompt
|
213 |
-
+ "\nThis is the first turn and you don't have <tool_results> to analyze yet"
|
214 |
-
+ AI_PREAMBLE,
|
215 |
# "streaming": False,
|
216 |
# "model": "smangrul/llama-3-8b-instruct-function-calling",
|
217 |
# "model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
|
218 |
# "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
|
219 |
-
|
220 |
-
"model": "dolphin-llama3:8b",
|
221 |
# "model": "dolphin-llama3:70b",
|
222 |
"raw": True,
|
223 |
"options": {
|
224 |
-
"temperature": 0.
|
225 |
# "max_tokens": 1500,
|
226 |
"num_predict": 1500,
|
227 |
# "mirostat": 1,
|
228 |
# "mirostat_tau": 2,
|
229 |
-
"repeat_penalty": 1.
|
230 |
"top_k": 25,
|
231 |
"top_p": 0.5,
|
232 |
"num_ctx": 8000,
|
|
|
233 |
# "num_predict": 1500,
|
234 |
# "max_tokens": 1500,
|
235 |
},
|
@@ -248,14 +238,26 @@ def run_inference_ollama(prompt):
|
|
248 |
|
249 |
|
250 |
def run_inference_step(
|
251 |
-
depth, history, tools, schema_json,
|
252 |
):
|
253 |
# If we decide to call a function, we need to generate the prompt for the model
|
254 |
# based on the history of the conversation so far.
|
255 |
# not break the loop
|
256 |
-
openai_tools = [
|
257 |
-
prompt = get_prompt(
|
258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
|
260 |
if backend == "ollama":
|
261 |
output = run_inference_ollama(prompt)
|
@@ -272,9 +274,7 @@ def run_inference_replicate(prompt):
|
|
272 |
replicate = Client(api_token=config.REPLICATE_API_KEY)
|
273 |
|
274 |
input = {
|
275 |
-
"prompt": prompt
|
276 |
-
+ "\nThis is the first turn and you don't have <tool_results> to analyze yet"
|
277 |
-
+ AI_PREAMBLE,
|
278 |
"temperature": 0.5,
|
279 |
"system_prompt": "",
|
280 |
"max_new_tokens": 1024,
|
@@ -283,41 +283,241 @@ def run_inference_replicate(prompt):
|
|
283 |
}
|
284 |
|
285 |
output = replicate.run(
|
286 |
-
"mikeei/dolphin-2.9-llama3-8b-gguf:0f79fb14c45ae2b92e1f07d872dceed3afafcacd903258df487d3bec9e393cb2",
|
|
|
287 |
input=input,
|
288 |
)
|
289 |
out = "".join(output)
|
290 |
|
|
|
|
|
291 |
return out
|
292 |
|
293 |
|
294 |
-
def
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
backend
|
300 |
-
)
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
)
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
)
|
323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
import json
|
3 |
import re
|
4 |
import uuid
|
5 |
+
from enum import Enum
|
6 |
+
from typing import List
|
7 |
+
import xml.etree.ElementTree as ET
|
8 |
|
9 |
from langchain.memory import ChatMessageHistory
|
10 |
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
|
11 |
+
from langchain_core.utils.function_calling import convert_to_openai_tool
|
12 |
+
from langchain.tools.base import StructuredTool
|
13 |
from ollama import Client
|
14 |
from pydantic import BaseModel
|
15 |
from loguru import logger
|
16 |
|
|
|
17 |
from kitt.skills import vehicle_status
|
18 |
from kitt.skills.common import config
|
19 |
+
from .validator import validate_function_call_schema
|
20 |
|
21 |
|
22 |
class FunctionCall(BaseModel):
|
|
|
32 |
"""The name of the function to call."""
|
33 |
|
34 |
|
35 |
+
class ResponseType(Enum):
|
36 |
+
TOOL_CALL = "tool_call"
|
37 |
+
TEXT = "text"
|
38 |
+
|
39 |
+
|
40 |
+
class AssistantResponse(BaseModel):
|
41 |
+
tool_calls: List[FunctionCall]
|
42 |
+
"""The tool call to make to get the response."""
|
43 |
+
|
44 |
+
response_type: ResponseType = (
|
45 |
+
ResponseType.TOOL_CALL
|
46 |
+
) # The type of response to make to the user. Either 'tool_call' or 'text'.
|
47 |
+
"""The type of response to make to the user. Either 'tool_call' or 'text'."""
|
48 |
+
|
49 |
+
response: str
|
50 |
+
|
51 |
+
|
52 |
schema_json = json.loads(FunctionCall.schema_json())
|
53 |
+
# schema_json = json.loads(AssistantResponse.schema_json())
|
54 |
+
|
55 |
HRMS_SYSTEM_PROMPT = """<|im_start|>system
|
56 |
+
You are a helpful assistant that answers in JSON. Here's the json schema you must adhere to:
|
57 |
+
<schema>
|
58 |
+
{schema}
|
59 |
+
<schema><|im_end|>"""
|
60 |
+
|
61 |
+
|
62 |
+
HRMS_SYSTEM_PROMPT = """<|im_start|>system
|
63 |
+
Role:
|
64 |
+
Your name is KITT. You are embodied in a Car. The user is a human who is a passenger in the car. You have autonomy to use the tools available to you to assist the user.
|
65 |
+
You are the AI assistant in the car. From the information in <car_status></car_status you know where you are, the destination, and the current date and time.
|
66 |
+
You are witty, helpful, and have a good sense of humor. You are a function calling AI agent with self-recursion.
|
67 |
You are provided with function signatures within <tools></tools> XML tags.
|
68 |
+
User preferences are provided in <user_preferences></user_preferences> XML tags. Use them if needed.
|
69 |
|
70 |
+
<car_status>
|
71 |
+
{car_status}
|
72 |
+
</car_status>
|
73 |
+
|
74 |
+
<user_preferences>
|
75 |
+
{user_preferences}
|
76 |
+
</user_preferences>
|
77 |
+
|
78 |
+
Objective:
|
79 |
You may use agentic frameworks for reasoning and planning to help with user query.
|
80 |
+
Please call one or two functions at a time, the function results to be provided to you immediately. Try to answer the user query, with as little back and forth as possible.
|
81 |
Don't make assumptions about what values to plug into function arguments.
|
82 |
Once you have called a function, results will be fed back to you within <tool_response></tool_response> XML tags.
|
83 |
Don't make assumptions about tool results if <tool_response> XML tags are not present since function hasn't been executed yet.
|
|
|
86 |
Your final response should directly answer the user query. Don't tell what you are doing, just do it.
|
87 |
|
88 |
|
89 |
+
Tools:
|
90 |
Here are the available tools:
|
91 |
<tools> {tools} </tools>
|
|
|
|
|
|
|
|
|
92 |
Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree.
|
93 |
When using tools, ensure to only use the tools provided and not make up any data and do not provide any explanation as to which tool you are using and why.
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
When asked for the weather or points of interest, use the appropriate tool with the current location of the car. Unless the user provides a location, then use that location.
|
96 |
Always assume user wants to travel by car.
|
97 |
|
98 |
+
Schema:
|
99 |
Use the following pydantic model json schema for each tool call you will make:
|
100 |
{schema}
|
101 |
|
102 |
+
Instructions:
|
103 |
At the very first turn you don't have <tool_results> so you shouldn't not make up the results.
|
104 |
Please keep a running summary with analysis of previous function results and summaries from previous iterations.
|
105 |
Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
|
106 |
+
Calling multiple functions at once can overload the system and increase cost so call one function at a time please.
|
107 |
If you plan to continue with analysis, always call another function.
|
108 |
+
For each function call return a valid json object (using double quotes) with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
109 |
+
<tool_call>
|
110 |
+
{{"arguments": <args-dict>, "name": <function-name>}}
|
111 |
+
</tool_call>
|
112 |
+
If there are more than one function call, return multiple <tool_call></tool_call> XML tags, for example:
|
113 |
<tool_call>
|
114 |
{{"arguments": <args-dict>, "name": <function-name>}}
|
115 |
</tool_call>
|
116 |
+
<tool_call>
|
117 |
+
{{"arguments": <args-dict>, "name": <function-name>}}
|
118 |
+
</tool_call>
|
119 |
+
You have to open and close the XML tags for each function call.
|
120 |
+
|
121 |
<|im_end|>"""
|
122 |
AI_PREAMBLE = """
|
123 |
<|im_start|>assistant
|
|
|
134 |
<|im_end|>"""
|
135 |
|
136 |
|
137 |
+
"""
|
138 |
+
Below are a few examples, but they are not exhaustive. You can call any tool as long as it is within the <tools></tools> XML tags. Also examples are simplified and don't include all the tags you will see in the conversation.
|
139 |
+
Example 1:
|
140 |
+
User: How is the weather?
|
141 |
+
Assistant:
|
142 |
+
<tool_call>
|
143 |
+
{{"arguments": {{"location": ""}}, "name": "get_weather"}}
|
144 |
+
</tool_call>
|
145 |
+
|
146 |
+
Example 2:
|
147 |
+
User: Is there a Spa nearby?
|
148 |
+
Assistant:
|
149 |
+
<tool_call>
|
150 |
+
{{"arguments": {{"search_query": "Spa"}}, "name": "search_points_of_interest"}}
|
151 |
+
</tool_call>
|
152 |
+
|
153 |
+
|
154 |
+
Example 3:
|
155 |
+
User: How long will it take to get to the destination?
|
156 |
+
Assistant:
|
157 |
+
<tool_call>
|
158 |
+
{{"arguments": {{"destination": ""}}, "name": "calculate_route"}}
|
159 |
+
</tool_call>
|
160 |
+
"""
|
161 |
+
|
162 |
+
|
163 |
def append_message(prompt, h):
|
164 |
if h.type == "human":
|
165 |
prompt += HRMS_TEMPLATE_USER.format(user_input=h.content)
|
|
|
170 |
return prompt
|
171 |
|
172 |
|
173 |
+
def get_prompt(template, history, tools, schema, user_preferences, car_status=None):
|
174 |
if not car_status:
|
175 |
# car_status = vehicle.dict()
|
176 |
car_status = vehicle_status()[0]
|
|
|
181 |
"schema": schema,
|
182 |
"tools": tools,
|
183 |
"car_status": car_status,
|
184 |
+
"user_preferences": user_preferences,
|
185 |
}
|
186 |
|
187 |
prompt = template.format(**kwargs).replace("{{", "{").replace("}}", "}")
|
|
|
195 |
return prompt
|
196 |
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
|
201 |
def run_inference_ollama(prompt):
|
202 |
data = {
|
203 |
+
"prompt": prompt,
|
|
|
|
|
204 |
# "streaming": False,
|
205 |
# "model": "smangrul/llama-3-8b-instruct-function-calling",
|
206 |
# "model": "elvee/hermes-2-pro-llama-3:8b-Q5_K_M",
|
207 |
# "model": "NousResearch/Hermes-2-Pro-Llama-3-8B",
|
208 |
+
"model": "interstellarninja/hermes-2-pro-llama-3-8b",
|
209 |
+
# "model": "dolphin-llama3:8b",
|
210 |
# "model": "dolphin-llama3:70b",
|
211 |
"raw": True,
|
212 |
"options": {
|
213 |
+
"temperature": 0.7,
|
214 |
# "max_tokens": 1500,
|
215 |
"num_predict": 1500,
|
216 |
# "mirostat": 1,
|
217 |
# "mirostat_tau": 2,
|
218 |
+
"repeat_penalty": 1.2,
|
219 |
"top_k": 25,
|
220 |
"top_p": 0.5,
|
221 |
"num_ctx": 8000,
|
222 |
+
# "stop": ["<|im_end|>"]
|
223 |
# "num_predict": 1500,
|
224 |
# "max_tokens": 1500,
|
225 |
},
|
|
|
238 |
|
239 |
|
240 |
def run_inference_step(
|
241 |
+
depth, history, tools, schema_json, user_preferences, backend="ollama"
|
242 |
):
|
243 |
# If we decide to call a function, we need to generate the prompt for the model
|
244 |
# based on the history of the conversation so far.
|
245 |
# not break the loop
|
246 |
+
openai_tools = [convert_to_openai_tool(tool) for tool in tools]
|
247 |
+
prompt = get_prompt(
|
248 |
+
HRMS_SYSTEM_PROMPT,
|
249 |
+
history,
|
250 |
+
openai_tools,
|
251 |
+
schema_json,
|
252 |
+
user_preferences=user_preferences,
|
253 |
+
)
|
254 |
+
logger.debug(f"History is: {history.messages}")
|
255 |
+
|
256 |
+
# if depth == 0:
|
257 |
+
# prompt += "\nThis is the first turn and you don't have <tool_results> to analyze yet."
|
258 |
+
prompt += AI_PREAMBLE
|
259 |
+
|
260 |
+
logger.info(f"Prompt is:\n{prompt}")
|
261 |
|
262 |
if backend == "ollama":
|
263 |
output = run_inference_ollama(prompt)
|
|
|
274 |
replicate = Client(api_token=config.REPLICATE_API_KEY)
|
275 |
|
276 |
input = {
|
277 |
+
"prompt": prompt,
|
|
|
|
|
278 |
"temperature": 0.5,
|
279 |
"system_prompt": "",
|
280 |
"max_new_tokens": 1024,
|
|
|
283 |
}
|
284 |
|
285 |
output = replicate.run(
|
286 |
+
# "mikeei/dolphin-2.9-llama3-8b-gguf:0f79fb14c45ae2b92e1f07d872dceed3afafcacd903258df487d3bec9e393cb2",
|
287 |
+
"sasan-j/hermes-2-pro-llama-3-8b:28b1dc16f47d9df68d9839418282315d5e78d9e2ab3fa6ff15728c76ae71a6d6",
|
288 |
input=input,
|
289 |
)
|
290 |
out = "".join(output)
|
291 |
|
292 |
+
logger.debug(f"Response from Ollama:\nOut:{out}")
|
293 |
+
|
294 |
return out
|
295 |
|
296 |
|
297 |
+
def run_inference(prompt, backend="ollama"):
|
298 |
+
prompt += AI_PREAMBLE
|
299 |
+
|
300 |
+
logger.info(f"Prompt is:\n{prompt}")
|
301 |
+
|
302 |
+
if backend == "ollama":
|
303 |
+
output = run_inference_ollama(prompt)
|
304 |
+
else:
|
305 |
+
output = run_inference_replicate(prompt)
|
306 |
+
|
307 |
+
logger.debug(f"Response from model: {output}")
|
308 |
+
return output
|
309 |
+
|
310 |
+
|
311 |
+
def validate_and_extract_tool_calls(assistant_content):
|
312 |
+
validation_result = False
|
313 |
+
tool_calls = []
|
314 |
+
error_message = None
|
315 |
+
|
316 |
+
try:
|
317 |
+
# wrap content in root element
|
318 |
+
xml_root_element = f"<root>{assistant_content}</root>"
|
319 |
+
root = ET.fromstring(xml_root_element)
|
320 |
+
|
321 |
+
# extract JSON data
|
322 |
+
for element in root.findall(".//tool_call"):
|
323 |
+
json_data = None
|
324 |
+
try:
|
325 |
+
json_text = element.text.strip()
|
326 |
+
|
327 |
+
try:
|
328 |
+
# Prioritize json.loads for better error handling
|
329 |
+
json_data = json.loads(json_text)
|
330 |
+
except json.JSONDecodeError as json_err:
|
331 |
+
try:
|
332 |
+
# Fallback to ast.literal_eval if json.loads fails
|
333 |
+
json_data = ast.literal_eval(json_text)
|
334 |
+
except (SyntaxError, ValueError) as eval_err:
|
335 |
+
error_message = (
|
336 |
+
f"JSON parsing failed with both json.loads and ast.literal_eval:\n"
|
337 |
+
f"- JSON Decode Error: {json_err}\n"
|
338 |
+
f"- Fallback Syntax/Value Error: {eval_err}\n"
|
339 |
+
f"- Problematic JSON text: {json_text}"
|
340 |
+
)
|
341 |
+
logger.error(error_message)
|
342 |
+
continue
|
343 |
+
except Exception as e:
|
344 |
+
error_message = f"Cannot strip text: {e}"
|
345 |
+
logger.error(error_message)
|
346 |
+
|
347 |
+
if json_data is not None:
|
348 |
+
tool_calls.append(json_data)
|
349 |
+
validation_result = True
|
350 |
+
|
351 |
+
except ET.ParseError as err:
|
352 |
+
error_message = f"XML Parse Error: {err}"
|
353 |
+
logger.error(f"XML Parse Error: {err}")
|
354 |
+
|
355 |
+
# Return default values if no valid data is extracted
|
356 |
+
return validation_result, tool_calls, error_message
|
357 |
+
|
358 |
+
|
359 |
+
def execute_function_call(tool_call, functions):
|
360 |
+
function_name = tool_call.get("name")
|
361 |
+
for tool in functions:
|
362 |
+
if tool.name == function_name:
|
363 |
+
function_to_call = tool
|
364 |
+
break
|
365 |
+
else:
|
366 |
+
raise ValueError(f"Function {function_name} not found.")
|
367 |
+
function_args = tool_call.get("arguments", {})
|
368 |
+
|
369 |
+
logger.info(f"Invoking function call {function_name} ...")
|
370 |
+
if isinstance(function_to_call, StructuredTool):
|
371 |
+
function_response = function_to_call.invoke(input=function_args)
|
372 |
+
else:
|
373 |
+
function_response = function_to_call(*function_args.values())
|
374 |
+
results_dict = f'{{"name": "{function_name}", "content": {function_response}}}'
|
375 |
+
return results_dict
|
376 |
+
|
377 |
+
|
378 |
+
def process_completion_and_validate(completion):
|
379 |
+
|
380 |
+
# I think I don't need this.
|
381 |
+
# assistant_message = get_assistant_message(completion, eos_token="<|im_end|>")
|
382 |
+
assistant_message = completion.strip()
|
383 |
+
|
384 |
+
if assistant_message:
|
385 |
+
validation, tool_calls, error_message = validate_and_extract_tool_calls(
|
386 |
+
assistant_message
|
387 |
)
|
388 |
+
|
389 |
+
if validation:
|
390 |
+
logger.info(f"parsed tool calls:\n{json.dumps(tool_calls, indent=2)}")
|
391 |
+
return tool_calls, assistant_message, error_message
|
392 |
+
else:
|
393 |
+
tool_calls = None
|
394 |
+
return tool_calls, assistant_message, error_message
|
395 |
+
else:
|
396 |
+
logger.warning("Assistant message is None")
|
397 |
+
raise ValueError("Assistant message is None")
|
398 |
+
|
399 |
+
|
400 |
+
UNRESOLVED_MSG = "I'm sorry, I'm not sure how to help you with that."
|
401 |
+
|
402 |
+
|
403 |
+
def get_assistant_message(completion, eos_token):
|
404 |
+
"""define and match pattern to find the assistant message"""
|
405 |
+
completion = completion.strip()
|
406 |
+
assistant_pattern = re.compile(
|
407 |
+
r"<\|im_start\|>\s*assistant((?:(?!<\|im_start\|>\s*assistant).)*)$", re.DOTALL
|
408 |
)
|
409 |
+
assistant_match = assistant_pattern.search(completion)
|
410 |
+
if assistant_match:
|
411 |
+
assistant_content = assistant_match.group(1).strip()
|
412 |
+
return assistant_content.replace(eos_token, "")
|
413 |
+
else:
|
414 |
+
assistant_content = None
|
415 |
+
logger.info("No match found for the assistant pattern")
|
416 |
+
return assistant_content
|
417 |
+
|
418 |
+
|
419 |
+
def generate_function_call(
|
420 |
+
query, history, user_preferences, tools, functions, backend, max_depth=5
|
421 |
+
) -> str:
|
422 |
+
"""
|
423 |
+
Largely taken from https://github.com/NousResearch/Hermes-Function-Calling
|
424 |
+
"""
|
425 |
+
|
426 |
+
try:
|
427 |
+
depth = 0
|
428 |
+
# user_message = f"{query}\nThis is the first turn and you don't have <tool_results> to analyze yet"
|
429 |
+
user_message = f"{query}"
|
430 |
+
# chat = [{"role": "user", "content": user_message}]
|
431 |
+
history.add_message(HumanMessage(content=user_message))
|
432 |
+
|
433 |
+
# openai_tools = [convert_to_openai_function(tool) for tool in tools]
|
434 |
+
prompt = get_prompt(
|
435 |
+
HRMS_SYSTEM_PROMPT,
|
436 |
+
history,
|
437 |
+
tools,
|
438 |
+
schema_json,
|
439 |
+
user_preferences=user_preferences,
|
440 |
+
)
|
441 |
+
logger.debug(f"History is: {history.json()}")
|
442 |
+
|
443 |
+
# if depth == 0:
|
444 |
+
# prompt += "\nThis is the first turn and you don't have <tool_results> to analyze yet."
|
445 |
+
completion = run_inference(prompt, backend=backend)
|
446 |
+
|
447 |
+
def recursive_loop(prompt, completion, depth) -> str:
|
448 |
+
nonlocal max_depth
|
449 |
+
tool_calls, assistant_message, error_message = (
|
450 |
+
process_completion_and_validate(completion)
|
451 |
+
)
|
452 |
+
# prompt.append({"role": "assistant", "content": assistant_message})
|
453 |
+
history.add_message(AIMessage(content=assistant_message))
|
454 |
+
|
455 |
+
tool_message = (
|
456 |
+
f"Agent iteration {depth} to assist with user query: {query}\n"
|
457 |
+
)
|
458 |
+
if tool_calls:
|
459 |
+
logger.info(f"Assistant Message:\n{assistant_message}")
|
460 |
+
for tool_call in tool_calls:
|
461 |
+
validation, message = validate_function_call_schema(
|
462 |
+
tool_call, tools
|
463 |
+
)
|
464 |
+
if validation:
|
465 |
+
try:
|
466 |
+
function_response = execute_function_call(
|
467 |
+
tool_call, functions=functions
|
468 |
+
)
|
469 |
+
tool_message += f"<tool_response>\n{function_response}\n</tool_response>\n"
|
470 |
+
logger.info(
|
471 |
+
f"Here's the response from the function call: {tool_call.get('name')}\n{function_response}"
|
472 |
+
)
|
473 |
+
except Exception as e:
|
474 |
+
logger.warning(f"Could not execute function: {e}")
|
475 |
+
tool_message += f"<tool_response>\nThere was an error when executing the function: {tool_call.get('name')}\nHere's the error traceback: {e}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
|
476 |
+
else:
|
477 |
+
logger.error(message)
|
478 |
+
tool_message += f"<tool_response>\nThere was an error validating function call against function signature: {tool_call.get('name')}\nHere's the error traceback: {message}\nPlease call this function again with correct arguments within XML tags <tool_call></tool_call>\n</tool_response>\n"
|
479 |
+
# prompt.append({"role": "tool", "content": tool_message})
|
480 |
+
history.add_message(
|
481 |
+
ToolMessage(content=tool_message, tool_call_id=uuid.uuid4().hex)
|
482 |
+
)
|
483 |
+
|
484 |
+
depth += 1
|
485 |
+
if depth >= max_depth:
|
486 |
+
logger.warning(
|
487 |
+
f"Maximum recursion depth reached ({max_depth}). Stopping recursion."
|
488 |
+
)
|
489 |
+
return UNRESOLVED_MSG
|
490 |
+
|
491 |
+
prompt = get_prompt(
|
492 |
+
HRMS_SYSTEM_PROMPT,
|
493 |
+
history,
|
494 |
+
tools,
|
495 |
+
schema_json,
|
496 |
+
user_preferences=user_preferences,
|
497 |
+
)
|
498 |
+
completion = run_inference(prompt, backend=backend)
|
499 |
+
return recursive_loop(prompt, completion, depth)
|
500 |
+
elif error_message:
|
501 |
+
logger.info(f"Assistant Message:\n{assistant_message}")
|
502 |
+
tool_message += f"<tool_response>\nThere was an error parsing function calls\n Here's the error stack trace: {error_message}\nPlease call the function again with correct syntax<tool_response>"
|
503 |
+
prompt.append({"role": "tool", "content": tool_message})
|
504 |
+
|
505 |
+
depth += 1
|
506 |
+
if depth >= max_depth:
|
507 |
+
logger.warning(
|
508 |
+
f"Maximum recursion depth reached ({max_depth}). Stopping recursion."
|
509 |
+
)
|
510 |
+
return UNRESOLVED_MSG
|
511 |
+
|
512 |
+
completion = run_inference(prompt, backend=backend)
|
513 |
+
return recursive_loop(prompt, completion, depth)
|
514 |
+
else:
|
515 |
+
logger.info(f"Assistant Message:\n{assistant_message}")
|
516 |
+
return assistant_message
|
517 |
+
|
518 |
+
return recursive_loop(prompt, completion, depth) # noqa
|
519 |
+
|
520 |
+
except Exception as e:
|
521 |
+
logger.error(f"Exception occurred: {e}")
|
522 |
+
return UNRESOLVED_MSG
|
523 |
+
# raise e
|
kitt/core/schema.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
from typing import List, Dict, Literal, Optional
|
3 |
+
|
4 |
+
class FunctionCall(BaseModel):
|
5 |
+
arguments: dict
|
6 |
+
"""
|
7 |
+
The arguments to call the function with, as generated by the model in JSON
|
8 |
+
format. Note that the model does not always generate valid JSON, and may
|
9 |
+
hallucinate parameters not defined by your function schema. Validate the
|
10 |
+
arguments in your code before calling your function.
|
11 |
+
"""
|
12 |
+
|
13 |
+
name: str
|
14 |
+
"""The name of the function to call."""
|
15 |
+
|
16 |
+
class FunctionDefinition(BaseModel):
|
17 |
+
name: str
|
18 |
+
description: Optional[str] = None
|
19 |
+
parameters: Optional[Dict[str, object]] = None
|
20 |
+
|
21 |
+
class FunctionSignature(BaseModel):
|
22 |
+
function: FunctionDefinition
|
23 |
+
type: Literal["function"]
|
kitt/core/utils.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from typing import List, Tuple, Optional, Union
|
2 |
|
3 |
|
@@ -33,3 +35,27 @@ def plot_route(points, vehicle: Union[tuple[float, float], None] = None):
|
|
33 |
fig.update_geos(fitbounds="locations")
|
34 |
fig.update_layout(margin={"r": 20, "t": 20, "l": 20, "b": 20})
|
35 |
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
from typing import List, Tuple, Optional, Union
|
4 |
|
5 |
|
|
|
35 |
fig.update_geos(fitbounds="locations")
|
36 |
fig.update_layout(margin={"r": 20, "t": 20, "l": 20, "b": 20})
|
37 |
return fig
|
38 |
+
|
39 |
+
|
40 |
+
def extract_json_from_markdown(text):
|
41 |
+
"""
|
42 |
+
Extracts the JSON string from the given text using a regular expression pattern.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
text (str): The input text containing the JSON string.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
dict: The JSON data loaded from the extracted string, or None if the JSON string is not found.
|
49 |
+
"""
|
50 |
+
json_pattern = r'```json\r?\n(.*?)\r?\n```'
|
51 |
+
match = re.search(json_pattern, text, re.DOTALL)
|
52 |
+
if match:
|
53 |
+
json_string = match.group(1)
|
54 |
+
try:
|
55 |
+
data = json.loads(json_string)
|
56 |
+
return data
|
57 |
+
except json.JSONDecodeError as e:
|
58 |
+
print(f"Error decoding JSON string: {e}")
|
59 |
+
else:
|
60 |
+
print("JSON string not found in the text.")
|
61 |
+
return None
|
kitt/core/validator.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import json
|
3 |
+
from jsonschema import validate
|
4 |
+
from pydantic import ValidationError
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
from .utils import extract_json_from_markdown
|
8 |
+
from .schema import FunctionCall, FunctionSignature
|
9 |
+
|
10 |
+
def validate_function_call_schema(call, signatures):
|
11 |
+
try:
|
12 |
+
call_data = FunctionCall(**call)
|
13 |
+
except ValidationError as e:
|
14 |
+
return False, str(e)
|
15 |
+
|
16 |
+
for signature in signatures:
|
17 |
+
try:
|
18 |
+
signature_data = FunctionSignature(**signature)
|
19 |
+
if signature_data.function.name == call_data.name:
|
20 |
+
# Validate types in function arguments
|
21 |
+
for arg_name, arg_schema in signature_data.function.parameters.get('properties', {}).items():
|
22 |
+
if arg_name in call_data.arguments:
|
23 |
+
call_arg_value = call_data.arguments[arg_name]
|
24 |
+
if call_arg_value:
|
25 |
+
try:
|
26 |
+
validate_argument_type(arg_name, call_arg_value, arg_schema)
|
27 |
+
except Exception as arg_validation_error:
|
28 |
+
return False, str(arg_validation_error)
|
29 |
+
|
30 |
+
# Check if all required arguments are present
|
31 |
+
required_arguments = signature_data.function.parameters.get('required', [])
|
32 |
+
result, missing_arguments = check_required_arguments(call_data.arguments, required_arguments)
|
33 |
+
if not result:
|
34 |
+
return False, f"Missing required arguments: {missing_arguments}"
|
35 |
+
|
36 |
+
return True, None
|
37 |
+
except Exception as e:
|
38 |
+
# Handle validation errors for the function signature
|
39 |
+
return False, str(e)
|
40 |
+
|
41 |
+
# No matching function signature found
|
42 |
+
return False, f"No matching function signature found for function: {call_data.name}"
|
43 |
+
|
44 |
+
def check_required_arguments(call_arguments, required_arguments):
|
45 |
+
missing_arguments = [arg for arg in required_arguments if arg not in call_arguments]
|
46 |
+
return not bool(missing_arguments), missing_arguments
|
47 |
+
|
48 |
+
def validate_enum_value(arg_name, arg_value, enum_values):
|
49 |
+
if arg_value not in enum_values:
|
50 |
+
raise Exception(
|
51 |
+
f"Invalid value '{arg_value}' for parameter {arg_name}. Expected one of {', '.join(map(str, enum_values))}"
|
52 |
+
)
|
53 |
+
|
54 |
+
def validate_argument_type(arg_name, arg_value, arg_schema):
|
55 |
+
arg_type = arg_schema.get('type', None)
|
56 |
+
if arg_type:
|
57 |
+
if arg_type == 'string' and 'enum' in arg_schema:
|
58 |
+
enum_values = arg_schema['enum']
|
59 |
+
if None not in enum_values and enum_values != []:
|
60 |
+
try:
|
61 |
+
validate_enum_value(arg_name, arg_value, enum_values)
|
62 |
+
except Exception as e:
|
63 |
+
# Propagate the validation error message
|
64 |
+
raise Exception(f"Error validating function call: {e}")
|
65 |
+
|
66 |
+
python_type = get_python_type(arg_type)
|
67 |
+
if not isinstance(arg_value, python_type):
|
68 |
+
raise Exception(f"Type mismatch for parameter {arg_name}. Expected: {arg_type}, Got: {type(arg_value)}")
|
69 |
+
|
70 |
+
def get_python_type(json_type):
|
71 |
+
type_mapping = {
|
72 |
+
'string': str,
|
73 |
+
'number': (int, float),
|
74 |
+
'integer': int,
|
75 |
+
'boolean': bool,
|
76 |
+
'array': list,
|
77 |
+
'object': dict,
|
78 |
+
'null': type(None),
|
79 |
+
}
|
80 |
+
return type_mapping[json_type]
|
81 |
+
|
82 |
+
def validate_json_data(json_object, json_schema):
|
83 |
+
valid = False
|
84 |
+
error_message = None
|
85 |
+
result_json = None
|
86 |
+
|
87 |
+
try:
|
88 |
+
# Attempt to load JSON using json.loads
|
89 |
+
try:
|
90 |
+
result_json = json.loads(json_object)
|
91 |
+
except json.decoder.JSONDecodeError:
|
92 |
+
# If json.loads fails, try ast.literal_eval
|
93 |
+
try:
|
94 |
+
result_json = ast.literal_eval(json_object)
|
95 |
+
except (SyntaxError, ValueError) as e:
|
96 |
+
try:
|
97 |
+
result_json = extract_json_from_markdown(json_object)
|
98 |
+
except Exception as e:
|
99 |
+
error_message = f"JSON decoding error: {e}"
|
100 |
+
logger.info(f"Validation failed for JSON data: {error_message}")
|
101 |
+
return valid, result_json, error_message
|
102 |
+
|
103 |
+
# Return early if both json.loads and ast.literal_eval fail
|
104 |
+
if result_json is None:
|
105 |
+
error_message = "Failed to decode JSON data"
|
106 |
+
logger.info(f"Validation failed for JSON data: {error_message}")
|
107 |
+
return valid, result_json, error_message
|
108 |
+
|
109 |
+
# Validate each item in the list against schema if it's a list
|
110 |
+
if isinstance(result_json, list):
|
111 |
+
for index, item in enumerate(result_json):
|
112 |
+
try:
|
113 |
+
validate(instance=item, schema=json_schema)
|
114 |
+
logger.info(f"Item {index+1} is valid against the schema.")
|
115 |
+
except ValidationError as e:
|
116 |
+
error_message = f"Validation failed for item {index+1}: {e}"
|
117 |
+
break
|
118 |
+
else:
|
119 |
+
# Default to validation without list
|
120 |
+
try:
|
121 |
+
validate(instance=result_json, schema=json_schema)
|
122 |
+
except ValidationError as e:
|
123 |
+
error_message = f"Validation failed: {e}"
|
124 |
+
|
125 |
+
except Exception as e:
|
126 |
+
error_message = f"Error occurred: {e}"
|
127 |
+
|
128 |
+
if error_message is None:
|
129 |
+
valid = True
|
130 |
+
logger.info("JSON data is valid against the schema.")
|
131 |
+
else:
|
132 |
+
logger.info(f"Validation failed for JSON data: {error_message}")
|
133 |
+
|
134 |
+
return valid, result_json, error_message
|
kitt/skills/__init__.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
from datetime import datetime
|
2 |
import inspect
|
|
|
3 |
|
4 |
from .common import execute_function_call, extract_func_args, vehicle as vehicle_obj
|
5 |
from .weather import get_weather_current_location, get_weather, get_forecast
|
6 |
from .routing import find_route
|
7 |
-
from .poi import
|
8 |
from .vehicle import vehicle_status, set_vehicle_speed, set_vehicle_destination
|
9 |
from .interpreter import code_interpreter
|
10 |
|
@@ -32,6 +33,8 @@ def format_functions_for_prompt_raven(*functions):
|
|
32 |
"""
|
33 |
formatted_functions = []
|
34 |
for func in functions:
|
|
|
|
|
35 |
signature = f"{func.__name__}{inspect.signature(func)}"
|
36 |
docstring = inspect.getdoc(func)
|
37 |
formatted_functions.append(
|
@@ -40,4 +43,4 @@ def format_functions_for_prompt_raven(*functions):
|
|
40 |
return "\n".join(formatted_functions)
|
41 |
|
42 |
|
43 |
-
SKILLS_PROMPT = format_functions_for_prompt_raven(get_weather, get_forecast, find_route,
|
|
|
1 |
from datetime import datetime
|
2 |
import inspect
|
3 |
+
from langchain.tools import StructuredTool
|
4 |
|
5 |
from .common import execute_function_call, extract_func_args, vehicle as vehicle_obj
|
6 |
from .weather import get_weather_current_location, get_weather, get_forecast
|
7 |
from .routing import find_route
|
8 |
+
from .poi import search_points_of_interest, search_along_route_w_coordinates
|
9 |
from .vehicle import vehicle_status, set_vehicle_speed, set_vehicle_destination
|
10 |
from .interpreter import code_interpreter
|
11 |
|
|
|
33 |
"""
|
34 |
formatted_functions = []
|
35 |
for func in functions:
|
36 |
+
if isinstance(func, StructuredTool):
|
37 |
+
func = func.func
|
38 |
signature = f"{func.__name__}{inspect.signature(func)}"
|
39 |
docstring = inspect.getdoc(func)
|
40 |
formatted_functions.append(
|
|
|
43 |
return "\n".join(formatted_functions)
|
44 |
|
45 |
|
46 |
+
SKILLS_PROMPT = format_functions_for_prompt_raven(get_weather, get_forecast, find_route, search_points_of_interest)
|
kitt/skills/poi.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1 |
import json
|
|
|
2 |
import requests
|
|
|
|
|
3 |
from .common import config, vehicle
|
4 |
|
5 |
|
@@ -16,7 +19,8 @@ def _select_equally_spaced_coordinates(coords, number_of_points=10):
|
|
16 |
return selected_coords
|
17 |
|
18 |
|
19 |
-
|
|
|
20 |
"""
|
21 |
Get some of the closest points of interest matching the query.
|
22 |
|
@@ -27,16 +31,31 @@ def search_points_of_interests(search_query="french restaurant"):
|
|
27 |
# Extract the latitude and longitude of the vehicle
|
28 |
vehicle_coordinates = getattr(vehicle, "location_coordinates")
|
29 |
lat, lon = vehicle_coordinates
|
30 |
-
|
31 |
|
32 |
# https://developer.tomtom.com/search-api/documentation/search-service/search
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
# Parse JSON from the response
|
39 |
data = r.json()
|
|
|
|
|
40 |
# Extract results
|
41 |
results = data["results"]
|
42 |
|
@@ -57,7 +76,7 @@ def search_points_of_interests(search_query="french restaurant"):
|
|
57 |
output = (
|
58 |
f"There are {len(results)} options in the vicinity. The most relevant are: "
|
59 |
)
|
60 |
-
return output + ".\n ".join(formatted_results)
|
61 |
|
62 |
|
63 |
def find_points_of_interest(lat="0", lon="0", type_of_poi="restaurant"):
|
@@ -69,12 +88,14 @@ def find_points_of_interest(lat="0", lon="0", type_of_poi="restaurant"):
|
|
69 |
:param type_of_poi (string): Required. type of point of interest depending on what the user wants to do.
|
70 |
"""
|
71 |
# https://developer.tomtom.com/search-api/documentation/search-service/points-of-interest-search
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
78 |
|
79 |
# Parse JSON from the response
|
80 |
data = r.json()
|
@@ -103,7 +124,11 @@ def search_along_route_w_coordinates(points: list[tuple[float, float]], query: s
|
|
103 |
"""
|
104 |
|
105 |
# The API endpoint for searching along a route
|
106 |
-
|
|
|
|
|
|
|
|
|
107 |
|
108 |
points = _select_equally_spaced_coordinates(points, number_of_points=20)
|
109 |
|
|
|
1 |
import json
|
2 |
+
import urllib.parse
|
3 |
import requests
|
4 |
+
from loguru import logger
|
5 |
+
from langchain.tools import tool
|
6 |
from .common import config, vehicle
|
7 |
|
8 |
|
|
|
19 |
return selected_coords
|
20 |
|
21 |
|
22 |
+
@tool
|
23 |
+
def search_points_of_interest(search_query: str ="french restaurant"):
|
24 |
"""
|
25 |
Get some of the closest points of interest matching the query.
|
26 |
|
|
|
31 |
# Extract the latitude and longitude of the vehicle
|
32 |
vehicle_coordinates = getattr(vehicle, "location_coordinates")
|
33 |
lat, lon = vehicle_coordinates
|
34 |
+
logger.info(f"POI search vehicle's lat: {lat}, lon: {lon}")
|
35 |
|
36 |
# https://developer.tomtom.com/search-api/documentation/search-service/search
|
37 |
+
# Encode the parameters
|
38 |
+
# Even with encoding tomtom doesn't return the correct results
|
39 |
+
search_query = search_query.replace("'", "")
|
40 |
+
encoded_search_query = urllib.parse.quote(search_query)
|
41 |
+
|
42 |
+
# Construct the URL
|
43 |
+
url = f"https://api.tomtom.com/search/2/search/{encoded_search_query}.json"
|
44 |
+
params = {
|
45 |
+
"key": config.TOMTOM_API_KEY,
|
46 |
+
"lat": lat,
|
47 |
+
"lon": lon,
|
48 |
+
"radius": 5000,
|
49 |
+
"idxSet": "POI",
|
50 |
+
"limit": 50
|
51 |
+
}
|
52 |
+
|
53 |
+
r = requests.get(url, params=params, timeout=5)
|
54 |
|
55 |
# Parse JSON from the response
|
56 |
data = r.json()
|
57 |
+
|
58 |
+
logger.debug(f"POI search response: {data}\n url:{url} params: {params}")
|
59 |
# Extract results
|
60 |
results = data["results"]
|
61 |
|
|
|
76 |
output = (
|
77 |
f"There are {len(results)} options in the vicinity. The most relevant are: "
|
78 |
)
|
79 |
+
return output + ".\n ".join(formatted_results), results[:3]
|
80 |
|
81 |
|
82 |
def find_points_of_interest(lat="0", lon="0", type_of_poi="restaurant"):
|
|
|
88 |
:param type_of_poi (string): Required. type of point of interest depending on what the user wants to do.
|
89 |
"""
|
90 |
# https://developer.tomtom.com/search-api/documentation/search-service/points-of-interest-search
|
91 |
+
# Encode the parameters
|
92 |
+
encoded_type_of_poi = urllib.parse.quote(type_of_poi)
|
93 |
+
|
94 |
+
# Construct the URL
|
95 |
+
url = f"https://api.tomtom.com/search/2/search/{encoded_type_of_poi}.json?key={config.TOMTOM_API_KEY}&lat={lat}&lon={lon}&radius=10000&vehicleTypeSet=Car&idxSet=POI&limit=100"
|
96 |
+
|
97 |
+
r = requests.get(url, timeout=5)
|
98 |
+
|
99 |
|
100 |
# Parse JSON from the response
|
101 |
data = r.json()
|
|
|
124 |
"""
|
125 |
|
126 |
# The API endpoint for searching along a route
|
127 |
+
|
128 |
+
# urlencode the query
|
129 |
+
query = urllib.parse.quote(query)
|
130 |
+
|
131 |
+
url = f"https://api.tomtom.com/search/2/searchAlongRoute/{query}.json?key={config.TOMTOM_API_KEY}&maxDetourTime=600&limit=20&sortBy=detourTime"
|
132 |
|
133 |
points = _select_equally_spaced_coordinates(points, number_of_points=20)
|
134 |
|
kitt/skills/routing.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
from datetime import datetime
|
2 |
import requests
|
|
|
|
|
3 |
from .common import config, vehicle
|
4 |
|
5 |
|
@@ -120,6 +122,7 @@ def find_route_a_to_b(origin="", destination=""):
|
|
120 |
return _format_tomtom_trip_info(trip_info, destination)
|
121 |
|
122 |
|
|
|
123 |
def find_route(destination):
|
124 |
"""Get a route to a destination from the current location of the vehicle.
|
125 |
|
|
|
1 |
from datetime import datetime
|
2 |
import requests
|
3 |
+
from loguru import logger
|
4 |
+
from langchain.tools import tool
|
5 |
from .common import config, vehicle
|
6 |
|
7 |
|
|
|
122 |
return _format_tomtom_trip_info(trip_info, destination)
|
123 |
|
124 |
|
125 |
+
@tool
|
126 |
def find_route(destination):
|
127 |
"""Get a route to a destination from the current location of the vehicle.
|
128 |
|
kitt/skills/vehicle.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from .common import vehicle, Speed
|
2 |
|
3 |
|
@@ -35,8 +37,8 @@ def vehicle_status() -> tuple[str, dict[str, str]]:
|
|
35 |
return STATUS_TEMPLATE.format(**vs), vs
|
36 |
|
37 |
|
38 |
-
|
39 |
-
def set_vehicle_speed(speed: Speed):
|
40 |
"""Set the speed of the vehicle.
|
41 |
Args:
|
42 |
speed (Speed): The speed of the vehicle. ("slow", "fast")
|
@@ -44,7 +46,8 @@ def set_vehicle_speed(speed: Speed):
|
|
44 |
vehicle.speed = speed
|
45 |
return f"The vehicle speed is set to {speed.value}."
|
46 |
|
47 |
-
|
|
|
48 |
"""Set the destination of the vehicle.
|
49 |
Args:
|
50 |
destination (str): The destination of the vehicle.
|
|
|
1 |
+
from langchain.tools import tool
|
2 |
+
|
3 |
from .common import vehicle, Speed
|
4 |
|
5 |
|
|
|
37 |
return STATUS_TEMPLATE.format(**vs), vs
|
38 |
|
39 |
|
40 |
+
@tool
|
41 |
+
def set_vehicle_speed(speed: Speed) -> str:
|
42 |
"""Set the speed of the vehicle.
|
43 |
Args:
|
44 |
speed (Speed): The speed of the vehicle. ("slow", "fast")
|
|
|
46 |
vehicle.speed = speed
|
47 |
return f"The vehicle speed is set to {speed.value}."
|
48 |
|
49 |
+
@tool
|
50 |
+
def set_vehicle_destination(destination: str) -> str:
|
51 |
"""Set the destination of the vehicle.
|
52 |
Args:
|
53 |
destination (str): The destination of the vehicle.
|
kitt/skills/weather.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import requests
|
2 |
from loguru import logger
|
|
|
3 |
|
4 |
from .common import config, vehicle
|
5 |
|
@@ -19,7 +20,7 @@ def get_weather_current_location():
|
|
19 |
return get_weather(location)
|
20 |
|
21 |
|
22 |
-
|
23 |
def get_weather(location: str = "here"):
|
24 |
"""
|
25 |
Get the current weather in a specified location.
|
@@ -70,7 +71,8 @@ def get_weather(location: str = "here"):
|
|
70 |
# f"Humidity is at {humidity}%. "
|
71 |
# f"Wind speed is {wind_kph} kph." if 'wind_kph' in weather_data['current'] else ""
|
72 |
)
|
73 |
-
return weather_sentences, weather_data
|
|
|
74 |
|
75 |
|
76 |
# weather forecast API
|
|
|
1 |
import requests
|
2 |
from loguru import logger
|
3 |
+
from langchain.tools import tool
|
4 |
|
5 |
from .common import config, vehicle
|
6 |
|
|
|
20 |
return get_weather(location)
|
21 |
|
22 |
|
23 |
+
@tool
|
24 |
def get_weather(location: str = "here"):
|
25 |
"""
|
26 |
Get the current weather in a specified location.
|
|
|
71 |
# f"Humidity is at {humidity}%. "
|
72 |
# f"Wind speed is {wind_kph} kph." if 'wind_kph' in weather_data['current'] else ""
|
73 |
)
|
74 |
+
# return weather_sentences, weather_data
|
75 |
+
return weather_sentences
|
76 |
|
77 |
|
78 |
# weather forecast API
|
main.py
CHANGED
@@ -12,7 +12,8 @@ import ollama
|
|
12 |
|
13 |
from langchain.tools.base import StructuredTool
|
14 |
from langchain.memory import ChatMessageHistory
|
15 |
-
from langchain_core.utils.function_calling import
|
|
|
16 |
from loguru import logger
|
17 |
|
18 |
|
@@ -22,7 +23,7 @@ from kitt.skills import (
|
|
22 |
get_forecast,
|
23 |
vehicle_status as vehicle_status_fn,
|
24 |
set_vehicle_speed,
|
25 |
-
|
26 |
search_along_route_w_coordinates,
|
27 |
set_vehicle_destination,
|
28 |
do_anything_else,
|
@@ -32,7 +33,8 @@ from kitt.skills import (
|
|
32 |
)
|
33 |
from kitt.skills import extract_func_args
|
34 |
from kitt.core import voice_options, tts_gradio
|
35 |
-
from kitt.core.model import process_query
|
|
|
36 |
from kitt.core import utils as kitt_utils
|
37 |
|
38 |
|
@@ -68,6 +70,8 @@ Answer questions concisely and do not mention what you base your reply on.<|im_e
|
|
68 |
<|im_start|>assistant
|
69 |
"""
|
70 |
|
|
|
|
|
71 |
|
72 |
def get_prompt(template, input, history, tools):
|
73 |
# "vehicle_status": vehicle_status_fn()[0]
|
@@ -98,6 +102,7 @@ def use_tool(func_name, kwargs, tools):
|
|
98 |
hour_options = [f"{i:02d}:00:00" for i in range(24)]
|
99 |
|
100 |
|
|
|
101 |
def search_along_route(query=""):
|
102 |
"""Search for points of interest along the route/way to the destination.
|
103 |
|
@@ -120,18 +125,29 @@ def get_vehicle_status(state):
|
|
120 |
|
121 |
|
122 |
tools = [
|
123 |
-
StructuredTool.from_function(get_weather),
|
124 |
-
StructuredTool.from_function(find_route),
|
125 |
# StructuredTool.from_function(vehicle_status_fn),
|
126 |
-
StructuredTool.from_function(set_vehicle_speed),
|
127 |
-
StructuredTool.from_function(
|
128 |
-
StructuredTool.from_function(
|
|
|
129 |
# StructuredTool.from_function(date_time_info),
|
130 |
# StructuredTool.from_function(get_weather_current_location),
|
131 |
-
StructuredTool.from_function(code_interpreter),
|
132 |
# StructuredTool.from_function(do_anything_else),
|
133 |
]
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
def run_generic_model(query):
|
137 |
print(f"Running the generic model with query: {query}")
|
@@ -186,11 +202,16 @@ def run_nexusraven_model(query, voice_character, state):
|
|
186 |
|
187 |
|
188 |
def run_llama3_model(query, voice_character, state):
|
|
|
|
|
|
|
|
|
189 |
output_text = process_query(
|
190 |
query,
|
191 |
history=history,
|
192 |
user_preferences=state["user_preferences"],
|
193 |
-
tools=
|
|
|
194 |
backend=state["llm_backend"],
|
195 |
)
|
196 |
gr.Info(f"Output text: {output_text}, generating voice output...")
|
@@ -216,6 +237,9 @@ def run_model(query, voice_character, state):
|
|
216 |
text, voice = run_llama3_model(query, voice_character, state)
|
217 |
else:
|
218 |
text, voice = "Error running model", None
|
|
|
|
|
|
|
219 |
return text, voice, vehicle.model_dump_json()
|
220 |
|
221 |
|
@@ -285,8 +309,8 @@ def save_and_transcribe_audio(audio):
|
|
285 |
|
286 |
def save_and_transcribe_run_model(audio, voice_character, state):
|
287 |
text = save_and_transcribe_audio(audio)
|
288 |
-
out_text, out_voice = run_model(text, voice_character, state)
|
289 |
-
return text, out_text, out_voice
|
290 |
|
291 |
|
292 |
def set_tts_enabled(tts_enabled, state):
|
@@ -314,6 +338,12 @@ def set_user_preferences(preferences, state):
|
|
314 |
return state
|
315 |
|
316 |
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
# to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
|
318 |
# in "Insecure origins treated as secure", enable it and relaunch chrome
|
319 |
|
@@ -322,6 +352,10 @@ def set_user_preferences(preferences, state):
|
|
322 |
# What's the closest restaurant from here?
|
323 |
|
324 |
|
|
|
|
|
|
|
|
|
325 |
def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = True):
|
326 |
print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
|
327 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
@@ -332,11 +366,13 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
|
|
332 |
"route_points": [],
|
333 |
"model": model,
|
334 |
"tts_enabled": tts_enabled,
|
335 |
-
"llm_backend": "
|
336 |
-
"user_preferences":
|
|
|
337 |
}
|
338 |
)
|
339 |
trip_points = gr.State(value=[])
|
|
|
340 |
|
341 |
with gr.Row():
|
342 |
with gr.Column(scale=1, min_width=300):
|
@@ -346,12 +382,6 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
|
|
346 |
value="08:00:00",
|
347 |
interactive=True,
|
348 |
)
|
349 |
-
history = gr.Radio(
|
350 |
-
["Yes", "No"],
|
351 |
-
label="Maintain the conversation history?",
|
352 |
-
value="No",
|
353 |
-
interactive=True,
|
354 |
-
)
|
355 |
voice_character = gr.Radio(
|
356 |
choices=voice_options,
|
357 |
label="Choose a voice",
|
@@ -359,24 +389,24 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
|
|
359 |
show_label=True,
|
360 |
)
|
361 |
origin = gr.Textbox(
|
362 |
-
value=
|
363 |
label="Origin",
|
364 |
interactive=True,
|
365 |
)
|
366 |
destination = gr.Textbox(
|
367 |
-
value=
|
368 |
label="Destination",
|
369 |
interactive=True,
|
370 |
)
|
371 |
preferences = gr.Textbox(
|
372 |
-
value=
|
373 |
label="User preferences",
|
374 |
lines=3,
|
375 |
interactive=True,
|
376 |
)
|
377 |
|
378 |
with gr.Column(scale=2, min_width=600):
|
379 |
-
map_plot = gr.Plot()
|
380 |
trip_progress = gr.Slider(
|
381 |
0, 100, step=5, label="Trip progress", interactive=True
|
382 |
)
|
@@ -422,6 +452,12 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
|
|
422 |
value="Ollama",
|
423 |
interactive=True,
|
424 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
# Push button
|
426 |
clear_history_btn = gr.Button(value="Clear History")
|
427 |
with gr.Column():
|
@@ -472,7 +508,7 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
|
|
472 |
input_audio.stop_recording(
|
473 |
fn=save_and_transcribe_run_model,
|
474 |
inputs=[input_audio, voice_character, state],
|
475 |
-
outputs=[input_text, output_text, output_audio],
|
476 |
)
|
477 |
input_audio_debug.stop_recording(
|
478 |
fn=save_and_transcribe_audio,
|
@@ -490,6 +526,10 @@ def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = Tr
|
|
490 |
llm_backend.change(
|
491 |
fn=set_llm_backend, inputs=[llm_backend, state], outputs=[state]
|
492 |
)
|
|
|
|
|
|
|
|
|
493 |
return demo
|
494 |
|
495 |
|
|
|
12 |
|
13 |
from langchain.tools.base import StructuredTool
|
14 |
from langchain.memory import ChatMessageHistory
|
15 |
+
from langchain_core.utils.function_calling import convert_to_openai_tool
|
16 |
+
from langchain.tools import tool
|
17 |
from loguru import logger
|
18 |
|
19 |
|
|
|
23 |
get_forecast,
|
24 |
vehicle_status as vehicle_status_fn,
|
25 |
set_vehicle_speed,
|
26 |
+
search_points_of_interest,
|
27 |
search_along_route_w_coordinates,
|
28 |
set_vehicle_destination,
|
29 |
do_anything_else,
|
|
|
33 |
)
|
34 |
from kitt.skills import extract_func_args
|
35 |
from kitt.core import voice_options, tts_gradio
|
36 |
+
# from kitt.core.model import process_query
|
37 |
+
from kitt.core.model import generate_function_call as process_query
|
38 |
from kitt.core import utils as kitt_utils
|
39 |
|
40 |
|
|
|
70 |
<|im_start|>assistant
|
71 |
"""
|
72 |
|
73 |
+
USER_PREFERENCES = "I love italian food\nI like doing sports"
|
74 |
+
|
75 |
|
76 |
def get_prompt(template, input, history, tools):
|
77 |
# "vehicle_status": vehicle_status_fn()[0]
|
|
|
102 |
hour_options = [f"{i:02d}:00:00" for i in range(24)]
|
103 |
|
104 |
|
105 |
+
@tool
|
106 |
def search_along_route(query=""):
|
107 |
"""Search for points of interest along the route/way to the destination.
|
108 |
|
|
|
125 |
|
126 |
|
127 |
tools = [
|
128 |
+
# StructuredTool.from_function(get_weather),
|
129 |
+
# StructuredTool.from_function(find_route),
|
130 |
# StructuredTool.from_function(vehicle_status_fn),
|
131 |
+
# StructuredTool.from_function(set_vehicle_speed),
|
132 |
+
# StructuredTool.from_function(set_vehicle_destination),
|
133 |
+
# StructuredTool.from_function(search_points_of_interest),
|
134 |
+
# StructuredTool.from_function(search_along_route),
|
135 |
# StructuredTool.from_function(date_time_info),
|
136 |
# StructuredTool.from_function(get_weather_current_location),
|
137 |
+
# StructuredTool.from_function(code_interpreter),
|
138 |
# StructuredTool.from_function(do_anything_else),
|
139 |
]
|
140 |
|
141 |
+
functions = [
|
142 |
+
set_vehicle_speed,
|
143 |
+
set_vehicle_destination,
|
144 |
+
get_weather,
|
145 |
+
find_route,
|
146 |
+
search_points_of_interest,
|
147 |
+
search_along_route
|
148 |
+
]
|
149 |
+
openai_tools = [convert_to_openai_tool(tool) for tool in functions]
|
150 |
+
|
151 |
|
152 |
def run_generic_model(query):
|
153 |
print(f"Running the generic model with query: {query}")
|
|
|
202 |
|
203 |
|
204 |
def run_llama3_model(query, voice_character, state):
|
205 |
+
|
206 |
+
assert len (functions) > 0, "No functions to call"
|
207 |
+
assert len (openai_tools) > 0, "No openai tools to call"
|
208 |
+
|
209 |
output_text = process_query(
|
210 |
query,
|
211 |
history=history,
|
212 |
user_preferences=state["user_preferences"],
|
213 |
+
tools=openai_tools,
|
214 |
+
functions=functions,
|
215 |
backend=state["llm_backend"],
|
216 |
)
|
217 |
gr.Info(f"Output text: {output_text}, generating voice output...")
|
|
|
237 |
text, voice = run_llama3_model(query, voice_character, state)
|
238 |
else:
|
239 |
text, voice = "Error running model", None
|
240 |
+
|
241 |
+
if not state["enable_history"]:
|
242 |
+
history.clear()
|
243 |
return text, voice, vehicle.model_dump_json()
|
244 |
|
245 |
|
|
|
309 |
|
310 |
def save_and_transcribe_run_model(audio, voice_character, state):
|
311 |
text = save_and_transcribe_audio(audio)
|
312 |
+
out_text, out_voice, vehicle_status = run_model(text, voice_character, state)
|
313 |
+
return text, out_text, out_voice, vehicle_status
|
314 |
|
315 |
|
316 |
def set_tts_enabled(tts_enabled, state):
|
|
|
338 |
return state
|
339 |
|
340 |
|
341 |
+
def set_enable_history(enable_history, state):
|
342 |
+
new_enable_history = enable_history == "Yes"
|
343 |
+
logger.info(f"Enable history was {state['enable_history']} and changed to {new_enable_history}")
|
344 |
+
state["enable_history"] = new_enable_history
|
345 |
+
return state
|
346 |
+
|
347 |
# to be able to use the microphone on chrome, you will have to go to chrome://flags/#unsafely-treat-insecure-origin-as-secure and enter http://10.186.115.21:7860/
|
348 |
# in "Insecure origins treated as secure", enable it and relaunch chrome
|
349 |
|
|
|
352 |
# What's the closest restaurant from here?
|
353 |
|
354 |
|
355 |
+
ORIGIN = "Mondorf-les-Bains, Luxembourg"
|
356 |
+
DESTINATION = "Rue Alphonse Weicker, Luxembourg"
|
357 |
+
|
358 |
+
|
359 |
def create_demo(tts_server: bool = False, model="llama3", tts_enabled: bool = True):
|
360 |
print(f"Running the demo with model: {model} and TTSServer: {tts_server}")
|
361 |
with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
|
366 |
"route_points": [],
|
367 |
"model": model,
|
368 |
"tts_enabled": tts_enabled,
|
369 |
+
"llm_backend": "ollama",
|
370 |
+
"user_preferences": USER_PREFERENCES,
|
371 |
+
"enable_history": False,
|
372 |
}
|
373 |
)
|
374 |
trip_points = gr.State(value=[])
|
375 |
+
plot, vehicle_status, _ = calculate_route_gradio(ORIGIN, DESTINATION)
|
376 |
|
377 |
with gr.Row():
|
378 |
with gr.Column(scale=1, min_width=300):
|
|
|
382 |
value="08:00:00",
|
383 |
interactive=True,
|
384 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
voice_character = gr.Radio(
|
386 |
choices=voice_options,
|
387 |
label="Choose a voice",
|
|
|
389 |
show_label=True,
|
390 |
)
|
391 |
origin = gr.Textbox(
|
392 |
+
value=ORIGIN,
|
393 |
label="Origin",
|
394 |
interactive=True,
|
395 |
)
|
396 |
destination = gr.Textbox(
|
397 |
+
value=DESTINATION,
|
398 |
label="Destination",
|
399 |
interactive=True,
|
400 |
)
|
401 |
preferences = gr.Textbox(
|
402 |
+
value=USER_PREFERENCES,
|
403 |
label="User preferences",
|
404 |
lines=3,
|
405 |
interactive=True,
|
406 |
)
|
407 |
|
408 |
with gr.Column(scale=2, min_width=600):
|
409 |
+
map_plot = gr.Plot(value=plot, label="Map")
|
410 |
trip_progress = gr.Slider(
|
411 |
0, 100, step=5, label="Trip progress", interactive=True
|
412 |
)
|
|
|
452 |
value="Ollama",
|
453 |
interactive=True,
|
454 |
)
|
455 |
+
enable_history = gr.Radio(
|
456 |
+
["Yes", "No"],
|
457 |
+
label="Maintain the conversation history?",
|
458 |
+
value="No",
|
459 |
+
interactive=True,
|
460 |
+
)
|
461 |
# Push button
|
462 |
clear_history_btn = gr.Button(value="Clear History")
|
463 |
with gr.Column():
|
|
|
508 |
input_audio.stop_recording(
|
509 |
fn=save_and_transcribe_run_model,
|
510 |
inputs=[input_audio, voice_character, state],
|
511 |
+
outputs=[input_text, output_text, output_audio, vehicle_status],
|
512 |
)
|
513 |
input_audio_debug.stop_recording(
|
514 |
fn=save_and_transcribe_audio,
|
|
|
526 |
llm_backend.change(
|
527 |
fn=set_llm_backend, inputs=[llm_backend, state], outputs=[state]
|
528 |
)
|
529 |
+
enable_history.change(
|
530 |
+
fn=set_enable_history, inputs=[enable_history, state], outputs=[state]
|
531 |
+
)
|
532 |
+
|
533 |
return demo
|
534 |
|
535 |
|