RajMl commited on
Commit
090aae2
1 Parent(s): d9ba12f

Update aux_functions/chat_sql_function.py

Browse files
Files changed (1) hide show
  1. aux_functions/chat_sql_function.py +47 -44
aux_functions/chat_sql_function.py CHANGED
@@ -1,44 +1,47 @@
1
- import os
2
- from dotenv import load_dotenv
3
- from langchain_community.utilities import SQLDatabase
4
- from langchain_community.agent_toolkits import create_sql_agent
5
- from langchain_openai import ChatOpenAI
6
- from huggingface_hub import notebook_login
7
- from huggingface_hub import hf_secrets
8
-
9
-
10
- # Load environment variables from .env file
11
- def initiate_chat(querry):
12
- load_dotenv()
13
-
14
- notebook_login()
15
-
16
- # Set environment variables for API keys
17
- # os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
18
- # os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY")
19
- # os.environ["LANGCHAIN_TRACING_V2"] = "true"
20
-
21
- # Define the SQL database URI
22
- db_uri = "sqlite:///db.db"
23
-
24
- # Initialize the SQLDatabase object
25
- db = SQLDatabase.from_uri(db_uri)
26
- api_key = hf_secrets.get("open_api")
27
-
28
- # Initialize the ChatOpenAI object with the desired model
29
- llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=api_key)
30
- # llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
31
-
32
- # Create the SQL agent
33
- agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
34
-
35
- # Define the query to get the sum of all salaries
36
-
37
-
38
- # Execute the query using the agent
39
- try:
40
- result = agent_executor.invoke(querry)
41
- return result
42
- except Exception as e:
43
- print(f"Error: {e}")
44
- return "Error: " + str(e)
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langchain_community.utilities import SQLDatabase
4
+ from langchain_community.agent_toolkits import create_sql_agent
5
+ from langchain_openai import ChatOpenAI
6
+ from huggingface_hub import notebook_login
7
+ from huggingface_hub import hf_secrets
8
+
9
+
10
+ # Load environment variables from .env file
11
+ def initiate_chat(querry):
12
+ load_dotenv()
13
+
14
+ notebook_login()
15
+
16
+ # Set environment variables for API keys
17
+ # os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
18
+ # os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY")
19
+ # os.environ["LANGCHAIN_TRACING_V2"] = "true"
20
+
21
+ # Define the SQL database URI
22
+ db_uri = "sqlite:///db.db"
23
+
24
+ # Initialize the SQLDatabase object
25
+ db = SQLDatabase.from_uri(db_uri)
26
+ import os
27
+
28
+ api_key = os.getenv("open_ai")
29
+
30
+
31
+ # Initialize the ChatOpenAI object with the desired model
32
+ llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=api_key)
33
+ # llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
34
+
35
+ # Create the SQL agent
36
+ agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
37
+
38
+ # Define the query to get the sum of all salaries
39
+
40
+
41
+ # Execute the query using the agent
42
+ try:
43
+ result = agent_executor.invoke(querry)
44
+ return result
45
+ except Exception as e:
46
+ print(f"Error: {e}")
47
+ return "Error: " + str(e)