matterattetatte's picture
Update app.py
3f6753e verified
import streamlit as st
from smolagents import Tool, CodeAgent, HfApiModel
from langchain.text_splitter import RecursiveCharacterTextSplitter, MarkdownTextSplitter
from langchain_community.retrievers import BM25Retriever
from langchain.docstore.document import Document
from datasets import load_dataset, concatenate_datasets
st.set_page_config(
page_title="Science Search Engine",
page_icon="πŸ“š",
layout="wide"
)
class RetrieverTool(Tool):
name = "retriever"
description = "Uses BM25 search to retrieve relevant scientific documentation"
inputs = {
"query": {
"type": "string",
"description": "The scientific query in affirmative form rather than a question"
}
}
output_type = "string"
def __init__(self, docs, k1=1.5, b=0.75, **kwargs):
super().__init__(**kwargs)
self.retriever = BM25Retriever.from_documents(
docs,
k=12,
k1=k1,
b=b
)
self.docs = docs
self.avg_doc_length = sum(len(doc.page_content.split()) for doc in docs) / len(docs)
def forward(self, query: str) -> str: # Matches exactly with inputs
# Preprocess query
query = self._preprocess_query(query)
# Retrieve documents
docs = self.retriever.get_relevant_documents(query)
# Format response
main_response = "Retrieved documents (ranked by relevance):\n\n"
for i, doc in enumerate(docs, 1):
doc_length = len(doc.page_content.split())
length_factor = doc_length / self.avg_doc_length
main_response += f"Document {i} (Length Factor: {length_factor:.2f})\n"
main_response += f"{doc.page_content}\n\n"
if doc.metadata:
main_response += f"Metadata: {doc.metadata}\n"
main_response += "---\n\n"
return main_response
def _preprocess_query(self, query: str) -> str:
question_words = ["what", "when", "where", "who", "why", "how"]
query_terms = query.lower().split()
if query_terms[0] in question_words:
query_terms = query_terms[1:]
return " ".join(query_terms)
# Process documents
def prepare_docs(documents):
text_splitter = MarkdownTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
return text_splitter.split_documents(documents)
# Initialize agent
def create_rag_agent(processed_docs):
retriever_tool = RetrieverTool(processed_docs)
return CodeAgent(
tools=[retriever_tool],
model=HfApiModel(),
verbose=True
)
def format_search_results(results: str):
"""Format the search results into main content and sources sections"""
if "### πŸ“š Sources:" in results:
main_content, sources = results.split("### πŸ“š Sources:")
# Create two columns with adjusted ratios
col1, col2 = st.columns([3, 2])
with col1:
st.markdown("### πŸ“– Main Findings")
st.markdown(main_content)
with col2:
st.markdown("### πŸ“š Sources")
st.markdown(sources, unsafe_allow_html=True)
else:
st.markdown(results)
@st.cache_resource
def get_agent():
"""Single function to handle data loading, processing, and agent creation"""
# Load dataset
dataset = load_dataset("camel-ai/biology")
train_docs = dataset["train"]
source_docs = concatenate_datasets([train_docs])
# Create documents
documents = [
Document(
page_content=item['message_2'],
metadata={
# "source": item['url'],
"title": item['message_1'],
"description": item['sub_topic'],
# "published_time": item['publishedTime']
}
)
for item in source_docs
]
# Process documents
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=500,
add_start_index=True,
strip_whitespace=True,
)
processed_docs = text_splitter.split_documents(documents)
# Create and return agent
retriever_tool = RetrieverTool(processed_docs)
return CodeAgent(
tools=[retriever_tool],
model=HfApiModel(),
)
# Streamlit UI
st.title("πŸ“š Scientific Search Engine")
st.markdown("""
This search engine uses advanced AI to help you explore science.
It provides detailed, sourced information from a curated database of scientific knowledge.
""")
# Initialize agent
if 'agent' not in st.session_state:
with st.spinner("Loading database..."):
st.session_state.agent = get_agent()
# Search interface
search_query = st.text_input(
"πŸ” Search African History",
placeholder="E.g., Tell me about cancer in dogs",
help="Enter any question about science"
)
# Advanced search options
with st.expander("Advanced Search Options"):
search_type = st.radio(
"Search Type",
["General Query", "Scientific branches"],
help="Select the type of search you want to perform"
)
if search_type == "Scientific branches":
search_query = f"Focus on the specific scientific branch of: {search_query}"
elif search_type == "Geographic Region":
search_query = f"Focus on the region of: {search_query}"
# Search button
if st.button("Search", type="primary"):
if search_query:
with st.spinner("Searching records..."):
try:
results = st.session_state.agent.run(search_query)
# Use the formatter to display results
format_search_results(results)
# Add methodology note
st.markdown("---")
st.info("""
πŸ’‘ **How to read the results:**
- Main findings are summarized on the left
- Source references are numbered [Source X]
- Click on source details on the right to expand
- Follow the links to read the original articles
""")
except Exception as e:
st.error(f"An error occurred during the search: {e}")
else:
st.warning("Please enter a search query to begin.")
# Sidebar with additional information
with st.sidebar:
st.markdown("### About This Search Engine")
st.markdown("""
This search engine specializes in African history, providing:
- πŸ“š Detailed information
- πŸ” Source verification
- 🌍 Geographic context
- ⏳ Historical timeline context
""")
st.markdown("### Data Sources")
st.markdown("Our database includes information from various historical documents, "
"academic papers, and verified historical records.")
# Footer
st.markdown("---")
st.caption("Powered by SmolAgents, RAG, and Camel AI Dataset")