DrishtiSharma commited on
Commit
5b20995
Β·
verified Β·
1 Parent(s): 9144783

Create test.py

Browse files
Files changed (1) hide show
  1. test.py +133 -0
test.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from typing import TypedDict, Annotated
3
+ from langgraph.graph import StateGraph
4
+ from langgraph.checkpoint.memory import MemorySaver
5
+ from langgraph.graph.message import add_messages
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain_community.tools.tavily_search import TavilySearchResults
8
+ from langchain_core.messages import HumanMessage, ToolMessage, AIMessage
9
+ from langgraph.prebuilt import tools_condition
10
+ import os
11
+
12
+ # Streamlit UI Header
13
+ st.title("Checkpoints and Breakpoints")
14
+ st.caption("Demonstrating LangGraph workflow execution with interruptions and tool invocation.")
15
+
16
+ # Fetch API Keys
17
+ openai_api_key = os.getenv("OPENAI_API_KEY")
18
+ tavily_api_key = os.getenv("TAVILY_API_KEY")
19
+
20
+ if not openai_api_key or not tavily_api_key:
21
+ st.error("API keys are missing! Set OPENAI_API_KEY and TAVILY_API_KEY in Hugging Face Spaces Secrets.")
22
+ st.stop()
23
+
24
+ os.environ["OPENAI_API_KEY"] = openai_api_key
25
+ os.environ["TAVILY_API_KEY"] = tavily_api_key
26
+
27
+ # Define State Class
28
+ class State(TypedDict):
29
+ messages: Annotated[list, add_messages]
30
+
31
+ # Initialize LLM and Tools
32
+ llm = ChatOpenAI(model="gpt-4o-mini")
33
+ tool = TavilySearchResults(max_results=2)
34
+ llm_with_tools = llm.bind_tools([tool])
35
+
36
+ # Agent Node
37
+ def Agent(state: State):
38
+ st.sidebar.write("Agent Input State:", state["messages"])
39
+ response = llm_with_tools.invoke(state["messages"])
40
+ st.sidebar.write("Agent Response:", response)
41
+ return {"messages": [response]}
42
+
43
+ # Tools Execution Node
44
+ def ExecuteTools(state: State):
45
+ tool_calls = state["messages"][-1].tool_calls
46
+ responses = []
47
+
48
+ if tool_calls:
49
+ for call in tool_calls:
50
+ tool_name = call["name"]
51
+ args = call["args"]
52
+ st.sidebar.write("Tool Call Detected:", tool_name, args)
53
+
54
+ if tool_name == "tavily_search_results_json":
55
+ tool_response = tool.invoke({"query": args["query"]})
56
+ st.sidebar.write("Tool Response:", tool_response)
57
+ responses.append(ToolMessage(content=str(tool_response), tool_call_id=call["id"]))
58
+ return {"messages": responses}
59
+
60
+ # Memory Checkpoint
61
+ memory = MemorySaver()
62
+
63
+ # Build the Graph
64
+ graph = StateGraph(State)
65
+ graph.add_node("Agent", Agent)
66
+ graph.add_node("ExecuteTools", ExecuteTools)
67
+
68
+ # Add Conditional Edge to Check for Tools
69
+ def custom_tools_condition(state: State):
70
+ return "True" if state["messages"][-1].tool_calls else "False"
71
+
72
+ graph.add_conditional_edges("Agent", custom_tools_condition, {"True": "ExecuteTools", "False": "Agent"})
73
+ graph.add_edge("ExecuteTools", "Agent")
74
+ graph.set_entry_point("Agent")
75
+
76
+ # Compile the Graph
77
+ app = graph.compile(checkpointer=memory, interrupt_before=["ExecuteTools"])
78
+
79
+ # Display Graph Visualization
80
+ st.subheader("Graph Visualization")
81
+ st.image(app.get_graph().draw_mermaid_png(), caption="Workflow Graph", use_container_width=True)
82
+
83
+ # Run the Workflow
84
+ st.subheader("Run the Workflow")
85
+ user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh")
86
+ thread_id = st.text_input("Thread ID", "1")
87
+
88
+ if st.button("Execute Workflow"):
89
+ thread = {"configurable": {"thread_id": thread_id}}
90
+ input_message = {"messages": [HumanMessage(content=user_input)]}
91
+
92
+ st.write("### Execution Outputs")
93
+ outputs = []
94
+
95
+ try:
96
+ # Stream the graph execution
97
+ for event in app.stream(input_message, thread, stream_mode="values"):
98
+ output_message = event["messages"][-1]
99
+ st.code(output_message.content)
100
+ outputs.append(output_message.content)
101
+ st.sidebar.write("Intermediate State:", event["messages"])
102
+
103
+ # Display Intermediate Outputs
104
+ if outputs:
105
+ st.subheader("Intermediate Outputs")
106
+ for idx, output in enumerate(outputs, start=1):
107
+ st.write(f"**Step {idx}:**")
108
+ st.code(output)
109
+ else:
110
+ st.warning("No outputs generated. Check the workflow or tool calls.")
111
+
112
+ # Snapshot of Current State
113
+ st.subheader("Current State Snapshot")
114
+ snapshot = app.get_state(thread)
115
+ current_message = snapshot.values["messages"][-1]
116
+ st.code(current_message.pretty_print())
117
+
118
+ # Manual Update for Interrupted State
119
+ if hasattr(current_message, "tool_calls") and current_message.tool_calls:
120
+ tool_call_id = current_message.tool_calls[0]["id"]
121
+ manual_response = st.text_area("Manual Tool Response", "Enter response to continue execution...")
122
+ if st.button("Update State"):
123
+ new_messages = [
124
+ ToolMessage(content=manual_response, tool_call_id=tool_call_id),
125
+ AIMessage(content=manual_response),
126
+ ]
127
+ app.update_state(thread, {"messages": new_messages})
128
+ st.success("State updated successfully!")
129
+ st.code(app.get_state(thread).values["messages"][-1].pretty_print())
130
+ else:
131
+ st.warning("No tool calls detected to update the state.")
132
+ except Exception as e:
133
+ st.error(f"Error during execution: {e}")