DrishtiSharma's picture
Create interim.py
aad2d05 verified
raw
history blame
3.18 kB
import os
from datetime import datetime
import streamlit as st
from patentwiz import preprocess_data, qa_agent
from patentwiz.main import PROMPT
# Check if the API key is loaded
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
st.error("OPENAI_API_KEY not found! Please set it in the environment variables or Hugging Face Secrets.")
st.stop()
# Title and description
st.title("Technical Measurements Extractor for Patents")
st.write(
"Analyze patents to extract physical measurements such as length, mass, time, and more. "
"Provide a date to download patents, and analyze them using GPT models."
)
# User Input Section
st.header("Enter Details for Patent Analysis")
user_date_input = st.text_input("Enter a date in the format 'YYYY-MM-DD':", help="e.g., 2024-01-01")
num_patents_to_analyze = st.number_input(
"Number of patents to analyze:", min_value=1, value=1, step=1, help="Specify how many patents you want to analyze."
)
model_choice = st.selectbox(
"Select a model for analysis:", ["gpt-3.5-turbo", "gpt-4"], help="Choose the OpenAI GPT model for the analysis."
)
logging_enabled = st.checkbox("Enable logging?", value=False, help="Toggle logging for debugging purposes.")
# Run Analysis Button
if st.button("Analyze Patents"):
if not user_date_input:
st.error("Please enter a valid date!")
else:
try:
# Parse date input
input_date = datetime.strptime(user_date_input, "%Y-%m-%d")
year, month, day = input_date.year, input_date.month, input_date.day
# Step 1: Download and preprocess patents
with st.spinner("Downloading and extracting patents..."):
saved_patent_names = preprocess_data.parse_and_save_patents(
year, month, day, logging_enabled
)
if not saved_patent_names:
st.error("No patents found for the given date.")
st.stop()
st.success(f"{len(saved_patent_names)} patents found and processed!")
# Step 2: Analyze patents using GPT
random_patents = saved_patent_names[:num_patents_to_analyze]
total_cost = 0
results = []
st.write("Starting patent analysis...")
for i, patent_file in enumerate(random_patents):
cost, output = qa_agent.call_QA_to_json(
PROMPT,
year,
month,
day,
saved_patent_names,
i,
logging_enabled,
model_choice,
)
total_cost += cost
results.append(output)
# Step 3: Display results
st.write(f"**Total Cost:** ${total_cost:.4f}")
st.write("### Analysis Results:")
for idx, result in enumerate(results):
st.subheader(f"Patent {idx + 1}")
st.json(result)
except ValueError as ve:
st.error(f"Invalid date format: {ve}")
except Exception as e:
st.error(f"An unexpected error occurred: {e}")