DrishtiSharma's picture
Create interim.py
5456c64 verified
import os
import streamlit as st
import json
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, Sequence
from langchain_core.messages import BaseMessage
import operator
import networkx as nx
import matplotlib.pyplot as plt
# Set API keys and validate credentials
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
if not OPENAI_API_KEY or not TAVILY_API_KEY:
st.error("API keys not found. Please set OPENAI_API_KEY and TAVILY_API_KEY as environment variables.")
st.stop()
# Initialize OpenAI LLM
model = ChatOpenAI(temperature=0)
# Define Tools
@tool
def multiply(first_number: int, second_number: int) -> int:
"""Multiplies two integers together."""
return first_number * second_number
@tool
def search(query: str):
"""Performs web search on the user query."""
tavily = TavilySearchResults(max_results=1)
result = tavily.invoke(query)
return result
tools = [search, multiply]
tool_map = {tool.name: tool for tool in tools}
model_with_tools = model.bind_tools(tools)
# Define Agent State class
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
# Define workflow nodes
def invoke_model(state):
messages = state['messages']
question = messages[-1]
return {"messages": [model_with_tools.invoke(question)]}
def invoke_tool(state):
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
tool_details = None
for tool_call in tool_calls:
tool_details = tool_call
if tool_details is None:
raise Exception("No tool input found.")
selected_tool = tool_details.get("function").get("name")
st.sidebar.write(f"Selected tool: {selected_tool}")
if selected_tool == "search":
if 'human_loop' in st.session_state and st.session_state['human_loop']:
response = st.sidebar.radio("Proceed with web search?", ["Yes", "No"])
if response == "No":
raise ValueError("User canceled the search tool execution.")
response = tool_map[selected_tool].invoke(json.loads(tool_details.get("function").get("arguments")))
return {"messages": [response]}
def router(state):
tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
if len(tool_calls):
return "tool"
else:
return "end"
# Graph setup
graph = StateGraph(AgentState)
graph.add_node("agent", invoke_model)
graph.add_node("tool", invoke_tool)
graph.add_conditional_edges("agent", router, {"tool": "tool", "end": END})
graph.add_edge("tool", END)
graph.set_entry_point("agent")
compiled_app = graph.compile()
# Function to render graph with NetworkX
def render_graph_nx(graph):
G = nx.DiGraph()
G.add_edge("agent", "tool", label="invoke tool")
G.add_edge("agent", "end", label="end condition")
G.add_edge("tool", "end", label="finish")
pos = nx.spring_layout(G, seed=42)
plt.figure(figsize=(8, 6))
nx.draw(G, pos, with_labels=True, node_color="lightblue", node_size=3000, font_size=10, font_weight="bold")
edge_labels = nx.get_edge_attributes(G, "label")
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=9)
plt.title("Workflow Graph")
st.pyplot(plt)
# Streamlit UI
st.title("LLM Tool Workflow Demo")
st.write("This app demonstrates LLM-based tool usage with and without human intervention.")
# Sidebar for options
st.sidebar.header("Configuration")
st.session_state['human_loop'] = st.sidebar.checkbox("Enable Human-in-the-Loop (For Search)", value=False)
# Input prompt
prompt = st.text_input("Enter your question:", "What is 24 * 365?")
if st.button("Run Workflow"):
st.subheader("Execution Results")
try:
intermediate_outputs = []
for s in compiled_app.stream({"messages": [prompt]}):
intermediate_outputs.append(s)
st.write("Response:", list(s.values())[0])
st.write("---")
st.sidebar.write("### Intermediate Outputs")
for i, output in enumerate(intermediate_outputs):
st.sidebar.write(f"Step {i+1}: {output}")
except Exception as e:
st.error(f"Error occurred: {e}")
# Display Graph
st.subheader("Workflow Graph")
render_graph_nx(graph)