File size: 6,233 Bytes
13fbd2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
from typing import Any, List
from llama_index.core.agent.react import ReActChatFormatter, ReActOutputParser
from llama_index.core.agent.react.types import (
ActionReasoningStep,
ObservationReasoningStep,
)
from llama_index.core.llms.llm import LLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools.types import BaseTool
from llama_index.core.workflow import (
Context,
Workflow,
StartEvent,
StopEvent,
step,
)
from llama_index.llms.openai import OpenAI
from llama_index.core.llms import ChatMessage
from llama_index.core.tools import ToolSelection, ToolOutput
from llama_index.core.workflow import Event
from llama_index.llms.ollama import Ollama
from llama_index.core import Settings
llm = Ollama(model="HammerAI/neuraldaredevil-abliterated")
Settings.llm = llm
class PrepEvent(Event):
pass
class InputEvent(Event):
input: list[ChatMessage]
class ToolCallEvent(Event):
tool_calls: list[ToolSelection]
class FunctionOutputEvent(Event):
output: ToolOutput
class ReActAgent(Workflow):
def __init__(
self,
*args: Any,
llm: LLM | None = None,
tools: list[BaseTool] | None = None,
extra_context: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.tools = tools or []
self.llm = llm or OpenAI()
self.memory = ChatMemoryBuffer.from_defaults(llm=llm)
self.formatter = ReActChatFormatter(context=extra_context or "")
self.output_parser = ReActOutputParser()
self.sources = []
@step
async def new_user_msg(self, ctx: Context, ev: StartEvent) -> PrepEvent:
# clear sources
self.sources = []
# get user input
user_input = ev.input
user_msg = ChatMessage(role="user", content=user_input)
self.memory.put(user_msg)
# clear current reasoning
await ctx.set("current_reasoning", [])
return PrepEvent()
@step
async def prepare_chat_history(
self, ctx: Context, ev: PrepEvent
) -> InputEvent:
# get chat history
chat_history = self.memory.get()
current_reasoning = await ctx.get("current_reasoning", default=[])
llm_input = self.formatter.format(
self.tools, chat_history, current_reasoning=current_reasoning
)
return InputEvent(input=llm_input)
@step
async def handle_llm_input(
self, ctx: Context, ev: InputEvent
) -> ToolCallEvent | StopEvent:
chat_history = ev.input
response = await self.llm.achat(chat_history)
try:
reasoning_step = self.output_parser.parse(response.message.content)
(await ctx.get("current_reasoning", default=[])).append(
reasoning_step
)
if reasoning_step.is_done:
self.memory.put(
ChatMessage(
role="assistant", content=reasoning_step.response
)
)
print(reasoning_step.response)
return StopEvent(
result={
"response": reasoning_step.response,
"sources": [*self.sources],
"reasoning": await ctx.get(
"current_reasoning", default=[]
),
}
)
elif isinstance(reasoning_step, ActionReasoningStep):
tool_name = reasoning_step.action
tool_args = reasoning_step.action_input
print(f"current thought:{reasoning_step.thought} --{tool_name}({tool_args}")
return ToolCallEvent(
tool_calls=[
ToolSelection(
tool_id="fake",
tool_name=tool_name,
tool_kwargs=tool_args,
)
]
)
except Exception as e:
(await ctx.get("current_reasoning", default=[])).append(
ObservationReasoningStep(
observation=f"There was an error in parsing my reasoning: {e}"
)
)
# if no tool calls or final response, iterate again
return PrepEvent()
@step
async def handle_tool_calls(
self, ctx: Context, ev: ToolCallEvent
) -> PrepEvent:
tool_calls = ev.tool_calls
tools_by_name = {tool.metadata.get_name(): tool for tool in self.tools}
# call tools -- safely!
for tool_call in tool_calls:
tool = tools_by_name.get(tool_call.tool_name)
if not tool:
(await ctx.get("current_reasoning", default=[])).append(
ObservationReasoningStep(
observation=f"Tool {tool_call.tool_name} does not exist"
)
)
continue
try:
tool_output = tool(**tool_call.tool_kwargs)
print(f"Tool output: {tool_output}")
self.sources.append(tool_output)
(await ctx.get("current_reasoning", default=[])).append(
ObservationReasoningStep(observation=tool_output.content)
)
except Exception as e:
(await ctx.get("current_reasoning", default=[])).append(
ObservationReasoningStep(
observation=f"Error calling tool {tool.metadata.get_name()}: {e}"
)
)
# prep the next iteraiton
return PrepEvent()
from llama_index.core.tools import FunctionTool
from llama_index.llms.openai import OpenAI
from workflow.modules import *
async def main():
tools = [
FunctionTool.from_defaults(video_search),
FunctionTool.from_defaults(general_search),
]
agent = ReActAgent(
llm=llm, tools=tools, timeout=120, verbose=True
)
ret = await agent.run(input="How to build trust with partner?")
if __name__ == "__main__":
import asyncio
asyncio.run(main())
|