|
import os |
|
from datetime import datetime |
|
import streamlit as st |
|
from patentwiz import preprocess_data, qa_agent, main |
|
from patentwiz.main import PROMPT |
|
|
|
|
|
api_key = os.getenv("OPENAI_API_KEY") |
|
if not api_key: |
|
st.error("OPENAI_API_KEY not found! Please set it in the environment variables or Hugging Face Secrets.") |
|
st.stop() |
|
|
|
|
|
st.title("Technical Measurements Extractor for Patents") |
|
st.write( |
|
"Analyze patents to extract physical measurements such as length, mass, time, and more. " |
|
"Provide a date to download patents, and analyze them using GPT models." |
|
) |
|
|
|
|
|
st.header("Enter Details for Patent Analysis") |
|
user_date_input = st.text_input("Enter a date in the format 'YYYY-MM-DD':", help="e.g., 2024-01-01") |
|
|
|
num_patents_to_analyze = st.number_input( |
|
"Number of patents to analyze:", min_value=1, value=1, step=1, help="Specify how many patents you want to analyze." |
|
) |
|
|
|
model_choice = st.selectbox( |
|
"Select a model for analysis:", ["gpt-3.5-turbo", "gpt-4"], help="Choose the OpenAI GPT model for the analysis." |
|
) |
|
|
|
logging_enabled = st.checkbox("Enable logging?", value=False, help="Toggle logging for debugging purposes.") |
|
|
|
|
|
if st.button("Analyze Patents"): |
|
if not user_date_input: |
|
st.error("Please enter a valid date!") |
|
else: |
|
try: |
|
|
|
input_date = datetime.strptime(user_date_input, "%Y-%m-%d") |
|
year, month, day = input_date.year, input_date.month, input_date.day |
|
|
|
|
|
with st.spinner("Downloading and extracting patents..."): |
|
saved_patent_names = preprocess_data.parse_and_save_patents( |
|
year, month, day, logging_enabled |
|
) |
|
if not saved_patent_names: |
|
st.error("No patents found for the given date.") |
|
st.stop() |
|
st.success(f"{len(saved_patent_names)} patents found and processed!") |
|
|
|
|
|
random_patents = saved_patent_names[:num_patents_to_analyze] |
|
total_cost = 0 |
|
results = [] |
|
|
|
st.write("Starting patent analysis...") |
|
for i, patent_file in enumerate(random_patents): |
|
cost, output = qa_agent.call_QA_to_json( |
|
PROMPT, |
|
year, |
|
month, |
|
day, |
|
saved_patent_names, |
|
i, |
|
logging_enabled, |
|
model_choice, |
|
) |
|
total_cost += cost |
|
results.append(output) |
|
|
|
|
|
st.write(f"**Total Cost:** ${total_cost:.4f}") |
|
st.write("### Analysis Results:") |
|
for idx, result in enumerate(results): |
|
st.subheader(f"Patent {idx + 1}") |
|
st.json(result) |
|
|
|
except ValueError as ve: |
|
st.error(f"Invalid date format: {ve}") |
|
except Exception as e: |
|
st.error(f"An unexpected error occurred: {e}") |
|
|