DrishtiSharma commited on
Commit
3d6c113
Β·
verified Β·
1 Parent(s): c648049

Update interim.py

Browse files
Files changed (1) hide show
  1. interim.py +26 -22
interim.py CHANGED
@@ -8,6 +8,7 @@ from datetime import datetime, timezone
8
  from crewai import Agent, Crew, Process, Task
9
  from crewai.tools import tool
10
  from langchain_groq import ChatGroq
 
11
  from langchain.schema.output import LLMResult
12
  from langchain_core.callbacks.base import BaseCallbackHandler
13
  from langchain_community.tools.sql_database.tool import (
@@ -23,30 +24,33 @@ import tempfile
23
  # API Key
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
- # Initialize LLM
27
- class LLMCallbackHandler(BaseCallbackHandler):
28
- def __init__(self, log_path: Path):
29
- self.log_path = log_path
30
-
31
- def on_llm_start(self, serialized, prompts, **kwargs):
32
- with self.log_path.open("a", encoding="utf-8") as file:
33
- file.write(json.dumps({"event": "llm_start", "text": prompts[0], "timestamp": datetime.now().isoformat()}) + "\n")
34
-
35
- def on_llm_end(self, response: LLMResult, **kwargs):
36
- generation = response.generations[-1][-1].message.content
37
- with self.log_path.open("a", encoding="utf-8") as file:
38
- file.write(json.dumps({"event": "llm_end", "text": generation, "timestamp": datetime.now().isoformat()}) + "\n")
39
-
40
- llm = ChatGroq(
41
- temperature=0,
42
- model_name="groq/llama-3.3-70b-versatile",
43
- max_tokens=120,
44
- callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
45
- )
46
-
47
  st.title("Blah Blah App Using CrewAI πŸš€")
48
  st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # Initialize session state for data persistence
51
  if "df" not in st.session_state:
52
  st.session_state.df = None
@@ -148,7 +152,7 @@ if st.session_state.df is not None:
148
  verbose=True,
149
  )
150
 
151
- # Query Input for Patent Analysis
152
  query = st.text_area("Enter Patent Analysis Query:", placeholder="e.g., 'How many patents related to Machine Learning were filed after 2016?'")
153
  if st.button("Submit Query"):
154
  with st.spinner("Processing your query..."):
 
8
  from crewai import Agent, Crew, Process, Task
9
  from crewai.tools import tool
10
  from langchain_groq import ChatGroq
11
+ from langchain_openai import ChatOpenAI
12
  from langchain.schema.output import LLMResult
13
  from langchain_core.callbacks.base import BaseCallbackHandler
14
  from langchain_community.tools.sql_database.tool import (
 
24
  # API Key
25
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  st.title("Blah Blah App Using CrewAI πŸš€")
28
  st.write("Analyze datasets using natural language queries powered by SQL and CrewAI.")
29
 
30
+ # Initialize LLM
31
+ llm = None
32
+
33
+ # Model Selection
34
+ model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
35
+
36
+
37
+ # API Key Validation and LLM Initialization
38
+ groq_api_key = os.getenv("GROQ_API_KEY")
39
+ openai_api_key = os.getenv("OPENAI_API_KEY")
40
+
41
+ if model_choice == "llama-3.3-70b":
42
+ if not groq_api_key:
43
+ st.error("Groq API key is missing. Please set the GROQ_API_KEY environment variable.")
44
+ llm = None
45
+ else:
46
+ llm = ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
47
+ elif model_choice == "GPT-4o":
48
+ if not openai_api_key:
49
+ st.error("OpenAI API key is missing. Please set the OPENAI_API_KEY environment variable.")
50
+ llm = None
51
+ else:
52
+ llm = ChatOpenAI(api_key=openai_api_key, model="gpt-4o")
53
+
54
  # Initialize session state for data persistence
55
  if "df" not in st.session_state:
56
  st.session_state.df = None
 
152
  verbose=True,
153
  )
154
 
155
+ #Query Input for Patent Analysis
156
  query = st.text_area("Enter Patent Analysis Query:", placeholder="e.g., 'How many patents related to Machine Learning were filed after 2016?'")
157
  if st.button("Submit Query"):
158
  with st.spinner("Processing your query..."):