SexBot / agent_workflow.py
Pew404's picture
Upload folder using huggingface_hub
13fbd2e verified
from typing import Any, List
from llama_index.core.agent.react import ReActChatFormatter, ReActOutputParser
from llama_index.core.agent.react.types import (
ActionReasoningStep,
ObservationReasoningStep,
)
from llama_index.core.llms.llm import LLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools.types import BaseTool
from llama_index.core.workflow import (
Context,
Workflow,
StartEvent,
StopEvent,
step,
)
from llama_index.llms.openai import OpenAI
from llama_index.core.llms import ChatMessage
from llama_index.core.tools import ToolSelection, ToolOutput
from llama_index.core.workflow import Event
from llama_index.llms.ollama import Ollama
from llama_index.core import Settings
llm = Ollama(model="HammerAI/neuraldaredevil-abliterated")
Settings.llm = llm
class PrepEvent(Event):
pass
class InputEvent(Event):
input: list[ChatMessage]
class ToolCallEvent(Event):
tool_calls: list[ToolSelection]
class FunctionOutputEvent(Event):
output: ToolOutput
class ReActAgent(Workflow):
def __init__(
self,
*args: Any,
llm: LLM | None = None,
tools: list[BaseTool] | None = None,
extra_context: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.tools = tools or []
self.llm = llm or OpenAI()
self.memory = ChatMemoryBuffer.from_defaults(llm=llm)
self.formatter = ReActChatFormatter(context=extra_context or "")
self.output_parser = ReActOutputParser()
self.sources = []
@step
async def new_user_msg(self, ctx: Context, ev: StartEvent) -> PrepEvent:
# clear sources
self.sources = []
# get user input
user_input = ev.input
user_msg = ChatMessage(role="user", content=user_input)
self.memory.put(user_msg)
# clear current reasoning
await ctx.set("current_reasoning", [])
return PrepEvent()
@step
async def prepare_chat_history(
self, ctx: Context, ev: PrepEvent
) -> InputEvent:
# get chat history
chat_history = self.memory.get()
current_reasoning = await ctx.get("current_reasoning", default=[])
llm_input = self.formatter.format(
self.tools, chat_history, current_reasoning=current_reasoning
)
return InputEvent(input=llm_input)
@step
async def handle_llm_input(
self, ctx: Context, ev: InputEvent
) -> ToolCallEvent | StopEvent:
chat_history = ev.input
response = await self.llm.achat(chat_history)
try:
reasoning_step = self.output_parser.parse(response.message.content)
(await ctx.get("current_reasoning", default=[])).append(
reasoning_step
)
if reasoning_step.is_done:
self.memory.put(
ChatMessage(
role="assistant", content=reasoning_step.response
)
)
print(reasoning_step.response)
return StopEvent(
result={
"response": reasoning_step.response,
"sources": [*self.sources],
"reasoning": await ctx.get(
"current_reasoning", default=[]
),
}
)
elif isinstance(reasoning_step, ActionReasoningStep):
tool_name = reasoning_step.action
tool_args = reasoning_step.action_input
print(f"current thought:{reasoning_step.thought} --{tool_name}({tool_args}")
return ToolCallEvent(
tool_calls=[
ToolSelection(
tool_id="fake",
tool_name=tool_name,
tool_kwargs=tool_args,
)
]
)
except Exception as e:
(await ctx.get("current_reasoning", default=[])).append(
ObservationReasoningStep(
observation=f"There was an error in parsing my reasoning: {e}"
)
)
# if no tool calls or final response, iterate again
return PrepEvent()
@step
async def handle_tool_calls(
self, ctx: Context, ev: ToolCallEvent
) -> PrepEvent:
tool_calls = ev.tool_calls
tools_by_name = {tool.metadata.get_name(): tool for tool in self.tools}
# call tools -- safely!
for tool_call in tool_calls:
tool = tools_by_name.get(tool_call.tool_name)
if not tool:
(await ctx.get("current_reasoning", default=[])).append(
ObservationReasoningStep(
observation=f"Tool {tool_call.tool_name} does not exist"
)
)
continue
try:
tool_output = tool(**tool_call.tool_kwargs)
print(f"Tool output: {tool_output}")
self.sources.append(tool_output)
(await ctx.get("current_reasoning", default=[])).append(
ObservationReasoningStep(observation=tool_output.content)
)
except Exception as e:
(await ctx.get("current_reasoning", default=[])).append(
ObservationReasoningStep(
observation=f"Error calling tool {tool.metadata.get_name()}: {e}"
)
)
# prep the next iteraiton
return PrepEvent()
from llama_index.core.tools import FunctionTool
from llama_index.llms.openai import OpenAI
from workflow.modules import *
async def main():
tools = [
FunctionTool.from_defaults(video_search),
FunctionTool.from_defaults(general_search),
]
agent = ReActAgent(
llm=llm, tools=tools, timeout=120, verbose=True
)
ret = await agent.run(input="How to build trust with partner?")
if __name__ == "__main__":
import asyncio
asyncio.run(main())