import streamlit as st from typing import TypedDict, Annotated from langgraph.graph import StateGraph from langgraph.checkpoint.memory import MemorySaver from langgraph.graph.message import add_messages from langchain_openai import ChatOpenAI from langchain_community.tools.tavily_search import TavilySearchResults from langchain_core.messages import HumanMessage, ToolMessage, AIMessage from langgraph.prebuilt import tools_condition import os # Streamlit UI Header st.title("Checkpoints and Breakpoints") st.caption("Demonstrating LangGraph workflow execution with interruptions and tool invocation.") # Fetch API Keys openai_api_key = os.getenv("OPENAI_API_KEY") tavily_api_key = os.getenv("TAVILY_API_KEY") if not openai_api_key or not tavily_api_key: st.error("API keys are missing! Set OPENAI_API_KEY and TAVILY_API_KEY in Hugging Face Spaces Secrets.") st.stop() os.environ["OPENAI_API_KEY"] = openai_api_key os.environ["TAVILY_API_KEY"] = tavily_api_key # Define State Class class State(TypedDict): messages: Annotated[list, add_messages] # Initialize LLM and Tools llm = ChatOpenAI(model="gpt-4o-mini") tool = TavilySearchResults(max_results=2) llm_with_tools = llm.bind_tools([tool]) # Agent Node def Agent(state: State): st.sidebar.write("Agent Input State:", state["messages"]) response = llm_with_tools.invoke(state["messages"]) st.sidebar.write("Agent Response:", response) return {"messages": [response]} # Tools Execution Node def ExecuteTools(state: State): tool_calls = state["messages"][-1].tool_calls responses = [] if tool_calls: for call in tool_calls: tool_name = call["name"] args = call["args"] st.sidebar.write("Tool Call Detected:", tool_name, args) if tool_name == "tavily_search_results_json": tool_response = tool.invoke({"query": args["query"]}) st.sidebar.write("Tool Response:", tool_response) responses.append(ToolMessage(content=str(tool_response), tool_call_id=call["id"])) return {"messages": responses} # Memory Checkpoint memory = MemorySaver() # Build the Graph graph = StateGraph(State) graph.add_node("Agent", Agent) graph.add_node("ExecuteTools", ExecuteTools) # Add Conditional Edge to Check for Tools def custom_tools_condition(state: State): return "True" if state["messages"][-1].tool_calls else "False" graph.add_conditional_edges("Agent", custom_tools_condition, {"True": "ExecuteTools", "False": "Agent"}) graph.add_edge("ExecuteTools", "Agent") graph.set_entry_point("Agent") # Compile the Graph app = graph.compile(checkpointer=memory, interrupt_before=["ExecuteTools"]) # Display Graph Visualization st.subheader("Graph Visualization") st.image(app.get_graph().draw_mermaid_png(), caption="Workflow Graph", use_container_width=True) # Run the Workflow st.subheader("Run the Workflow") user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh") thread_id = st.text_input("Thread ID", "1") if st.button("Execute Workflow"): thread = {"configurable": {"thread_id": thread_id}} input_message = {"messages": [HumanMessage(content=user_input)]} st.write("### Execution Outputs") outputs = [] try: # Stream the graph execution for event in app.stream(input_message, thread, stream_mode="values"): output_message = event["messages"][-1] st.code(output_message.content) outputs.append(output_message.content) st.sidebar.write("Intermediate State:", event["messages"]) # Display Intermediate Outputs if outputs: st.subheader("Intermediate Outputs") for idx, output in enumerate(outputs, start=1): st.write(f"**Step {idx}:**") st.code(output) else: st.warning("No outputs generated. Check the workflow or tool calls.") # Snapshot of Current State st.subheader("Current State Snapshot") snapshot = app.get_state(thread) current_message = snapshot.values["messages"][-1] st.code(current_message.pretty_print()) # Manual Update for Interrupted State if hasattr(current_message, "tool_calls") and current_message.tool_calls: tool_call_id = current_message.tool_calls[0]["id"] manual_response = st.text_area("Manual Tool Response", "Enter response to continue execution...") if st.button("Update State"): new_messages = [ ToolMessage(content=manual_response, tool_call_id=tool_call_id), AIMessage(content=manual_response), ] app.update_state(thread, {"messages": new_messages}) st.success("State updated successfully!") st.code(app.get_state(thread).values["messages"][-1].pretty_print()) else: st.warning("No tool calls detected to update the state.") except Exception as e: st.error(f"Error during execution: {e}")