|
import os |
|
import streamlit as st |
|
import json |
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.tools import tool |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langgraph.graph import StateGraph, END |
|
from typing import TypedDict, Annotated, Sequence |
|
from langchain_core.messages import BaseMessage |
|
import operator |
|
import networkx as nx |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") |
|
|
|
if not OPENAI_API_KEY or not TAVILY_API_KEY: |
|
st.error("API keys not found. Please set OPENAI_API_KEY and TAVILY_API_KEY as environment variables.") |
|
st.stop() |
|
|
|
|
|
model = ChatOpenAI(temperature=0) |
|
|
|
|
|
@tool |
|
def multiply(first_number: int, second_number: int) -> int: |
|
"""Multiplies two integers together.""" |
|
return first_number * second_number |
|
|
|
@tool |
|
def search(query: str): |
|
"""Performs web search on the user query.""" |
|
tavily = TavilySearchResults(max_results=1) |
|
result = tavily.invoke(query) |
|
return result |
|
|
|
tools = [search, multiply] |
|
tool_map = {tool.name: tool for tool in tools} |
|
|
|
model_with_tools = model.bind_tools(tools) |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[Sequence[BaseMessage], operator.add] |
|
|
|
|
|
def invoke_model(state): |
|
messages = state['messages'] |
|
question = messages[-1] |
|
return {"messages": [model_with_tools.invoke(question)]} |
|
|
|
def invoke_tool(state): |
|
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", []) |
|
tool_details = None |
|
|
|
for tool_call in tool_calls: |
|
tool_details = tool_call |
|
|
|
if tool_details is None: |
|
raise Exception("No tool input found.") |
|
|
|
selected_tool = tool_details.get("function").get("name") |
|
st.sidebar.write(f"Selected tool: {selected_tool}") |
|
|
|
if selected_tool == "search": |
|
if 'human_loop' in st.session_state and st.session_state['human_loop']: |
|
response = st.sidebar.radio("Proceed with web search?", ["Yes", "No"]) |
|
if response == "No": |
|
raise ValueError("User canceled the search tool execution.") |
|
|
|
response = tool_map[selected_tool].invoke(json.loads(tool_details.get("function").get("arguments"))) |
|
return {"messages": [response]} |
|
|
|
def router(state): |
|
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", []) |
|
if len(tool_calls): |
|
return "tool" |
|
else: |
|
return "end" |
|
|
|
|
|
graph = StateGraph(AgentState) |
|
graph.add_node("agent", invoke_model) |
|
graph.add_node("tool", invoke_tool) |
|
graph.add_conditional_edges("agent", router, {"tool": "tool", "end": END}) |
|
graph.add_edge("tool", END) |
|
graph.set_entry_point("agent") |
|
compiled_app = graph.compile() |
|
|
|
|
|
def render_graph_nx(graph): |
|
G = nx.DiGraph() |
|
G.add_edge("agent", "tool", label="invoke tool") |
|
G.add_edge("agent", "end", label="end condition") |
|
G.add_edge("tool", "end", label="finish") |
|
|
|
pos = nx.spring_layout(G, seed=42) |
|
plt.figure(figsize=(8, 6)) |
|
nx.draw(G, pos, with_labels=True, node_color="lightblue", node_size=3000, font_size=10, font_weight="bold") |
|
edge_labels = nx.get_edge_attributes(G, "label") |
|
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=9) |
|
plt.title("Workflow Graph") |
|
st.pyplot(plt) |
|
|
|
|
|
st.title("LLM Tool Workflow Demo") |
|
st.write("This app demonstrates LLM-based tool usage with and without human intervention.") |
|
|
|
|
|
st.sidebar.header("Configuration") |
|
st.session_state['human_loop'] = st.sidebar.checkbox("Enable Human-in-the-Loop (For Search)", value=False) |
|
|
|
|
|
prompt = st.text_input("Enter your question:", "What is 24 * 365?") |
|
if st.button("Run Workflow"): |
|
st.subheader("Execution Results") |
|
try: |
|
intermediate_outputs = [] |
|
for s in compiled_app.stream({"messages": [prompt]}): |
|
intermediate_outputs.append(s) |
|
st.write("Response:", list(s.values())[0]) |
|
st.write("---") |
|
|
|
st.sidebar.write("### Intermediate Outputs") |
|
for i, output in enumerate(intermediate_outputs): |
|
st.sidebar.write(f"Step {i+1}: {output}") |
|
except Exception as e: |
|
st.error(f"Error occurred: {e}") |
|
|
|
|
|
st.subheader("Workflow Graph") |
|
render_graph_nx(graph) |