michaelfeil commited on
Commit
11ac6f7
1 Parent(s): 02e3a03

add mongooseminer

Browse files
Files changed (3) hide show
  1. Dockerfile +8 -0
  2. main.py +160 -0
  3. search.py +65 -0
Dockerfile ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from python:3.10-slim
2
+
3
+ RUN pip install groq gradio infinity_emb[all] usearch
4
+
5
+ WORKDIR /app
6
+ COPY . .
7
+
8
+ CMD python main.py
main.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+
5
+ from groq import Groq
6
+ from search import answer_query
7
+ try:
8
+ from dotenv import load_dotenv
9
+ load_dotenv(dotenv_path="./.env")
10
+ except:
11
+ pass
12
+
13
+ client = Groq(
14
+ api_key=os.environ.get("GROQ_API_KEY"),
15
+ )
16
+
17
+ tools = [
18
+ {
19
+ "type": "function",
20
+ "function": {
21
+ "name": "get_related_functions",
22
+ "description": "Get docstrings for internal functions for any library on PyPi.",
23
+ "parameters": {
24
+ "type": "object",
25
+ "properties": {
26
+ "user_query": {
27
+ "type": "string",
28
+ "description": "A query to retrieve docstrings and find useful information.",
29
+ }
30
+ },
31
+ "required": ["user_query"],
32
+ },
33
+ },
34
+ }
35
+ ]
36
+
37
+
38
+ def user(user_message, history):
39
+ return "", history + [[user_message, None]]
40
+
41
+
42
+ def get_related_functions(user_query: str) -> dict:
43
+ docstring_top10 = answer_query(user_query)
44
+ print("added torch mul")
45
+ return docstring_top10[0]
46
+
47
+
48
+ def generate_rag(history):
49
+ messages = [
50
+ {
51
+ "role": "system",
52
+ "content": "You are a function calling LLM that uses the data extracted from the get_related_functions function to answer questions around writing Python code. Use the extraced docstrings to write better code."
53
+ },
54
+ {
55
+ "role": "user",
56
+ "content": history[-1][0],
57
+ }
58
+ ]
59
+ history[-1][1] = ""
60
+ tool_call_count = 0
61
+ max_tool_calls = 3
62
+ while tool_call_count <= max_tool_calls:
63
+ response = client.chat.completions.create(
64
+ model="llama3-70b-8192",
65
+ messages=messages,
66
+ tools=tools if tool_call_count < 3 else None,
67
+ tool_choice="auto",
68
+ max_tokens=4096
69
+ )
70
+ tool_call_count += 1
71
+ response_message = response.choices[0].message
72
+ tool_calls = response_message.tool_calls
73
+
74
+ if tool_calls:
75
+ available_functions = {
76
+ "get_related_functions": get_related_functions,
77
+ }
78
+ messages.append(response_message)
79
+
80
+ for tool_call in tool_calls:
81
+ function_name = tool_call.function.name
82
+ function_to_call = available_functions[function_name]
83
+ function_args = json.loads(tool_call.function.arguments)
84
+ function_response = function_to_call(
85
+ user_query=function_args.get("user_query")
86
+ )
87
+ messages.append(
88
+ {
89
+ "tool_call_id": tool_call.id,
90
+ "role": "tool",
91
+ "name": function_name,
92
+ "content": function_response,
93
+ }
94
+ )
95
+ else:
96
+ break
97
+
98
+ history[-1][1] += response_message.content
99
+ return history
100
+
101
+
102
+ def generate_llama3(history):
103
+ history[-1][1] = ""
104
+ stream = client.chat.completions.create(
105
+ messages=[
106
+ # Set an optional system message. This sets the behavior of the
107
+ # assistant and can be used to provide specific instructions for
108
+ # how it should behave throughout the conversation.
109
+ {
110
+ "role": "system",
111
+ "content": "you are a helpful assistant."
112
+ },
113
+ # Set a user message for the assistant to respond to.
114
+ {
115
+ "role": "user",
116
+ "content": history[-1][0],
117
+ }
118
+ ],
119
+ stream=True,
120
+ model="llama3-8b-8192",
121
+ max_tokens=1024,
122
+ temperature=0
123
+ )
124
+
125
+ for chunk in stream:
126
+ if chunk.choices[0].delta.content != None:
127
+ history[-1][1] += chunk.choices[0].delta.content
128
+ yield history
129
+ else:
130
+ return
131
+
132
+
133
+ with gr.Blocks() as demo:
134
+
135
+ with gr.Row():
136
+ with gr.Column():
137
+ gr.Markdown("# Mongoose Miner Search Demo")
138
+ gr.Markdown(
139
+ "Augmenting LLM code generation with function-level search across all of PyPi.")
140
+
141
+ with gr.Row():
142
+ chatbot = gr.Chatbot(height="35rem", label="Llama3 unaugmented")
143
+ chatbot2 = gr.Chatbot(
144
+ height="35rem", label="Llama3 with MongooseMiner Search")
145
+ msg = gr.Textbox()
146
+
147
+ clear = gr.Button("Clear")
148
+
149
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
150
+ generate_llama3, chatbot, chatbot
151
+ )
152
+ msg.submit(user, [msg, chatbot2], [msg, chatbot2], queue=False).then(
153
+ generate_rag, chatbot2, chatbot2
154
+ )
155
+ clear.click(lambda: None, None, chatbot, queue=False)
156
+ clear.click(lambda: None, None, chatbot2, queue=False)
157
+
158
+
159
+ demo.queue()
160
+ demo.launch()
search.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from infinity_emb import AsyncEmbeddingEngine, EngineArgs
2
+ import numpy as np
3
+ from usearch.index import Index, Matches
4
+ import asyncio
5
+ import pandas as pd
6
+
7
+ engine = AsyncEmbeddingEngine.from_args(
8
+ EngineArgs(
9
+ model_name_or_path="michaelfeil/jina-embeddings-v2-base-code",
10
+ batch_size=8,
11
+ )
12
+ )
13
+
14
+
15
+ async def embed_texts(texts: list[str]) -> np.ndarray:
16
+ async with engine:
17
+ embeddings = (await engine.embed(texts))[0]
18
+ return np.array(embeddings)
19
+
20
+ def embed_texts_sync(texts: list[str]) -> np.ndarray:
21
+ loop = asyncio.new_event_loop()
22
+ return loop.run_until_complete(embed_texts(texts))
23
+
24
+ index = None
25
+ docs_index = None
26
+
27
+
28
+ def build_index(demo_mode=True):
29
+ global index, docs_index
30
+ index = Index(
31
+ ndim=embed_texts_sync(["Hi"]).shape[
32
+ -1
33
+ ], # Define the number of dimensions in input vectors
34
+ metric="cos", # Choose 'l2sq', 'haversine' or other metric, default = 'ip'
35
+ dtype="f16", # Quantize to 'f16' or 'i8' if needed, default = 'f32'
36
+ connectivity=16, # How frequent should the connections in the graph be, optional
37
+ expansion_add=128, # Control the recall of indexing, optional
38
+ expansion_search=64, # Control the quality of search, optional
39
+ )
40
+ if demo_mode:
41
+ docs_index = [
42
+ "torch.add(*demo)",
43
+ "torch.mul(*demo)",
44
+ "torch.div(*demo)",
45
+ "torch.sub(*demo)",
46
+ ]
47
+ embeddings = embed_texts_sync(docs_index)
48
+ index.add(np.arange(len(docs_index)), embeddings)
49
+ return
50
+ # TODO: Michael, load parquet with embeddings
51
+
52
+
53
+ if index is None:
54
+ build_index()
55
+
56
+
57
+ def answer_query(query: str) -> list[str]:
58
+ embedding = embed_texts_sync([query])
59
+ matches = index.search(embedding, 10)
60
+ texts = [docs_index[match.key] for match in matches]
61
+ return texts
62
+
63
+
64
+ if __name__ == "__main__":
65
+ print(answer_query("torch.mul(*demo2)"))