File size: 4,245 Bytes
8c88d5f 6b1ac2e 8c88d5f 3b5459d 8c88d5f 2b42de2 58db0e0 8c88d5f 71a90c7 8c88d5f 58db0e0 3b5459d 58db0e0 9144783 58db0e0 3b5459d 58db0e0 3b5459d 58db0e0 3b5459d 58db0e0 3b5459d 8c88d5f 3b5459d 58db0e0 bfb9e67 3b5459d 8c88d5f 58db0e0 8c88d5f 3b5459d 58db0e0 8c88d5f 58db0e0 8c88d5f 58db0e0 8c88d5f 3b5459d 58db0e0 8c88d5f 3b5459d 65c0aec 58db0e0 2b42de2 6b1ac2e 2b42de2 3b5459d 2b42de2 3b5459d 8c88d5f 6b1ac2e 3b5459d 6b1ac2e 3b5459d 6b1ac2e 3b5459d 6b1ac2e 2b42de2 3b5459d 58db0e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import streamlit as st
from typing import TypedDict, Annotated
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, ToolMessage, AIMessage
from langgraph.prebuilt import ToolNode, tools_condition
import os
# Streamlit UI Header
st.title("Checkpoints and Breakpoints")
st.caption("Demonstrating LangGraph workflow execution with interruptions and tool invocation.")
# Fetch API Keys
openai_api_key = os.getenv("OPENAI_API_KEY")
tavily_api_key = os.getenv("TAVILY_API_KEY")
if not openai_api_key or not tavily_api_key:
st.error("API keys are missing! Set OPENAI_API_KEY and TAVILY_API_KEY in Hugging Face Spaces Secrets.")
st.stop()
os.environ["OPENAI_API_KEY"] = openai_api_key
os.environ["TAVILY_API_KEY"] = tavily_api_key
# Define State Class
class State(TypedDict):
messages: Annotated[list, add_messages]
# Initialize LLM and Tools
llm = ChatOpenAI(model="gpt-4o-mini")
tool = TavilySearchResults(max_results=2)
tools = [tool]
llm_with_tools = llm.bind_tools(tools)
# Agent Node
def Agent(state: State):
st.sidebar.write("Agent received input:", state["messages"])
response = llm_with_tools.invoke(state["messages"])
st.sidebar.write("Agent Response:", response)
return {"messages": [response]}
# Set up Graph
memory = MemorySaver()
graph = StateGraph(State)
# Add nodes
graph.add_node("Agent", Agent)
tool_node = ToolNode(tools=[tool])
graph.add_node("tools", tool_node)
# Add edges
graph.add_conditional_edges("Agent", tools_condition)
graph.add_edge("tools", "Agent")
graph.set_entry_point("Agent")
# Compile with Breakpoint
app = graph.compile(checkpointer=memory, interrupt_before=["tools"])
# Display Graph Visualization
st.subheader("Graph Visualization")
st.image(app.get_graph().draw_mermaid_png(), caption="Workflow Graph", use_container_width=True)
# Input Section
st.subheader("Run the Workflow")
user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh")
thread_id = st.text_input("Thread ID", "1")
if st.button("Execute Workflow"):
thread = {"configurable": {"thread_id": thread_id}}
input_message = {"messages": [HumanMessage(content=user_input)]}
st.write("### Execution Outputs")
outputs = []
# Execute the workflow
try:
for event in app.stream(input_message, thread, stream_mode="values"):
st.code(event["messages"][-1].content)
outputs.append(event["messages"][-1].content)
# Display Intermediate Outputs
if outputs:
st.subheader("Intermediate Outputs")
for idx, output in enumerate(outputs, start=1):
st.write(f"**Step {idx}:**")
st.code(output)
else:
st.warning("No outputs generated yet.")
# Show State Snapshot
st.subheader("Current State Snapshot")
snapshot = app.get_state(thread)
current_message = snapshot.values["messages"][-1]
st.code(current_message.pretty_print())
# Handle Tool Calls with Manual Input
if hasattr(current_message, "tool_calls") and current_message.tool_calls:
tool_call_id = current_message.tool_calls[0]["id"]
st.warning("Execution paused before tool execution. Provide manual input to resume.")
manual_response = st.text_area("Manual Tool Response", "Enter the tool's response here...")
if st.button("Resume Execution"):
new_messages = [
ToolMessage(content=manual_response, tool_call_id=tool_call_id),
AIMessage(content=manual_response),
]
app.update_state(thread, {"messages": new_messages})
st.success("State updated! Rerun the workflow to continue.")
st.code(app.get_state(thread).values["messages"][-1].pretty_print())
else:
st.info("No tool calls detected at this step.")
except Exception as e:
st.error(f"Error during execution: {e}")
|