DrishtiSharma commited on
Commit
5456c64
·
verified ·
1 Parent(s): 6988bcd

Create interim.py

Browse files
Files changed (1) hide show
  1. interim.py +133 -0
interim.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import json
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_core.tools import tool
6
+ from langchain_community.tools.tavily_search import TavilySearchResults
7
+ from langgraph.graph import StateGraph, END
8
+ from typing import TypedDict, Annotated, Sequence
9
+ from langchain_core.messages import BaseMessage
10
+ import operator
11
+ import networkx as nx
12
+ import matplotlib.pyplot as plt
13
+
14
+ # Set API keys and validate credentials
15
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
16
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
17
+
18
+ if not OPENAI_API_KEY or not TAVILY_API_KEY:
19
+ st.error("API keys not found. Please set OPENAI_API_KEY and TAVILY_API_KEY as environment variables.")
20
+ st.stop()
21
+
22
+ # Initialize OpenAI LLM
23
+ model = ChatOpenAI(temperature=0)
24
+
25
+ # Define Tools
26
+ @tool
27
+ def multiply(first_number: int, second_number: int) -> int:
28
+ """Multiplies two integers together."""
29
+ return first_number * second_number
30
+
31
+ @tool
32
+ def search(query: str):
33
+ """Performs web search on the user query."""
34
+ tavily = TavilySearchResults(max_results=1)
35
+ result = tavily.invoke(query)
36
+ return result
37
+
38
+ tools = [search, multiply]
39
+ tool_map = {tool.name: tool for tool in tools}
40
+
41
+ model_with_tools = model.bind_tools(tools)
42
+
43
+ # Define Agent State class
44
+ class AgentState(TypedDict):
45
+ messages: Annotated[Sequence[BaseMessage], operator.add]
46
+
47
+ # Define workflow nodes
48
+ def invoke_model(state):
49
+ messages = state['messages']
50
+ question = messages[-1]
51
+ return {"messages": [model_with_tools.invoke(question)]}
52
+
53
+ def invoke_tool(state):
54
+ tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
55
+ tool_details = None
56
+
57
+ for tool_call in tool_calls:
58
+ tool_details = tool_call
59
+
60
+ if tool_details is None:
61
+ raise Exception("No tool input found.")
62
+
63
+ selected_tool = tool_details.get("function").get("name")
64
+ st.sidebar.write(f"Selected tool: {selected_tool}")
65
+
66
+ if selected_tool == "search":
67
+ if 'human_loop' in st.session_state and st.session_state['human_loop']:
68
+ response = st.sidebar.radio("Proceed with web search?", ["Yes", "No"])
69
+ if response == "No":
70
+ raise ValueError("User canceled the search tool execution.")
71
+
72
+ response = tool_map[selected_tool].invoke(json.loads(tool_details.get("function").get("arguments")))
73
+ return {"messages": [response]}
74
+
75
+ def router(state):
76
+ tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
77
+ if len(tool_calls):
78
+ return "tool"
79
+ else:
80
+ return "end"
81
+
82
+ # Graph setup
83
+ graph = StateGraph(AgentState)
84
+ graph.add_node("agent", invoke_model)
85
+ graph.add_node("tool", invoke_tool)
86
+ graph.add_conditional_edges("agent", router, {"tool": "tool", "end": END})
87
+ graph.add_edge("tool", END)
88
+ graph.set_entry_point("agent")
89
+ compiled_app = graph.compile()
90
+
91
+ # Function to render graph with NetworkX
92
+ def render_graph_nx(graph):
93
+ G = nx.DiGraph()
94
+ G.add_edge("agent", "tool", label="invoke tool")
95
+ G.add_edge("agent", "end", label="end condition")
96
+ G.add_edge("tool", "end", label="finish")
97
+
98
+ pos = nx.spring_layout(G, seed=42)
99
+ plt.figure(figsize=(8, 6))
100
+ nx.draw(G, pos, with_labels=True, node_color="lightblue", node_size=3000, font_size=10, font_weight="bold")
101
+ edge_labels = nx.get_edge_attributes(G, "label")
102
+ nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=9)
103
+ plt.title("Workflow Graph")
104
+ st.pyplot(plt)
105
+
106
+ # Streamlit UI
107
+ st.title("LLM Tool Workflow Demo")
108
+ st.write("This app demonstrates LLM-based tool usage with and without human intervention.")
109
+
110
+ # Sidebar for options
111
+ st.sidebar.header("Configuration")
112
+ st.session_state['human_loop'] = st.sidebar.checkbox("Enable Human-in-the-Loop (For Search)", value=False)
113
+
114
+ # Input prompt
115
+ prompt = st.text_input("Enter your question:", "What is 24 * 365?")
116
+ if st.button("Run Workflow"):
117
+ st.subheader("Execution Results")
118
+ try:
119
+ intermediate_outputs = []
120
+ for s in compiled_app.stream({"messages": [prompt]}):
121
+ intermediate_outputs.append(s)
122
+ st.write("Response:", list(s.values())[0])
123
+ st.write("---")
124
+
125
+ st.sidebar.write("### Intermediate Outputs")
126
+ for i, output in enumerate(intermediate_outputs):
127
+ st.sidebar.write(f"Step {i+1}: {output}")
128
+ except Exception as e:
129
+ st.error(f"Error occurred: {e}")
130
+
131
+ # Display Graph
132
+ st.subheader("Workflow Graph")
133
+ render_graph_nx(graph)