DrishtiSharma's picture
Update app.py
c47fc55 verified
raw
history blame
7.08 kB
import streamlit as st
import pandas as pd
import os
from pandasai import SmartDataframe
from pandasai.llm import OpenAI
import tempfile
import matplotlib.pyplot as plt
from datasets import load_dataset
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
import time
# Load environment variables
openai_api_key = os.getenv("OPENAI_API_KEY")
groq_api_key = os.getenv("GROQ_API_KEY")
st.title("Chat with Patent Dataset Using PandasAI")
# Initialize the LLM based on user selection
def initialize_llm(model_choice):
if model_choice == "llama-3.3-70b":
if not groq_api_key:
st.error("Groq API key is missing. Please set the GROQ_API_KEY environment variable.")
return None
return ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
elif model_choice == "GPT-4o":
if not openai_api_key:
st.error("OpenAI API key is missing. Please set the OPENAI_API_KEY environment variable.")
return None
return ChatOpenAI(api_key=openai_api_key, model="gpt-4o")
# Select LLM model
model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
llm = initialize_llm(model_choice)
# Dataset loading without caching to support progress bar
def load_huggingface_dataset(dataset_name):
# Initialize progress bar
progress_bar = st.progress(0)
try:
# Incrementally update progress
progress_bar.progress(10)
dataset = load_dataset(dataset_name, name="sample", split="train", trust_remote_code=True, uniform_split=True)
progress_bar.progress(50)
if hasattr(dataset, "to_pandas"):
df = dataset.to_pandas()
else:
df = pd.DataFrame(dataset)
progress_bar.progress(100) # Final update to 100%
return df
except Exception as e:
progress_bar.progress(0) # Reset progress bar on failure
raise e
def load_uploaded_csv(uploaded_file):
# Initialize progress bar
progress_bar = st.progress(0)
try:
# Simulate progress
progress_bar.progress(10)
time.sleep(1) # Simulate file processing delay
progress_bar.progress(50)
df = pd.read_csv(uploaded_file)
progress_bar.progress(100) # Final update
return df
except Exception as e:
progress_bar.progress(0) # Reset progress bar on failure
raise e
# Dataset selection logic
def load_dataset_into_session():
input_option = st.radio(
"Select Dataset Input:",
["Use Repo Directory Dataset", "Use Hugging Face Dataset", "Upload CSV File"], index=1, horizontal=True
)
# Option 1: Load dataset from the repo directory
if input_option == "Use Repo Directory Dataset":
file_path = "./source/test.csv"
if st.button("Load Dataset"):
try:
with st.spinner("Loading dataset from the repo directory..."):
st.session_state.df = pd.read_csv(file_path)
st.success(f"File loaded successfully from '{file_path}'!")
except Exception as e:
st.error(f"Error loading dataset from the repo directory: {e}")
# Option 2: Load dataset from Hugging Face
elif input_option == "Use Hugging Face Dataset":
dataset_name = st.text_input(
"Enter Hugging Face Dataset Name:", value="HUPD/hupd"
)
if st.button("Load Dataset"):
try:
st.session_state.df = load_huggingface_dataset(dataset_name)
st.success(f"Hugging Face Dataset '{dataset_name}' loaded successfully!")
except Exception as e:
st.error(f"Error loading Hugging Face dataset: {e}")
# Option 3: Upload CSV File
elif input_option == "Upload CSV File":
uploaded_file = st.file_uploader("Upload a CSV File:", type=["csv"])
if uploaded_file:
try:
st.session_state.df = load_uploaded_csv(uploaded_file)
st.success("File uploaded successfully!")
except Exception as e:
st.error(f"Error reading uploaded file: {e}")
# Load dataset into session
load_dataset_into_session()
if "df" in st.session_state and llm:
df = st.session_state.df
# Display dataset metadata
st.write("### Dataset Metadata")
st.text(f"Number of Rows: {df.shape[0]}")
st.text(f"Number of Columns: {df.shape[1]}")
st.text(f"Column Names: {', '.join(df.columns)}")
# Display dataset preview
st.write("### Dataset Preview")
num_rows = st.slider("Select number of rows to display:", min_value=5, max_value=50, value=10)
st.dataframe(df.head(num_rows))
# Create SmartDataFrame
chat_df = SmartDataframe(df, config={"llm": llm})
# Chat functionality
st.write("### Chat with Patent Data")
user_query = st.text_input("Enter your question about the patent data:", value = "Have the patents with the numbers 14908945, 14994130, 14909084, and 14995057 been accepted or rejected? What are their titles?")
if user_query:
try:
response = chat_df.chat(user_query)
st.success(f"Response: {response}")
except Exception as e:
st.error(f"Error: {e}")
# Plot generation functionality
st.write("### Generate and View Graphs")
plot_query = st.text_input("Enter a query to generate a graph:", value = "What is the distribution of patents categorized as 'ACCEPTED', 'REJECTED', or 'PENDING'?")
if plot_query:
try:
with tempfile.TemporaryDirectory() as temp_dir:
# PandasAI can handle plotting
chat_df.chat(plot_query)
# Save and display the plot
temp_plot_path = os.path.join(temp_dir, "plot.png")
plt.savefig(temp_plot_path)
st.image(temp_plot_path, caption="Generated Plot", use_container_width=True)
except Exception as e:
st.error(f"Error: {e}")
# Download processed dataset
#st.write("### Download Processed Dataset")
#st.download_button(
# label="Download Dataset as CSV",
# data=df.to_csv(index=False),
# file_name="processed_dataset.csv",
# mime="text/csv"
#)
# Sidebar instructions
with st.sidebar:
st.header("πŸ“‹ Instructions:")
st.markdown(
"1. Choose an LLM (Groq-based or OpenAI-based) to interact with the data.\n"
"2. Upload, select, or fetch the dataset using the provided options.\n"
"3. Enter a query to generate and view graphs based on patent attributes.\n"
" - Example: 'Predict if the patent will be accepted.'\n"
" - Example: 'What is the primary classification of this patent?'\n"
" - Example: 'Summarize the abstract of this patent.'\n"
)
st.markdown("---")
st.header("πŸ“š References:")
st.markdown(
"1. [Chat With Your CSV File With PandasAI - Prince Krampah](https://medium.com/aimonks/chat-with-your-csv-file-with-pandasai-22232a13c7b7)"
)