DrishtiSharma commited on
Commit
739c523
Β·
verified Β·
1 Parent(s): 4b8d5d1

Create interim.py

Browse files
Files changed (1) hide show
  1. interim.py +162 -0
interim.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import sqlite3
4
+ import os
5
+ import json
6
+ from pathlib import Path
7
+ from datetime import datetime, timezone
8
+ from crewai import Agent, Crew, Process, Task
9
+ from crewai.tools import tool
10
+ from langchain_groq import ChatGroq
11
+ from langchain.schema.output import LLMResult
12
+ from langchain_core.callbacks.base import BaseCallbackHandler
13
+ from langchain_community.tools.sql_database.tool import (
14
+ InfoSQLDatabaseTool,
15
+ ListSQLDatabaseTool,
16
+ QuerySQLCheckerTool,
17
+ QuerySQLDataBaseTool,
18
+ )
19
+ from langchain_community.utilities.sql_database import SQLDatabase
20
+ from datasets import load_dataset
21
+ import tempfile
22
+
23
+ # API Key
24
+ os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
+
26
+ # Initialize LLM
27
+ class LLMCallbackHandler(BaseCallbackHandler):
28
+ def __init__(self, log_path: Path):
29
+ self.log_path = log_path
30
+
31
+ def on_llm_start(self, serialized, prompts, **kwargs):
32
+ with self.log_path.open("a", encoding="utf-8") as file:
33
+ file.write(json.dumps({"event": "llm_start", "text": prompts[0], "timestamp": datetime.now().isoformat()}) + "\n")
34
+
35
+ def on_llm_end(self, response: LLMResult, **kwargs):
36
+ generation = response.generations[-1][-1].message.content
37
+ with self.log_path.open("a", encoding="utf-8") as file:
38
+ file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
+
40
+ llm = ChatGroq(
41
+ temperature=0,
42
+ model_name="groq/llama-3.3-70b-versatile",
43
+ max_tokens=120,
44
+ callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
45
+ )
46
+
47
+ st.title("Blah Blah App Using CrewAI πŸš€")
48
+ st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
49
+
50
+ # Initialize session state for data persistence
51
+ if "df" not in st.session_state:
52
+ st.session_state.df = None
53
+
54
+ # Dataset Input
55
+ input_option = st.radio("Select Dataset Input:", ["Use Hugging Face Dataset", "Upload CSV File"])
56
+ if input_option == "Use Hugging Face Dataset":
57
+ dataset_name = st.text_input("Enter Hugging Face Dataset Name:", value="HUPD/hupd")
58
+ if st.button("Load Dataset"):
59
+ try:
60
+ with st.spinner("Loading dataset..."):
61
+ dataset = load_dataset(dataset_name, name="sample", split="train", trust_remote_code=True, uniform_split=True)
62
+ st.session_state.df = pd.DataFrame(dataset)
63
+ st.success(f"Dataset '{dataset_name}' loaded successfully!")
64
+ st.dataframe(st.session_state.df.head())
65
+ except Exception as e:
66
+ st.error(f"Error: {e}")
67
+ elif input_option == "Upload CSV File":
68
+ uploaded_file = st.file_uploader("Upload CSV File:", type=["csv"])
69
+ if uploaded_file:
70
+ st.session_state.df = pd.read_csv(uploaded_file)
71
+ st.success("File uploaded successfully!")
72
+ st.dataframe(st.session_state.df.head())
73
+
74
+
75
+ if st.session_state.df is not None:
76
+ # Database setup
77
+ temp_dir = tempfile.TemporaryDirectory()
78
+ db_path = os.path.join(temp_dir.name, "patent_data.db")
79
+ connection = sqlite3.connect(db_path)
80
+ st.session_state.df.to_sql("patents", connection, if_exists="replace", index=False)
81
+ db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
82
+
83
+ # SQL Tools
84
+ @tool("list_tables")
85
+ def list_tables() -> str:
86
+ """List all tables in the patent database."""
87
+ return ListSQLDatabaseTool(db=db).invoke("")
88
+
89
+ @tool("tables_schema")
90
+ def tables_schema(tables: str) -> str:
91
+ """Get schema and sample rows for given tables."""
92
+ return InfoSQLDatabaseTool(db=db).invoke(tables)
93
+
94
+ @tool("execute_sql")
95
+ def execute_sql(sql_query: str) -> str:
96
+ """Execute a SQL query against the patent database."""
97
+ return QuerySQLDataBaseTool(db=db).invoke(sql_query)
98
+
99
+ # --- CrewAI Agents for Patent Analysis ---
100
+ patent_sql_dev = Agent(
101
+ role="Patent Data Analyst",
102
+ goal="Extract patent data using optimized SQL queries.",
103
+ backstory="An expert in writing optimized SQL queries for complex patent databases.",
104
+ llm=llm,
105
+ tools=[list_tables, tables_schema, execute_sql],
106
+ )
107
+
108
+ patent_data_analyst = Agent(
109
+ role="Patent Data Analyst",
110
+ goal="Analyze the data and produce insights.",
111
+ backstory="A seasoned analyst who identifies trends and patterns in datasets.",
112
+ llm=llm,
113
+ )
114
+
115
+ patent_report_writer = Agent(
116
+ role="Patent Report Writer",
117
+ goal="Summarize patent insights into a clear report.",
118
+ backstory="Expert in summarizing patent data insights into comprehensive reports.",
119
+ llm=llm,
120
+ )
121
+
122
+ # --- Crew Tasks ---
123
+ extract_data = Task(
124
+ description="Extract patents related to the query: {query}.",
125
+ expected_output="Patent data matching the query.",
126
+ agent=patent_sql_dev,
127
+ )
128
+
129
+ analyze_data = Task(
130
+ description="Analyze the extracted patent data for query: {query}.",
131
+ expected_output="Analysis text summarizing findings.",
132
+ agent=patent_data_analyst,
133
+ context=[extract_data],
134
+ )
135
+
136
+ write_report = Task(
137
+ description="Summarize analysis into an executive report.",
138
+ expected_output="Markdown report of insights.",
139
+ agent=patent_report_writer,
140
+ context=[analyze_data],
141
+ )
142
+
143
+ # Assemble Crew
144
+ crew = Crew(
145
+ agents=[patent_sql_dev, patent_data_analyst, patent_report_writer],
146
+ tasks=[extract_data, analyze_data, write_report],
147
+ process=Process.sequential,
148
+ verbose=True,
149
+ )
150
+
151
+ # Query Input for Patent Analysis
152
+ query = st.text_area("Enter Patent Analysis Query:", placeholder="e.g., 'How many patents related to Machine Learning were filed after 2016?'")
153
+ if st.button("Submit Query"):
154
+ with st.spinner("Processing your query..."):
155
+ inputs = {"query": query}
156
+ result = crew.kickoff(inputs=inputs)
157
+ st.markdown("### πŸ“Š Patent Analysis Report")
158
+ st.markdown(result)
159
+
160
+ temp_dir.cleanup()
161
+ else:
162
+ st.info("Please load a patent dataset to proceed.")