DrishtiSharma commited on
Commit
6988bcd
·
verified ·
1 Parent(s): 088adc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -8,9 +8,8 @@ from langgraph.graph import StateGraph, END
8
  from typing import TypedDict, Annotated, Sequence
9
  from langchain_core.messages import BaseMessage
10
  import operator
11
- import pygraphviz as pgv
12
- from PIL import Image
13
- import tempfile
14
 
15
  # Set API keys and validate credentials
16
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
@@ -89,13 +88,20 @@ graph.add_edge("tool", END)
89
  graph.set_entry_point("agent")
90
  compiled_app = graph.compile()
91
 
92
- # Function to render graph
93
- def render_graph(graph):
94
- dot_string = graph.get_graph().to_string()
95
- G = pgv.AGraph(string=dot_string)
96
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
97
- G.draw(temp_file.name, prog="dot", format="png")
98
- return Image.open(temp_file.name)
 
 
 
 
 
 
 
99
 
100
  # Streamlit UI
101
  st.title("LLM Tool Workflow Demo")
@@ -124,5 +130,4 @@ if st.button("Run Workflow"):
124
 
125
  # Display Graph
126
  st.subheader("Workflow Graph")
127
- graph_image = render_graph(graph)
128
- st.image(graph_image, caption="Workflow Graph", use_column_width=True)
 
8
  from typing import TypedDict, Annotated, Sequence
9
  from langchain_core.messages import BaseMessage
10
  import operator
11
+ import networkx as nx
12
+ import matplotlib.pyplot as plt
 
13
 
14
  # Set API keys and validate credentials
15
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
 
88
  graph.set_entry_point("agent")
89
  compiled_app = graph.compile()
90
 
91
+ # Function to render graph with NetworkX
92
+ def render_graph_nx(graph):
93
+ G = nx.DiGraph()
94
+ G.add_edge("agent", "tool", label="invoke tool")
95
+ G.add_edge("agent", "end", label="end condition")
96
+ G.add_edge("tool", "end", label="finish")
97
+
98
+ pos = nx.spring_layout(G, seed=42)
99
+ plt.figure(figsize=(8, 6))
100
+ nx.draw(G, pos, with_labels=True, node_color="lightblue", node_size=3000, font_size=10, font_weight="bold")
101
+ edge_labels = nx.get_edge_attributes(G, "label")
102
+ nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=9)
103
+ plt.title("Workflow Graph")
104
+ st.pyplot(plt)
105
 
106
  # Streamlit UI
107
  st.title("LLM Tool Workflow Demo")
 
130
 
131
  # Display Graph
132
  st.subheader("Workflow Graph")
133
+ render_graph_nx(graph)