Technocoloredgeek commited on
Commit
273bfd6
1 Parent(s): 8187b01

Create multiagent.py

Browse files
Files changed (1) hide show
  1. multiagent.py +336 -0
multiagent.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Change to requirements caller
2
+ import sys
3
+ import subprocess
4
+
5
+ def run_pip_install():
6
+ packages = [
7
+ "langgraph",
8
+ "langchain",
9
+ "langchain_openai",
10
+ "langchain_experimental",
11
+ "qdrant-client",
12
+ "pymupdf",
13
+ "tiktoken",
14
+ "huggingface_hub",
15
+ "openai",
16
+ "tavily-python"
17
+ ]
18
+
19
+ package_string = " ".join(packages)
20
+
21
+ try:
22
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-qU"] + packages)
23
+ print("All required packages have been installed successfully.")
24
+ except subprocess.CalledProcessError:
25
+ print(f"Failed to install packages. Please run the following command manually:")
26
+ print(f"%pip install -qU {package_string}")
27
+ sys.exit(1)
28
+
29
+ # Run pip install
30
+ run_pip_install()
31
+
32
+ import os
33
+ import functools
34
+ import operator
35
+ from typing import Annotated, List, Tuple, Union, Dict, Optional
36
+ from typing_extensions import TypedDict
37
+ import uuid
38
+ from pathlib import Path
39
+
40
+ from langchain_core.tools import tool
41
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
42
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
43
+ from langchain_openai import ChatOpenAI
44
+ from langchain.agents import AgentExecutor, create_openai_functions_agent
45
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
46
+ from langchain_community.tools.tavily_search import TavilySearchResults
47
+ from langchain_community.vectorstores import Qdrant
48
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
49
+ from langchain_openai.embeddings import OpenAIEmbeddings
50
+ from langgraph.graph import END, StateGraph
51
+ from huggingface_hub import hf_hub_download
52
+
53
+ # Environment setup
54
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
55
+ TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
56
+
57
+ if not OPENAI_API_KEY:
58
+ raise ValueError("OPENAI_API_KEY not found in environment variables")
59
+ if not TAVILY_API_KEY:
60
+ raise ValueError("TAVILY_API_KEY not found in environment variables")
61
+
62
+ os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
63
+ os.environ["TAVILY_API_KEY"] = TAVILY_API_KEY
64
+
65
+ # CHANGE TO HF DIRECTORY
66
+ WORKING_DIRECTORY = Path("/tmp/content/data")
67
+ WORKING_DIRECTORY.mkdir(parents=True, exist_ok=True)
68
+
69
+ # Utility functions
70
+ def create_random_subdirectory():
71
+ random_id = str(uuid.uuid4())[:8]
72
+ subdirectory_path = WORKING_DIRECTORY / random_id
73
+ subdirectory_path.mkdir(exist_ok=True)
74
+ return subdirectory_path
75
+
76
+ def get_current_files():
77
+ try:
78
+ files = [f.relative_to(WORKING_DIRECTORY) for f in WORKING_DIRECTORY.rglob("*") if f.is_file()]
79
+ return "\n".join(str(f) for f in files) if files else "No files written."
80
+ except Exception:
81
+ return "Unable to retrieve current files."
82
+
83
+ # Document loading change to upload in HF
84
+ def fetch_hbr_article():
85
+ pdf_path = hf_hub_download(repo_id="your-username/your-repo-name", filename="murthy-loneliness.pdf")
86
+ return PyMuPDFLoader(pdf_path).load()
87
+
88
+ # Document processing
89
+ def tiktoken_len(text):
90
+ tokens = tiktoken.encoding_for_model("gpt-4o-mini").encode(text)
91
+ return len(tokens)
92
+
93
+ text_splitter = RecursiveCharacterTextSplitter(
94
+ chunk_size=300,
95
+ chunk_overlap=0,
96
+ length_function=tiktoken_len,
97
+ )
98
+
99
+ docs = fetch_hbr_article()
100
+ split_chunks = text_splitter.split_documents(docs)
101
+
102
+ # Embedding and vector store setup
103
+ embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
104
+ qdrant_vectorstore = Qdrant.from_documents(
105
+ split_chunks,
106
+ embedding_model,
107
+ location=":memory:",
108
+ collection_name="extending_context_window_llama_3",
109
+ )
110
+ qdrant_retriever = qdrant_vectorstore.as_retriever()
111
+
112
+ # RAG setup
113
+ RAG_PROMPT = """
114
+ CONTEXT:
115
+ {context}
116
+
117
+ QUERY:
118
+ {question}
119
+
120
+ You are a helpful assistant. Use the available context to answer the question. If you can't answer the question, say you don't know.
121
+ """
122
+ rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
123
+ openai_chat_model = ChatOpenAI(model="gpt-4o-mini")
124
+
125
+ rag_chain = (
126
+ {"context": itemgetter("question") | qdrant_retriever, "question": itemgetter("question")}
127
+ | rag_prompt | openai_chat_model | StrOutputParser()
128
+ )
129
+
130
+ # Tool definitions
131
+ @tool
132
+ def create_outline(points: List[str], file_name: str) -> str:
133
+ """Create and save an outline."""
134
+ with (WORKING_DIRECTORY / file_name).open("w") as file:
135
+ for i, point in enumerate(points):
136
+ file.write(f"{i + 1}. {point}\n")
137
+ return f"Outline saved to {file_name}"
138
+
139
+ @tool
140
+ def read_document(file_name: str, start: Optional[int] = None, end: Optional[int] = None) -> str:
141
+ """Read the specified document."""
142
+ with (WORKING_DIRECTORY / file_name).open("r") as file:
143
+ lines = file.readlines()
144
+ if start is not None:
145
+ start = 0
146
+ return "\n".join(lines[start:end])
147
+
148
+ @tool
149
+ def write_document(content: str, file_name: str) -> str:
150
+ """Create and save a text document."""
151
+ with (WORKING_DIRECTORY / file_name).open("w") as file:
152
+ file.write(content)
153
+ return f"Document saved to {file_name}"
154
+
155
+ @tool
156
+ def edit_document(file_name: str, inserts: Dict[int, str] = {}) -> str:
157
+ """Edit a document by inserting text at specific line numbers."""
158
+ with (WORKING_DIRECTORY / file_name).open("r") as file:
159
+ lines = file.readlines()
160
+
161
+ sorted_inserts = sorted(inserts.items())
162
+ for line_number, text in sorted_inserts:
163
+ if 1 <= line_number <= len(lines) + 1:
164
+ lines.insert(line_number - 1, text + "\n")
165
+ else:
166
+ return f"Error: Line number {line_number} is out of range."
167
+
168
+ with (WORKING_DIRECTORY / file_name).open("w") as file:
169
+ file.writelines(lines)
170
+ return f"Document edited and saved to {file_name}"
171
+
172
+ @tool
173
+ def retrieve_information(query: str):
174
+ """Use Retrieval Augmented Generation to retrieve information about the 'murthy-loneliness' paper."""
175
+ return rag_chain.invoke({"question": query})
176
+
177
+ # Agent creation helpers
178
+ def create_team_agent(llm, tools, system_prompt, agent_name, team_members):
179
+ return create_agent(
180
+ llm,
181
+ tools,
182
+ f"{system_prompt}\nBelow are files currently in your directory:\n{{current_files}}",
183
+ team_members
184
+ )
185
+
186
+ def create_agent_node(agent, name):
187
+ return functools.partial(agent_node, agent=agent, name=name)
188
+
189
+ def add_agent_to_graph(graph, agent_name, agent_node):
190
+ graph.add_node(agent_name, agent_node)
191
+ graph.add_edge(agent_name, "supervisor")
192
+
193
+ def create_team_supervisor(llm, team_description, team_members):
194
+ return create_team_supervisor(
195
+ llm,
196
+ f"You are a supervisor tasked with managing a conversation between the"
197
+ f" following workers: {', '.join(team_members)}. {team_description}"
198
+ f" When all workers are finished, you must respond with FINISH.",
199
+ team_members
200
+ )
201
+
202
+ def create_team_chain(graph, team_members):
203
+ return (
204
+ functools.partial(enter_chain, members=team_members)
205
+ | graph.compile()
206
+ )
207
+
208
+ # LLM setup
209
+ llm = ChatOpenAI(model="gpt-4-turbo")
210
+
211
+ # Agent creation
212
+ tavily_tool = TavilySearchResults(max_results=5)
213
+
214
+ search_agent = create_team_agent(
215
+ llm,
216
+ [tavily_tool],
217
+ "You are a research assistant who can search for up-to-date info using the tavily search engine.",
218
+ "Search",
219
+ ["Search", "PaperInformationRetriever"]
220
+ )
221
+
222
+ research_agent = create_team_agent(
223
+ llm,
224
+ [retrieve_information],
225
+ "You are a research assistant who can provide specific information on the provided paper: 'murthy-loneliness.pdf'. You must only respond with information about the paper related to the request.",
226
+ "PaperInformationRetriever",
227
+ ["Search", "PaperInformationRetriever"]
228
+ )
229
+
230
+ doc_writer_agent = create_team_agent(
231
+ llm,
232
+ [write_document, edit_document, read_document],
233
+ "You are an expert writing technical social media posts.",
234
+ "DocWriter",
235
+ ["DocWriter", "NoteTaker", "CopyEditor", "VoiceEditor"]
236
+ )
237
+
238
+ note_taking_agent = create_team_agent(
239
+ llm,
240
+ [create_outline, read_document],
241
+ "You are an expert senior researcher tasked with writing a social media post outline and taking notes to craft a social media post.",
242
+ "NoteTaker",
243
+ ["DocWriter", "NoteTaker", "CopyEditor", "VoiceEditor"]
244
+ )
245
+
246
+ copy_editor_agent = create_team_agent(
247
+ llm,
248
+ [write_document, edit_document, read_document],
249
+ "You are an expert copy editor who focuses on fixing grammar, spelling, and tone issues.",
250
+ "CopyEditor",
251
+ ["DocWriter", "NoteTaker", "CopyEditor", "VoiceEditor"]
252
+ )
253
+
254
+ voice_editor_agent = create_team_agent(
255
+ llm,
256
+ [write_document, edit_document, read_document],
257
+ "You are an expert in crafting and refining the voice and tone of social media posts. You edit the document to ensure it has a consistent, professional, and engaging voice appropriate for social media platforms.",
258
+ "VoiceEditor",
259
+ ["DocWriter", "NoteTaker", "CopyEditor", "VoiceEditor"]
260
+ )
261
+
262
+ # Node creation
263
+ search_node = create_agent_node(search_agent, "Search")
264
+ research_node = create_agent_node(research_agent, "PaperInformationRetriever")
265
+ doc_writing_node = create_agent_node(doc_writer_agent, "DocWriter")
266
+ note_taking_node = create_agent_node(note_taking_agent, "NoteTaker")
267
+ copy_editing_node = create_agent_node(copy_editor_agent, "CopyEditor")
268
+ voice_node = create_agent_node(voice_editor_agent, "VoiceEditor")
269
+
270
+ # Graph creation
271
+ research_graph = StateGraph(ResearchTeamState)
272
+ add_agent_to_graph(research_graph, "Search", search_node)
273
+ add_agent_to_graph(research_graph, "PaperInformationRetriever", research_node)
274
+
275
+ authoring_graph = StateGraph(DocWritingState)
276
+ add_agent_to_graph(authoring_graph, "DocWriter", doc_writing_node)
277
+ add_agent_to_graph(authoring_graph, "NoteTaker", note_taking_node)
278
+ add_agent_to_graph(authoring_graph, "CopyEditor", copy_editing_node)
279
+ add_agent_to_graph(authoring_graph, "VoiceEditor", voice_node)
280
+
281
+ # Supervisor creation
282
+ research_supervisor = create_team_supervisor(
283
+ llm,
284
+ "Given the following user request, determine the subject to be researched and respond with the worker to act next.",
285
+ ["Search", "PaperInformationRetriever"]
286
+ )
287
+
288
+ doc_writing_supervisor = create_team_supervisor(
289
+ llm,
290
+ "Given the following user request, determine which worker should act next. Each worker will perform a task and respond with their results and status.",
291
+ ["DocWriter", "NoteTaker", "CopyEditor", "VoiceEditor"]
292
+ )
293
+
294
+ # Graph compilation
295
+ research_graph.add_node("supervisor", research_supervisor)
296
+ research_graph.set_entry_point("supervisor")
297
+ research_chain = create_team_chain(research_graph, research_graph.nodes)
298
+
299
+ authoring_graph.add_node("supervisor", doc_writing_supervisor)
300
+ authoring_graph.set_entry_point("supervisor")
301
+ authoring_chain = create_team_chain(authoring_graph, authoring_graph.nodes)
302
+
303
+ # Meta-supervisor setup
304
+ super_graph = StateGraph(State)
305
+ super_graph.add_node("Research team", get_last_message | research_chain | join_graph)
306
+ super_graph.add_node("SocialMedia team", get_last_message | authoring_chain | join_graph)
307
+ super_graph.add_node("supervisor", supervisor_node)
308
+
309
+ super_graph.add_edge("Research team", "supervisor")
310
+ super_graph.add_edge("SocialMedia team", "supervisor")
311
+ super_graph.add_conditional_edges(
312
+ "supervisor",
313
+ lambda x: x["next"],
314
+ {
315
+ "SocialMedia team": "SocialMedia team",
316
+ "Research team": "Research team",
317
+ "FINISH": END,
318
+ },
319
+ )
320
+ super_graph.set_entry_point("supervisor")
321
+ super_graph = super_graph.compile()
322
+
323
+ # Example usage
324
+ user_input = input("Enter your request for the social media post: ")
325
+
326
+ for s in super_graph.stream(
327
+ {
328
+ "messages": [
329
+ HumanMessage(content=user_input)
330
+ ],
331
+ },
332
+ {"recursion_limit": 50},
333
+ ):
334
+ if "__end__" not in s:
335
+ print(s)
336
+ print("---")