Shahabmoin commited on
Commit
7bf1eec
·
verified ·
1 Parent(s): a199ef4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -24
app.py CHANGED
@@ -1,31 +1,71 @@
1
- import streamlit as st
 
2
  import pandas as pd
 
 
3
  from groq import Groq
4
 
 
 
 
5
  # Initialize Groq API client
6
  GROQ_API_KEY = "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"
7
  client = Groq(api_key=GROQ_API_KEY)
8
 
9
- # Helper functions
10
- def preprocess_data(uploaded_file):
11
- data = pd.read_csv(uploaded_file)
12
- return data
13
-
14
- def generate_report(data, query):
15
- # This should include the retrieval and report generation logic from above
16
- return f"Report for query: {query}"
17
-
18
- # Streamlit UI
19
- st.title("Energy Usage Analysis Report Generator")
20
- uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
21
-
22
- if uploaded_file:
23
- data = preprocess_data(uploaded_file)
24
- st.write("Dataset Preview:")
25
- st.dataframe(data.head())
26
-
27
- query = st.text_input("Enter your query:")
28
- if query:
29
- report = generate_report(data, query)
30
- st.write("Generated Report:")
31
- st.text(report)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
  import pandas as pd
4
+ import faiss
5
+ from sentence_transformers import SentenceTransformer
6
  from groq import Groq
7
 
8
+ # Load pre-trained Sentence Transformer model
9
+ model = SentenceTransformer('all-MiniLM-L6-v2')
10
+
11
  # Initialize Groq API client
12
  GROQ_API_KEY = "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"
13
  client = Groq(api_key=GROQ_API_KEY)
14
 
15
+ # Generate embeddings using Sentence Transformers
16
+ def generate_embeddings(text):
17
+ return model.encode(text)
18
+
19
+ # Build FAISS index for retrieval
20
+ def build_faiss_index(data):
21
+ index = faiss.IndexFlatL2(384) # 384-dimensional embeddings for MiniLM
22
+ embeddings = [generate_embeddings(row.to_string()) for _, row in data.iterrows()]
23
+ embeddings = np.array(embeddings).astype("float32")
24
+ index.add(embeddings)
25
+ return index, embeddings
26
+
27
+ # Query FAISS index
28
+ def query_index(query, data, index):
29
+ query_embedding = generate_embeddings(query).astype("float32")
30
+ distances, indices = index.search(np.array([query_embedding]), k=5)
31
+ results = data.iloc[indices[0]]
32
+ return results
33
+
34
+ # Generate a detailed report using Groq's generative model
35
+ def generate_report_with_groq(query, results):
36
+ input_text = f"Based on the query '{query}', the following insights are generated:\n\n{results.to_string(index=False)}"
37
+ response = client.chat.completions.create(
38
+ messages=[{"role": "user", "content": input_text}],
39
+ model="llama3-8b-8192",
40
+ stream=False
41
+ )
42
+ return response.choices[0].message.content
43
+
44
+ # Main function to execute the workflow
45
+ if __name__ == "__main__":
46
+ # Load dataset
47
+ csv_path = "energy_usage_data.csv" # Ensure this CSV is uploaded to your working directory
48
+ data = pd.read_csv(csv_path)
49
+
50
+ # Preprocess data (if needed)
51
+ data.fillna("", inplace=True)
52
+
53
+ # Build FAISS index
54
+ print("Building FAISS index...")
55
+ index, embeddings = build_faiss_index(data)
56
+
57
+ # User query
58
+ query = "Show households with high energy usage in the North region"
59
+ print(f"User Query: {query}")
60
+
61
+ # Query FAISS index
62
+ print("Retrieving relevant data...")
63
+ results = query_index(query, data, index)
64
+
65
+ # Generate report
66
+ print("Generating report using Groq API...")
67
+ report = generate_report_with_groq(query, results)
68
+
69
+ # Output the report
70
+ print("Generated Report:\n")
71
+ print(report)