DrishtiSharma commited on
Commit
a1ef31a
Β·
verified Β·
1 Parent(s): b3ee6dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -83
app.py CHANGED
@@ -2,19 +2,13 @@ import streamlit as st
2
  import pandas as pd
3
  import sqlite3
4
  import os
5
- import json
6
- from pathlib import Path
7
- from datetime import datetime, timezone
8
  from crewai import Agent, Crew, Process, Task
9
  from crewai.tools import tool
10
  from langchain_groq import ChatGroq
11
  from langchain_openai import ChatOpenAI
12
- from langchain.schema.output import LLMResult
13
- from langchain_core.callbacks.base import BaseCallbackHandler
14
  from langchain_community.tools.sql_database.tool import (
15
  InfoSQLDatabaseTool,
16
  ListSQLDatabaseTool,
17
- QuerySQLCheckerTool,
18
  QuerySQLDataBaseTool,
19
  )
20
  from langchain_community.utilities.sql_database import SQLDatabase
@@ -24,140 +18,143 @@ import tempfile
24
  st.title("Blah Blah App πŸš€")
25
  st.write("Analyze datasets using natural language queries.")
26
 
27
- # Initialize LLM
28
- llm = None
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # Model Selection
31
  model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
32
-
33
-
34
- # API Key Validation and LLM Initialization
35
- groq_api_key = os.getenv("GROQ_API_KEY")
36
- openai_api_key = os.getenv("OPENAI_API_KEY")
37
-
38
- if model_choice == "llama-3.3-70b":
39
- if not groq_api_key:
40
- st.error("Groq API key is missing. Please set the GROQ_API_KEY environment variable.")
41
- llm = None
42
- else:
43
- llm = ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
44
- elif model_choice == "GPT-4o":
45
- if not openai_api_key:
46
- st.error("OpenAI API key is missing. Please set the OPENAI_API_KEY environment variable.")
47
- llm = None
48
- else:
49
- llm = ChatOpenAI(api_key=openai_api_key, model="gpt-4o")
50
-
51
- # Initialize session state for data persistence
52
- if "df" not in st.session_state:
53
- st.session_state.df = None
54
-
55
- # Dataset Input
56
- input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
57
- if input_option == "Use Hugging Face Dataset":
58
- dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="HUPD/hupd")
59
- if st.button("Load Dataset"):
60
- try:
61
- with st.spinner("Loading dataset..."):
62
  dataset = load_dataset(dataset_name, name="sample", split="train", trust_remote_code=True, uniform_split=True)
63
  st.session_state.df = pd.DataFrame(dataset)
64
  st.success(f"Dataset '{dataset_name}' loaded successfully!")
65
  st.dataframe(st.session_state.df.head())
66
- except Exception as e:
67
- st.error(f"Error: {e}")
68
- elif input_option == "Upload CSV File":
69
- uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
70
- if uploaded_file:
71
- st.session_state.df = pd.read_csv(uploaded_file)
72
- st.success("File uploaded successfully!")
73
- st.dataframe(st.session_state.df.head())
74
 
 
 
 
75
 
76
- if st.session_state.df is not None:
77
- # Database setup
78
  temp_dir = tempfile.TemporaryDirectory()
79
  db_path = os.path.join(temp_dir.name, "patent_data.db")
80
  connection = sqlite3.connect(db_path)
81
- st.session_state.df.to_sql("patents", connection, if_exists="replace", index=False)
82
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
 
83
 
84
- # SQL Tools
 
85
  @tool("list_tables")
86
  def list_tables() -> str:
87
- """List all tables in the patent database."""
88
  return ListSQLDatabaseTool(db=db).invoke("")
89
 
90
  @tool("tables_schema")
91
  def tables_schema(tables: str) -> str:
92
- """Get schema and sample rows for given tables."""
93
  return InfoSQLDatabaseTool(db=db).invoke(tables)
94
 
95
  @tool("execute_sql")
96
  def execute_sql(sql_query: str) -> str:
97
- """Execute a SQL query against the patent database."""
98
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
99
 
100
- # --- CrewAI Agents for Patent Analysis ---
101
- patent_sql_dev = Agent(
 
 
 
 
 
102
  role="Patent Data Analyst",
103
  goal="Extract patent data using optimized SQL queries.",
104
- backstory="An expert in writing optimized SQL queries for complex patent databases.",
105
  llm=llm,
106
  tools=[list_tables, tables_schema, execute_sql],
107
  )
108
 
109
- patent_data_analyst = Agent(
110
  role="Patent Data Analyst",
111
  goal="Analyze the data and produce insights.",
112
- backstory="A seasoned analyst who identifies trends and patterns in datasets.",
113
  llm=llm,
114
  )
115
 
116
- patent_report_writer = Agent(
117
  role="Patent Report Writer",
118
- goal="Summarize patent insights into a clear report.",
119
- backstory="Expert in summarizing patent data insights into comprehensive reports.",
120
  llm=llm,
121
  )
122
 
123
- # --- Crew Tasks ---
124
- extract_data = Task(
 
 
 
125
  description="Extract patents related to the query: {query}.",
126
  expected_output="Patent data matching the query.",
127
- agent=patent_sql_dev,
128
  )
129
 
130
- analyze_data = Task(
131
- description="Analyze the extracted patent data for query: {query}.",
132
  expected_output="Analysis text summarizing findings.",
133
- agent=patent_data_analyst,
134
- context=[extract_data],
135
  )
136
 
137
- write_report = Task(
138
- description="Summarize analysis into an executive report.",
139
  expected_output="Markdown report of insights.",
140
- agent=patent_report_writer,
141
- context=[analyze_data],
142
  )
143
 
144
- # Assemble Crew
145
- crew = Crew(
146
- agents=[patent_sql_dev, patent_data_analyst, patent_report_writer],
147
- tasks=[extract_data, analyze_data, write_report],
148
  process=Process.sequential,
149
  verbose=True,
150
  )
151
 
152
- #Query Input for Patent Analysis
 
 
 
 
 
 
153
  query = st.text_area("Enter Patent Analysis Query:", placeholder="e.g., 'How many patents related to Machine Learning were filed after 2016?'")
154
  if st.button("Submit Query"):
155
  with st.spinner("Processing your query..."):
156
- inputs = {"query": query}
157
- result = crew.kickoff(inputs=inputs)
158
  st.markdown("### πŸ“Š Patent Analysis Report")
159
  st.markdown(result)
160
-
161
- temp_dir.cleanup()
162
  else:
163
- st.info("Please load a patent dataset to proceed.")
 
2
  import pandas as pd
3
  import sqlite3
4
  import os
 
 
 
5
  from crewai import Agent, Crew, Process, Task
6
  from crewai.tools import tool
7
  from langchain_groq import ChatGroq
8
  from langchain_openai import ChatOpenAI
 
 
9
  from langchain_community.tools.sql_database.tool import (
10
  InfoSQLDatabaseTool,
11
  ListSQLDatabaseTool,
 
12
  QuerySQLDataBaseTool,
13
  )
14
  from langchain_community.utilities.sql_database import SQLDatabase
 
18
  st.title("Blah Blah App πŸš€")
19
  st.write("Analyze datasets using natural language queries.")
20
 
21
+ # LLM Initialization
22
+ def initialize_llm(model_choice):
23
+ groq_api_key = os.getenv("GROQ_API_KEY")
24
+ openai_api_key = os.getenv("OPENAI_API_KEY")
25
+
26
+ if model_choice == "llama-3.3-70b":
27
+ if not groq_api_key:
28
+ st.error("Groq API key is missing.")
29
+ return None
30
+ return ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
31
+ elif model_choice == "GPT-4o":
32
+ if not openai_api_key:
33
+ st.error("OpenAI API key is missing.")
34
+ return None
35
+ return ChatOpenAI(api_key=openai_api_key, model="gpt-4o")
36
 
 
37
  model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
38
+ llm = initialize_llm(model_choice)
39
+
40
+ # Dataset Loading
41
+ def load_dataset_into_session():
42
+ input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
43
+ if input_option == "Use Hugging Face Dataset":
44
+ dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="HUPD/hupd")
45
+ if st.button("Load Dataset"):
46
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  dataset = load_dataset(dataset_name, name="sample", split="train", trust_remote_code=True, uniform_split=True)
48
  st.session_state.df = pd.DataFrame(dataset)
49
  st.success(f"Dataset '{dataset_name}' loaded successfully!")
50
  st.dataframe(st.session_state.df.head())
51
+ except Exception as e:
52
+ st.error(f"Error: {e}")
53
+ elif input_option == "Upload CSV File":
54
+ uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
55
+ if uploaded_file:
56
+ st.session_state.df = pd.read_csv(uploaded_file)
57
+ st.success("File uploaded successfully!")
58
+ st.dataframe(st.session_state.df.head())
59
 
60
+ if "df" not in st.session_state:
61
+ st.session_state.df = None
62
+ load_dataset_into_session()
63
 
64
+ # Database Initialization
65
+ def initialize_database(df):
66
  temp_dir = tempfile.TemporaryDirectory()
67
  db_path = os.path.join(temp_dir.name, "patent_data.db")
68
  connection = sqlite3.connect(db_path)
69
+ df.to_sql("patents", connection, if_exists="replace", index=False)
70
  db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
71
+ return db, temp_dir
72
 
73
+ # SQL Tools
74
+ def create_sql_tools(db):
75
  @tool("list_tables")
76
  def list_tables() -> str:
 
77
  return ListSQLDatabaseTool(db=db).invoke("")
78
 
79
  @tool("tables_schema")
80
  def tables_schema(tables: str) -> str:
 
81
  return InfoSQLDatabaseTool(db=db).invoke(tables)
82
 
83
  @tool("execute_sql")
84
  def execute_sql(sql_query: str) -> str:
 
85
  return QuerySQLDataBaseTool(db=db).invoke(sql_query)
86
 
87
+ return list_tables, tables_schema, execute_sql
88
+
89
+ # Agent Initialization
90
+ def initialize_agents(llm, tools):
91
+ list_tables, tables_schema, execute_sql = tools
92
+
93
+ sql_agent = Agent(
94
  role="Patent Data Analyst",
95
  goal="Extract patent data using optimized SQL queries.",
96
+ backstory="Expert in optimized SQL for patent databases.",
97
  llm=llm,
98
  tools=[list_tables, tables_schema, execute_sql],
99
  )
100
 
101
+ analyst_agent = Agent(
102
  role="Patent Data Analyst",
103
  goal="Analyze the data and produce insights.",
104
+ backstory="Data analyst identifying trends.",
105
  llm=llm,
106
  )
107
 
108
+ writer_agent = Agent(
109
  role="Patent Report Writer",
110
+ goal="Summarize patent insights into a report.",
111
+ backstory="Expert in clear, concise reporting.",
112
  llm=llm,
113
  )
114
 
115
+ return sql_agent, analyst_agent, writer_agent
116
+
117
+ # Crew and Tasks Setup
118
+ def setup_crew(sql_agent, analyst_agent, writer_agent):
119
+ extract_task = Task(
120
  description="Extract patents related to the query: {query}.",
121
  expected_output="Patent data matching the query.",
122
+ agent=sql_agent,
123
  )
124
 
125
+ analyze_task = Task(
126
+ description="Analyze the extracted patent data.",
127
  expected_output="Analysis text summarizing findings.",
128
+ agent=analyst_agent,
129
+ context=[extract_task],
130
  )
131
 
132
+ report_task = Task(
133
+ description="Summarize analysis into a report.",
134
  expected_output="Markdown report of insights.",
135
+ agent=writer_agent,
136
+ context=[analyze_task],
137
  )
138
 
139
+ return Crew(
140
+ agents=[sql_agent, analyst_agent, writer_agent],
141
+ tasks=[extract_task, analyze_task, report_task],
 
142
  process=Process.sequential,
143
  verbose=True,
144
  )
145
 
146
+ # Execution Flow
147
+ if st.session_state.df is not None:
148
+ db, temp_dir = initialize_database(st.session_state.df)
149
+ tools = create_sql_tools(db)
150
+ sql_agent, analyst_agent, writer_agent = initialize_agents(llm, tools)
151
+ crew = setup_crew(sql_agent, analyst_agent, writer_agent)
152
+
153
  query = st.text_area("Enter Patent Analysis Query:", placeholder="e.g., 'How many patents related to Machine Learning were filed after 2016?'")
154
  if st.button("Submit Query"):
155
  with st.spinner("Processing your query..."):
156
+ result = crew.kickoff(inputs={"query": query})
 
157
  st.markdown("### πŸ“Š Patent Analysis Report")
158
  st.markdown(result)
 
 
159
  else:
160
+ st.info("Please load a patent dataset to proceed.")