|
from typing import Any, List |
|
|
|
from llama_index.core.agent.react import ReActChatFormatter, ReActOutputParser |
|
from llama_index.core.agent.react.types import ( |
|
ActionReasoningStep, |
|
ObservationReasoningStep, |
|
) |
|
from llama_index.core.llms.llm import LLM |
|
from llama_index.core.memory import ChatMemoryBuffer |
|
from llama_index.core.tools.types import BaseTool |
|
from llama_index.core.workflow import ( |
|
Context, |
|
Workflow, |
|
StartEvent, |
|
StopEvent, |
|
step, |
|
) |
|
from llama_index.llms.openai import OpenAI |
|
from llama_index.core.llms import ChatMessage |
|
from llama_index.core.tools import ToolSelection, ToolOutput |
|
from llama_index.core.workflow import Event |
|
from llama_index.llms.ollama import Ollama |
|
from llama_index.core import Settings |
|
|
|
llm = Ollama(model="HammerAI/neuraldaredevil-abliterated") |
|
Settings.llm = llm |
|
|
|
class PrepEvent(Event): |
|
pass |
|
|
|
|
|
class InputEvent(Event): |
|
input: list[ChatMessage] |
|
|
|
|
|
class ToolCallEvent(Event): |
|
tool_calls: list[ToolSelection] |
|
|
|
|
|
class FunctionOutputEvent(Event): |
|
output: ToolOutput |
|
|
|
|
|
class ReActAgent(Workflow): |
|
def __init__( |
|
self, |
|
*args: Any, |
|
llm: LLM | None = None, |
|
tools: list[BaseTool] | None = None, |
|
extra_context: str | None = None, |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__(*args, **kwargs) |
|
self.tools = tools or [] |
|
|
|
self.llm = llm or OpenAI() |
|
|
|
self.memory = ChatMemoryBuffer.from_defaults(llm=llm) |
|
self.formatter = ReActChatFormatter(context=extra_context or "") |
|
self.output_parser = ReActOutputParser() |
|
self.sources = [] |
|
|
|
@step |
|
async def new_user_msg(self, ctx: Context, ev: StartEvent) -> PrepEvent: |
|
|
|
self.sources = [] |
|
|
|
|
|
user_input = ev.input |
|
user_msg = ChatMessage(role="user", content=user_input) |
|
self.memory.put(user_msg) |
|
|
|
|
|
await ctx.set("current_reasoning", []) |
|
|
|
return PrepEvent() |
|
|
|
@step |
|
async def prepare_chat_history( |
|
self, ctx: Context, ev: PrepEvent |
|
) -> InputEvent: |
|
|
|
chat_history = self.memory.get() |
|
current_reasoning = await ctx.get("current_reasoning", default=[]) |
|
llm_input = self.formatter.format( |
|
self.tools, chat_history, current_reasoning=current_reasoning |
|
) |
|
return InputEvent(input=llm_input) |
|
|
|
@step |
|
async def handle_llm_input( |
|
self, ctx: Context, ev: InputEvent |
|
) -> ToolCallEvent | StopEvent: |
|
chat_history = ev.input |
|
|
|
response = await self.llm.achat(chat_history) |
|
|
|
try: |
|
reasoning_step = self.output_parser.parse(response.message.content) |
|
(await ctx.get("current_reasoning", default=[])).append( |
|
reasoning_step |
|
) |
|
if reasoning_step.is_done: |
|
self.memory.put( |
|
ChatMessage( |
|
role="assistant", content=reasoning_step.response |
|
) |
|
) |
|
print(reasoning_step.response) |
|
return StopEvent( |
|
result={ |
|
"response": reasoning_step.response, |
|
"sources": [*self.sources], |
|
"reasoning": await ctx.get( |
|
"current_reasoning", default=[] |
|
), |
|
} |
|
) |
|
elif isinstance(reasoning_step, ActionReasoningStep): |
|
tool_name = reasoning_step.action |
|
tool_args = reasoning_step.action_input |
|
print(f"current thought:{reasoning_step.thought} --{tool_name}({tool_args}") |
|
return ToolCallEvent( |
|
tool_calls=[ |
|
ToolSelection( |
|
tool_id="fake", |
|
tool_name=tool_name, |
|
tool_kwargs=tool_args, |
|
) |
|
] |
|
) |
|
except Exception as e: |
|
(await ctx.get("current_reasoning", default=[])).append( |
|
ObservationReasoningStep( |
|
observation=f"There was an error in parsing my reasoning: {e}" |
|
) |
|
) |
|
|
|
|
|
return PrepEvent() |
|
|
|
@step |
|
async def handle_tool_calls( |
|
self, ctx: Context, ev: ToolCallEvent |
|
) -> PrepEvent: |
|
tool_calls = ev.tool_calls |
|
tools_by_name = {tool.metadata.get_name(): tool for tool in self.tools} |
|
|
|
|
|
for tool_call in tool_calls: |
|
tool = tools_by_name.get(tool_call.tool_name) |
|
if not tool: |
|
(await ctx.get("current_reasoning", default=[])).append( |
|
ObservationReasoningStep( |
|
observation=f"Tool {tool_call.tool_name} does not exist" |
|
) |
|
) |
|
continue |
|
|
|
try: |
|
tool_output = tool(**tool_call.tool_kwargs) |
|
print(f"Tool output: {tool_output}") |
|
self.sources.append(tool_output) |
|
(await ctx.get("current_reasoning", default=[])).append( |
|
ObservationReasoningStep(observation=tool_output.content) |
|
) |
|
except Exception as e: |
|
(await ctx.get("current_reasoning", default=[])).append( |
|
ObservationReasoningStep( |
|
observation=f"Error calling tool {tool.metadata.get_name()}: {e}" |
|
) |
|
) |
|
|
|
|
|
return PrepEvent() |
|
|
|
from llama_index.core.tools import FunctionTool |
|
from llama_index.llms.openai import OpenAI |
|
from workflow.modules import * |
|
|
|
|
|
|
|
async def main(): |
|
tools = [ |
|
FunctionTool.from_defaults(video_search), |
|
FunctionTool.from_defaults(general_search), |
|
] |
|
|
|
agent = ReActAgent( |
|
llm=llm, tools=tools, timeout=120, verbose=True |
|
) |
|
|
|
ret = await agent.run(input="How to build trust with partner?") |
|
|
|
if __name__ == "__main__": |
|
import asyncio |
|
asyncio.run(main()) |
|
|
|
|