DrishtiSharma's picture
Update app.py
504a395 verified
raw
history blame
5.02 kB
import sys
import os
import re
import shutil
import time
import streamlit as st
sys.path.append(os.path.abspath("."))
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.llms import OpenAI
from langchain.document_loaders import UnstructuredPDFLoader
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import NLTKTextSplitter
from patent_downloader import PatentDownloader
PERSISTED_DIRECTORY = "."
# Fetch API key securely from the environment
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
st.error("Critical Error: OpenAI API key not found in the environment variables. Please configure it.")
st.stop()
def check_poppler_installed():
if not shutil.which("pdfinfo"):
raise EnvironmentError(
"Poppler is not installed or not in PATH. Install 'poppler-utils' for PDF processing."
)
check_poppler_installed()
def load_docs(document_path):
loader = UnstructuredPDFLoader(document_path)
documents = loader.load()
text_splitter = NLTKTextSplitter(chunk_size=1000)
return text_splitter.split_documents(documents)
def already_indexed(vectordb, file_name):
indexed_sources = set(
x["source"] for x in vectordb.get(include=["metadatas"])["metadatas"]
)
return file_name in indexed_sources
def load_chain(file_name=None):
loaded_patent = st.session_state.get("LOADED_PATENT")
vectordb = Chroma(
persist_directory=PERSISTED_DIRECTORY,
embedding_function=HuggingFaceEmbeddings(),
)
if loaded_patent == file_name or already_indexed(vectordb, file_name):
st.write("Already indexed")
else:
vectordb.delete_collection()
docs = load_docs(file_name)
st.write("Length: ", len(docs))
vectordb = Chroma.from_documents(
docs, HuggingFaceEmbeddings(), persist_directory=PERSISTED_DIRECTORY
)
vectordb.persist()
st.session_state["LOADED_PATENT"] = file_name
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
input_key="question",
output_key="answer",
)
return ConversationalRetrievalChain.from_llm(
OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY),
vectordb.as_retriever(search_kwargs={"k": 3}),
return_source_documents=False,
memory=memory,
)
def extract_patent_number(url):
pattern = r"/patent/([A-Z]{2}\d+)"
match = re.search(pattern, url)
return match.group(1) if match else None
def download_pdf(patent_number):
patent_downloader = PatentDownloader()
patent_downloader.download(patent=patent_number)
return f"{patent_number}.pdf"
if __name__ == "__main__":
st.set_page_config(
page_title="Patent Chat: Google Patents Chat Demo",
page_icon="πŸ“–",
layout="wide",
initial_sidebar_state="expanded",
)
st.header("πŸ“– Patent Chat: Google Patents Chat Demo")
# Allow user to input the Google patent link
patent_link = st.text_input("Enter Google Patent Link:", key="PATENT_LINK")
if not patent_link:
st.warning("Please enter a Google patent link to proceed.")
st.stop()
else:
st.session_state["patent_link_configured"] = True
patent_number = extract_patent_number(patent_link)
if not patent_number:
st.error("Invalid patent link format. Please provide a valid Google patent link.")
st.stop()
st.write("Patent number: ", patent_number)
pdf_path = f"{patent_number}.pdf"
if os.path.isfile(pdf_path):
st.write("File already downloaded.")
else:
st.write("Downloading patent file...")
pdf_path = download_pdf(patent_number)
st.write("File downloaded.")
chain = load_chain(pdf_path)
if "messages" not in st.session_state:
st.session_state["messages"] = [
{"role": "assistant", "content": "How can I help you?"}
]
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if user_input := st.chat_input("What is your question?"):
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.markdown(user_input)
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
with st.spinner("CHAT-BOT is at Work ..."):
assistant_response = chain({"question": user_input})
for chunk in assistant_response["answer"].split():
full_response += chunk + " "
time.sleep(0.05)
message_placeholder.markdown(full_response + "β–Œ")
st.session_state.messages.append(
{"role": "assistant", "content": full_response}
)