DrishtiSharma commited on
Commit
e1b05e1
Β·
verified Β·
1 Parent(s): 5cc7611

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -14
app.py CHANGED
@@ -23,7 +23,7 @@ import tempfile
23
  # API Key
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
- # LLM Callback Logger
27
  class LLMCallbackHandler(BaseCallbackHandler):
28
  def __init__(self, log_path: Path):
29
  self.log_path = log_path
@@ -37,45 +37,45 @@ class LLMCallbackHandler(BaseCallbackHandler):
37
  with self.log_path.open("a", encoding="utf-8") as file:
38
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
 
40
- # Initialize LLM
41
  llm = ChatGroq(
42
  temperature=0,
43
  model_name="mixtral-8x7b-32768",
44
  callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
45
  )
46
 
47
- # Streamlit UI
48
  st.title("SQL-RAG Using CrewAI πŸš€")
49
  st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
50
 
 
 
 
 
51
  # Dataset Input
52
  input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
53
- df = None
54
-
55
  if input_option == "Use Hugging Face Dataset":
56
  dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="Einstellung/demo-salaries")
57
  if st.button("Load Dataset"):
58
  try:
59
  with st.spinner("Loading dataset..."):
60
  dataset = load_dataset(dataset_name, split="train")
61
- df = pd.DataFrame(dataset)
62
  st.success(f"Dataset '{dataset_name}' loaded successfully!")
63
- st.dataframe(df.head())
64
  except Exception as e:
65
  st.error(f"Error: {e}")
66
- else:
67
  uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
68
  if uploaded_file:
69
- df = pd.read_csv(uploaded_file)
70
  st.success("File uploaded successfully!")
71
- st.dataframe(df.head())
72
 
73
  # SQL-RAG Analysis
74
- if df is not None:
75
  temp_dir = tempfile.TemporaryDirectory()
76
  db_path = os.path.join(temp_dir.name, "data.db")
77
  connection = sqlite3.connect(db_path)
78
- df.to_sql("salaries", connection, if_exists="replace", index=False)
79
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
80
 
81
  @tool("list_tables")
@@ -98,7 +98,6 @@ if df is not None:
98
  """Check the validity of a SQL query."""
99
  return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
100
 
101
- # Agents
102
  sql_dev = Agent(
103
  role="Senior Database Developer",
104
  goal="Extract data using optimized SQL queries.",
@@ -121,7 +120,6 @@ if df is not None:
121
  llm=llm,
122
  )
123
 
124
- # Tasks
125
  extract_data = Task(
126
  description="Extract data based on the query: {query}.",
127
  expected_output="Database results matching the query.",
 
23
  # API Key
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
+ # Initialize LLM
27
  class LLMCallbackHandler(BaseCallbackHandler):
28
  def __init__(self, log_path: Path):
29
  self.log_path = log_path
 
37
  with self.log_path.open("a", encoding="utf-8") as file:
38
  file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
 
 
40
  llm = ChatGroq(
41
  temperature=0,
42
  model_name="mixtral-8x7b-32768",
43
  callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
44
  )
45
 
 
46
  st.title("SQL-RAG Using CrewAI πŸš€")
47
  st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
48
 
49
+ # Initialize session state for data persistence
50
+ if "df" not in st.session_state:
51
+ st.session_state.df = None
52
+
53
  # Dataset Input
54
  input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
 
 
55
  if input_option == "Use Hugging Face Dataset":
56
  dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="Einstellung/demo-salaries")
57
  if st.button("Load Dataset"):
58
  try:
59
  with st.spinner("Loading dataset..."):
60
  dataset = load_dataset(dataset_name, split="train")
61
+ st.session_state.df = pd.DataFrame(dataset)
62
  st.success(f"Dataset '{dataset_name}' loaded successfully!")
63
+ st.dataframe(st.session_state.df.head())
64
  except Exception as e:
65
  st.error(f"Error: {e}")
66
+ elif input_option == "Upload CSV File":
67
  uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
68
  if uploaded_file:
69
+ st.session_state.df = pd.read_csv(uploaded_file)
70
  st.success("File uploaded successfully!")
71
+ st.dataframe(st.session_state.df.head())
72
 
73
  # SQL-RAG Analysis
74
+ if st.session_state.df is not None:
75
  temp_dir = tempfile.TemporaryDirectory()
76
  db_path = os.path.join(temp_dir.name, "data.db")
77
  connection = sqlite3.connect(db_path)
78
+ st.session_state.df.to_sql("salaries", connection, if_exists="replace", index=False)
79
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
80
 
81
  @tool("list_tables")
 
98
  """Check the validity of a SQL query."""
99
  return QuerySQLCheckerTool(db=db, llm=llm).invoke({"query": sql_query})
100
 
 
101
  sql_dev = Agent(
102
  role="Senior Database Developer",
103
  goal="Extract data using optimized SQL queries.",
 
120
  llm=llm,
121
  )
122
 
 
123
  extract_data = Task(
124
  description="Extract data based on the query: {query}.",
125
  expected_output="Database results matching the query.",