DrishtiSharma's picture
Update app.py
65c0aec verified
raw
history blame
4.54 kB
import streamlit as st
from typing import TypedDict, Annotated
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, ToolMessage, AIMessage
from langgraph.prebuilt import ToolNode, tools_condition
import os
# Streamlit UI Header
st.title("Checkpoints and Breakpoints")
st.caption("Demonstrating workflow execution with checkpoints and tool invocation.")
# Fetch API Keys
openai_api_key = os.getenv("OPENAI_API_KEY")
tavily_api_key = os.getenv("TAVILY_API_KEY")
if openai_api_key and tavily_api_key:
os.environ["OPENAI_API_KEY"] = openai_api_key
os.environ["TAVILY_API_KEY"] = tavily_api_key
# Define State Class
class State(TypedDict):
messages: Annotated[list, add_messages]
# Initialize LLM and Tools
llm = ChatOpenAI(model="gpt-4o-mini")
tool = TavilySearchResults(max_results=2)
tools = [tool]
llm_with_tools = llm.bind_tools(tools)
# Agent Function
def Agent(state: State):
print("Agent received state:", state)
# Force tool invocation
response = llm_with_tools.invoke(state["messages"])
print("Agent Response:", response)
return {"messages": [response]}
# Memory Checkpoint
memory = MemorySaver()
# Graph Definition
graph = StateGraph(State)
tool_node = ToolNode(tools=[tool])
graph.add_node("Agent", Agent)
graph.add_node("tools", tool_node)
# Corrected Conditional Edge
def always_true(state):
return "True" # Return string "True"
graph.add_conditional_edges("Agent", always_true, {"True": "tools"})
graph.add_edge("tools", "Agent")
graph.set_entry_point("Agent")
# Compile Graph
app = graph.compile(checkpointer=memory, interrupt_before=["tools"])
# Display Graph Visualization
st.subheader("Graph Workflow")
st.image(app.get_graph().draw_mermaid_png(), caption="Graph Visualization", use_container_width=True)
# Input Section
st.subheader("Run the Workflow")
user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh")
thread_id = st.text_input("Thread ID", "1")
if st.button("Execute Workflow"):
thread = {"configurable": {"thread_id": thread_id}}
input_message = {'messages': HumanMessage(content=user_input)}
st.write("### Execution Outputs")
outputs = []
for event in app.stream(input_message, thread, stream_mode="values"):
if "messages" in event and event["messages"]:
latest_message = event["messages"][-1].pretty_print()
outputs.append(latest_message)
st.code(latest_message)
if outputs:
st.subheader("Intermediate Outputs")
for i, output in enumerate(outputs, 1):
st.write(f"**Step {i}:**")
st.code(output)
else:
st.warning("No outputs generated. Adjust your input to trigger tools.")
# Snapshot of Current State
st.subheader("Current State Snapshot")
snapshot = app.get_state(thread)
if snapshot.values["messages"]:
current_message = snapshot.values["messages"][-1]
st.code(current_message.pretty_print())
# Safe Access to Tool Calls
if hasattr(current_message, "tool_calls") and current_message.tool_calls:
tool_call_id = current_message.tool_calls[0]["id"]
manual_response = st.text_area("Manual Tool Response", "Enter response to continue...")
if st.button("Update State"):
new_messages = [
ToolMessage(content=manual_response, tool_call_id=tool_call_id),
AIMessage(content=manual_response),
]
app.update_state(thread, {"messages": new_messages})
st.success("State updated successfully!")
st.code(app.get_state(thread).values["messages"][-1].pretty_print())
else:
st.warning("No tool calls available for manual updates.")
else:
st.warning("No state messages available.")
else:
st.error("API keys are missing! Please set `OPENAI_API_KEY` and `TAVILY_API_KEY` in Hugging Face Spaces Secrets.")