DrishtiSharma commited on
Commit
3b5459d
Β·
verified Β·
1 Parent(s): ed49154

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -41
app.py CHANGED
@@ -6,7 +6,7 @@ from langgraph.graph.message import add_messages
6
  from langchain_openai import ChatOpenAI
7
  from langchain_community.tools.tavily_search import TavilySearchResults
8
  from langchain_core.messages import HumanMessage, ToolMessage, AIMessage
9
- from langgraph.prebuilt import tools_condition
10
  import os
11
 
12
  # Streamlit UI Header
@@ -31,52 +31,38 @@ class State(TypedDict):
31
  # Initialize LLM and Tools
32
  llm = ChatOpenAI(model="gpt-4o-mini")
33
  tool = TavilySearchResults(max_results=2)
34
- llm_with_tools = llm.bind_tools([tool])
 
35
 
36
  # Agent Node
37
  def Agent(state: State):
38
- st.sidebar.write("Agent Input State:", state["messages"])
39
  response = llm_with_tools.invoke(state["messages"])
40
  st.sidebar.write("Agent Response:", response)
41
  return {"messages": [response]}
42
 
43
- # Tools Execution Node
44
- def ExecuteTools(state: State):
45
- tool_calls = state["messages"][-1].tool_calls
46
- responses = []
47
-
48
- if tool_calls:
49
- for call in tool_calls:
50
- tool_name = call["name"]
51
- args = call["args"]
52
- st.sidebar.write("Tool Call Detected:", tool_name, args)
53
-
54
- if tool_name == "tavily_search_results_json":
55
- tool_response = tool.invoke({"query": args["query"]})
56
- st.sidebar.write("Tool Response:", tool_response)
57
- responses.append(ToolMessage(content=str(tool_response), tool_call_id=call["id"]))
58
- return {"messages": responses}
59
-
60
- # Memory Checkpoint
61
  memory = MemorySaver()
62
-
63
- # Build the Graph
64
  graph = StateGraph(State)
 
 
65
  graph.add_node("Agent", Agent)
66
- graph.add_node("ExecuteTools", ExecuteTools)
 
67
 
68
- graph.add_conditional_edges("Agent", tools_condition, {"True": "ExecuteTools", "False": "Agent"})
69
- graph.add_edge("ExecuteTools", "Agent")
 
70
  graph.set_entry_point("Agent")
71
 
72
- # Compile the Graph
73
- app = graph.compile(checkpointer=memory, interrupt_before=["ExecuteTools"])
74
 
75
  # Display Graph Visualization
76
  st.subheader("Graph Visualization")
77
  st.image(app.get_graph().draw_mermaid_png(), caption="Workflow Graph", use_container_width=True)
78
 
79
- # Run the Workflow
80
  st.subheader("Run the Workflow")
81
  user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh")
82
  thread_id = st.text_input("Thread ID", "1")
@@ -88,13 +74,11 @@ if st.button("Execute Workflow"):
88
  st.write("### Execution Outputs")
89
  outputs = []
90
 
 
91
  try:
92
- # Stream the graph execution
93
  for event in app.stream(input_message, thread, stream_mode="values"):
94
- output_message = event["messages"][-1]
95
- st.code(output_message.content)
96
- outputs.append(output_message.content)
97
- st.sidebar.write("Intermediate State:", event["messages"])
98
 
99
  # Display Intermediate Outputs
100
  if outputs:
@@ -103,27 +87,28 @@ if st.button("Execute Workflow"):
103
  st.write(f"**Step {idx}:**")
104
  st.code(output)
105
  else:
106
- st.warning("No outputs generated. Check the workflow or tool calls.")
107
 
108
- # Snapshot of Current State
109
  st.subheader("Current State Snapshot")
110
  snapshot = app.get_state(thread)
111
  current_message = snapshot.values["messages"][-1]
112
  st.code(current_message.pretty_print())
113
 
114
- # Manual Update for Interrupted State
115
  if hasattr(current_message, "tool_calls") and current_message.tool_calls:
116
  tool_call_id = current_message.tool_calls[0]["id"]
117
- manual_response = st.text_area("Manual Tool Response", "Enter response to continue execution...")
118
- if st.button("Update State"):
 
119
  new_messages = [
120
  ToolMessage(content=manual_response, tool_call_id=tool_call_id),
121
  AIMessage(content=manual_response),
122
  ]
123
  app.update_state(thread, {"messages": new_messages})
124
- st.success("State updated successfully!")
125
  st.code(app.get_state(thread).values["messages"][-1].pretty_print())
126
  else:
127
- st.warning("No tool calls detected to update the state.")
128
  except Exception as e:
129
  st.error(f"Error during execution: {e}")
 
6
  from langchain_openai import ChatOpenAI
7
  from langchain_community.tools.tavily_search import TavilySearchResults
8
  from langchain_core.messages import HumanMessage, ToolMessage, AIMessage
9
+ from langgraph.prebuilt import ToolNode, tools_condition
10
  import os
11
 
12
  # Streamlit UI Header
 
31
  # Initialize LLM and Tools
32
  llm = ChatOpenAI(model="gpt-4o-mini")
33
  tool = TavilySearchResults(max_results=2)
34
+ tools = [tool]
35
+ llm_with_tools = llm.bind_tools(tools)
36
 
37
  # Agent Node
38
  def Agent(state: State):
39
+ st.sidebar.write("Agent received input:", state["messages"])
40
  response = llm_with_tools.invoke(state["messages"])
41
  st.sidebar.write("Agent Response:", response)
42
  return {"messages": [response]}
43
 
44
+ # Set up Graph
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  memory = MemorySaver()
 
 
46
  graph = StateGraph(State)
47
+
48
+ # Add nodes
49
  graph.add_node("Agent", Agent)
50
+ tool_node = ToolNode(tools=[tool])
51
+ graph.add_node("tools", tool_node)
52
 
53
+ # Add edges
54
+ graph.add_conditional_edges("Agent", tools_condition)
55
+ graph.add_edge("tools", "Agent")
56
  graph.set_entry_point("Agent")
57
 
58
+ # Compile with Breakpoint
59
+ app = graph.compile(checkpointer=memory, interrupt_before=["tools"])
60
 
61
  # Display Graph Visualization
62
  st.subheader("Graph Visualization")
63
  st.image(app.get_graph().draw_mermaid_png(), caption="Workflow Graph", use_container_width=True)
64
 
65
+ # Input Section
66
  st.subheader("Run the Workflow")
67
  user_input = st.text_input("Enter a message to start the graph:", "Search for the weather in Uttar Pradesh")
68
  thread_id = st.text_input("Thread ID", "1")
 
74
  st.write("### Execution Outputs")
75
  outputs = []
76
 
77
+ # Execute the workflow
78
  try:
 
79
  for event in app.stream(input_message, thread, stream_mode="values"):
80
+ st.code(event["messages"][-1].content)
81
+ outputs.append(event["messages"][-1].content)
 
 
82
 
83
  # Display Intermediate Outputs
84
  if outputs:
 
87
  st.write(f"**Step {idx}:**")
88
  st.code(output)
89
  else:
90
+ st.warning("No outputs generated yet.")
91
 
92
+ # Show State Snapshot
93
  st.subheader("Current State Snapshot")
94
  snapshot = app.get_state(thread)
95
  current_message = snapshot.values["messages"][-1]
96
  st.code(current_message.pretty_print())
97
 
98
+ # Handle Tool Calls with Manual Input
99
  if hasattr(current_message, "tool_calls") and current_message.tool_calls:
100
  tool_call_id = current_message.tool_calls[0]["id"]
101
+ st.warning("Execution paused before tool execution. Provide manual input to resume.")
102
+ manual_response = st.text_area("Manual Tool Response", "Enter the tool's response here...")
103
+ if st.button("Resume Execution"):
104
  new_messages = [
105
  ToolMessage(content=manual_response, tool_call_id=tool_call_id),
106
  AIMessage(content=manual_response),
107
  ]
108
  app.update_state(thread, {"messages": new_messages})
109
+ st.success("State updated! Rerun the workflow to continue.")
110
  st.code(app.get_state(thread).values["messages"][-1].pretty_print())
111
  else:
112
+ st.info("No tool calls detected at this step.")
113
  except Exception as e:
114
  st.error(f"Error during execution: {e}")