File size: 6,792 Bytes
3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 3eeb9d5 3c8ea82 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import os
import streamlit as st
from streamlit_chat import message as st_message
from sqlalchemy import create_engine
from langchain.agents import Tool, initialize_agent
from langchain.chains.conversation.memory import ConversationBufferMemory
from llama_index import GPTSQLStructStoreIndex, LLMPredictor, ServiceContext
from llama_index import SQLDatabase as llama_SQLDatabase
from llama_index.indices.struct_store import SQLContextContainerBuilder
from constants import (
DEFAULT_SQL_PATH,
DEFAULT_BUSINESS_TABLE_DESCRP,
DEFAULT_VIOLATIONS_TABLE_DESCRP,
DEFAULT_INSPECTIONS_TABLE_DESCRP,
DEFAULT_LC_TOOL_DESCRP,
)
from utils import get_sql_index_tool, get_llm
@st.cache_resource
def initialize_index(
llm_name, model_temperature, table_context_dict, api_key, sql_path=DEFAULT_SQL_PATH
):
"""Create the GPTSQLStructStoreIndex object."""
llm = get_llm(llm_name, model_temperature, api_key)
engine = create_engine(sql_path)
sql_database = llama_SQLDatabase(engine)
context_container = None
if table_context_dict is not None:
context_builder = SQLContextContainerBuilder(
sql_database, context_dict=table_context_dict
)
context_container = context_builder.build_context_container()
service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm))
index = GPTSQLStructStoreIndex(
[],
sql_database=sql_database,
sql_context_container=context_container,
service_context=service_context,
)
return index
@st.cache_resource
def initialize_chain(llm_name, model_temperature, lc_descrp, api_key, _sql_index):
"""Create a (rather hacky) custom agent and sql_index tool."""
sql_tool = Tool(
name="SQL Index",
func=get_sql_index_tool(
_sql_index, _sql_index.sql_context_container.context_dict
),
description=lc_descrp,
)
llm = get_llm(llm_name, model_temperature, api_key=api_key)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
agent_chain = initialize_agent(
[sql_tool],
llm,
agent="chat-conversational-react-description",
verbose=True,
memory=memory,
)
return agent_chain
st.title("🦙 Llama Index SQL Sandbox 🦙")
st.markdown(
(
"This sandbox uses a sqlite database by default, powered by [Llama Index](https://gpt-index.readthedocs.io/en/latest/index.html) ChatGPT, and LangChain.\n\n"
"The database contains information on health violations and inspections at restaurants in San Francisco."
"This data is spread across three tables - businesses, inspections, and violations.\n\n"
"Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain."
"The other tabs will perform chatbot and text2sql operations.\n\n"
"Read more about LlamaIndexes structured data support [here!](https://gpt-index.readthedocs.io/en/latest/guides/tutorials/sql_guide.html)"
)
)
setup_tab, llama_tab, lc_tab = st.tabs(
["Setup", "Llama Index", "Langchain+Llama Index"]
)
with setup_tab:
st.subheader("LLM Setup")
api_key = st.text_input("Enter your OpenAI API key here", type="password")
llm_name = st.selectbox(
"Which LLM?", ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"]
)
model_temperature = st.slider(
"LLM Temperature", min_value=0.0, max_value=1.0, step=0.1
)
st.subheader("Table Setup")
business_table_descrp = st.text_area(
"Business table description", value=DEFAULT_BUSINESS_TABLE_DESCRP
)
violations_table_descrp = st.text_area(
"Business table description", value=DEFAULT_VIOLATIONS_TABLE_DESCRP
)
inspections_table_descrp = st.text_area(
"Business table description", value=DEFAULT_INSPECTIONS_TABLE_DESCRP
)
table_context_dict = {
"businesses": business_table_descrp,
"inspections": inspections_table_descrp,
"violations": violations_table_descrp,
}
use_table_descrp = st.checkbox("Use table descriptions?", value=True)
lc_descrp = st.text_area("LangChain Tool Description", value=DEFAULT_LC_TOOL_DESCRP)
with llama_tab:
st.subheader("Text2SQL with Llama Index")
if st.button("Initialize Index", key="init_index_1"):
st.session_state["llama_index"] = initialize_index(
llm_name,
model_temperature,
table_context_dict if use_table_descrp else None,
api_key,
)
if "llama_index" in st.session_state:
query_text = st.text_input(
"Query:", value="Which restaurant has the most violations?"
)
use_nl = st.checkbox("Return natural language response?")
if st.button("Run Query") and query_text:
with st.spinner("Getting response..."):
try:
response = st.session_state["llama_index"].as_query_engine(synthesize_response=use_nl).query(query_text)
response_text = str(response)
response_sql = response.extra_info["sql_query"]
except Exception as e:
response_text = "Error running SQL Query."
response_sql = str(e)
col1, col2 = st.columns(2)
with col1:
st.text("SQL Result:")
st.markdown(response_text)
with col2:
st.text("SQL Query:")
st.markdown(response_sql)
with lc_tab:
st.subheader("Langchain + Llama Index SQL Demo")
if st.button("Initialize Agent"):
st.session_state["llama_index"] = initialize_index(
llm_name,
model_temperature,
table_context_dict if use_table_descrp else None,
api_key,
)
st.session_state["lc_agent"] = initialize_chain(
llm_name,
model_temperature,
lc_descrp,
api_key,
st.session_state["llama_index"],
)
st.session_state["chat_history"] = []
model_input = st.text_input(
"Message:", value="Which restaurant has the most violations?"
)
if "lc_agent" in st.session_state and st.button("Send"):
model_input = "User: " + model_input
st.session_state["chat_history"].append(model_input)
with st.spinner("Getting response..."):
response = st.session_state["lc_agent"].run(input=model_input)
st.session_state["chat_history"].append(response)
if "chat_history" in st.session_state:
for msg in st.session_state["chat_history"]:
st_message(msg.split("User: ")[-1], is_user="User: " in msg)
|