MongooseMiner / main.py
michaelfeil
add mongooseminer
11ac6f7
raw
history blame contribute delete
No virus
4.83 kB
import gradio as gr
import os
import json
from groq import Groq
from search import answer_query
try:
from dotenv import load_dotenv
load_dotenv(dotenv_path="./.env")
except:
pass
client = Groq(
api_key=os.environ.get("GROQ_API_KEY"),
)
tools = [
{
"type": "function",
"function": {
"name": "get_related_functions",
"description": "Get docstrings for internal functions for any library on PyPi.",
"parameters": {
"type": "object",
"properties": {
"user_query": {
"type": "string",
"description": "A query to retrieve docstrings and find useful information.",
}
},
"required": ["user_query"],
},
},
}
]
def user(user_message, history):
return "", history + [[user_message, None]]
def get_related_functions(user_query: str) -> dict:
docstring_top10 = answer_query(user_query)
print("added torch mul")
return docstring_top10[0]
def generate_rag(history):
messages = [
{
"role": "system",
"content": "You are a function calling LLM that uses the data extracted from the get_related_functions function to answer questions around writing Python code. Use the extraced docstrings to write better code."
},
{
"role": "user",
"content": history[-1][0],
}
]
history[-1][1] = ""
tool_call_count = 0
max_tool_calls = 3
while tool_call_count <= max_tool_calls:
response = client.chat.completions.create(
model="llama3-70b-8192",
messages=messages,
tools=tools if tool_call_count < 3 else None,
tool_choice="auto",
max_tokens=4096
)
tool_call_count += 1
response_message = response.choices[0].message
tool_calls = response_message.tool_calls
if tool_calls:
available_functions = {
"get_related_functions": get_related_functions,
}
messages.append(response_message)
for tool_call in tool_calls:
function_name = tool_call.function.name
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
function_response = function_to_call(
user_query=function_args.get("user_query")
)
messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": function_response,
}
)
else:
break
history[-1][1] += response_message.content
return history
def generate_llama3(history):
history[-1][1] = ""
stream = client.chat.completions.create(
messages=[
# Set an optional system message. This sets the behavior of the
# assistant and can be used to provide specific instructions for
# how it should behave throughout the conversation.
{
"role": "system",
"content": "you are a helpful assistant."
},
# Set a user message for the assistant to respond to.
{
"role": "user",
"content": history[-1][0],
}
],
stream=True,
model="llama3-8b-8192",
max_tokens=1024,
temperature=0
)
for chunk in stream:
if chunk.choices[0].delta.content != None:
history[-1][1] += chunk.choices[0].delta.content
yield history
else:
return
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("# Mongoose Miner Search Demo")
gr.Markdown(
"Augmenting LLM code generation with function-level search across all of PyPi.")
with gr.Row():
chatbot = gr.Chatbot(height="35rem", label="Llama3 unaugmented")
chatbot2 = gr.Chatbot(
height="35rem", label="Llama3 with MongooseMiner Search")
msg = gr.Textbox()
clear = gr.Button("Clear")
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
generate_llama3, chatbot, chatbot
)
msg.submit(user, [msg, chatbot2], [msg, chatbot2], queue=False).then(
generate_rag, chatbot2, chatbot2
)
clear.click(lambda: None, None, chatbot, queue=False)
clear.click(lambda: None, None, chatbot2, queue=False)
demo.queue()
demo.launch()