|
import streamlit as st |
|
from smolagents import Tool, CodeAgent, HfApiModel |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, MarkdownTextSplitter |
|
from langchain_community.retrievers import BM25Retriever |
|
from langchain.docstore.document import Document |
|
from datasets import load_dataset, concatenate_datasets |
|
|
|
st.set_page_config( |
|
page_title="Science Search Engine", |
|
page_icon="π", |
|
layout="wide" |
|
) |
|
|
|
class RetrieverTool(Tool): |
|
name = "retriever" |
|
description = "Uses BM25 search to retrieve relevant scientific documentation" |
|
inputs = { |
|
"query": { |
|
"type": "string", |
|
"description": "The scientific query in affirmative form rather than a question" |
|
} |
|
} |
|
output_type = "string" |
|
|
|
def __init__(self, docs, k1=1.5, b=0.75, **kwargs): |
|
super().__init__(**kwargs) |
|
self.retriever = BM25Retriever.from_documents( |
|
docs, |
|
k=12, |
|
k1=k1, |
|
b=b |
|
) |
|
self.docs = docs |
|
self.avg_doc_length = sum(len(doc.page_content.split()) for doc in docs) / len(docs) |
|
|
|
def forward(self, query: str) -> str: |
|
|
|
query = self._preprocess_query(query) |
|
|
|
|
|
docs = self.retriever.get_relevant_documents(query) |
|
|
|
|
|
main_response = "Retrieved documents (ranked by relevance):\n\n" |
|
|
|
for i, doc in enumerate(docs, 1): |
|
doc_length = len(doc.page_content.split()) |
|
length_factor = doc_length / self.avg_doc_length |
|
|
|
main_response += f"Document {i} (Length Factor: {length_factor:.2f})\n" |
|
main_response += f"{doc.page_content}\n\n" |
|
|
|
if doc.metadata: |
|
main_response += f"Metadata: {doc.metadata}\n" |
|
main_response += "---\n\n" |
|
|
|
return main_response |
|
|
|
def _preprocess_query(self, query: str) -> str: |
|
question_words = ["what", "when", "where", "who", "why", "how"] |
|
query_terms = query.lower().split() |
|
if query_terms[0] in question_words: |
|
query_terms = query_terms[1:] |
|
return " ".join(query_terms) |
|
|
|
|
|
def prepare_docs(documents): |
|
text_splitter = MarkdownTextSplitter( |
|
chunk_size=1000, |
|
chunk_overlap=200 |
|
) |
|
return text_splitter.split_documents(documents) |
|
|
|
|
|
def create_rag_agent(processed_docs): |
|
retriever_tool = RetrieverTool(processed_docs) |
|
return CodeAgent( |
|
tools=[retriever_tool], |
|
model=HfApiModel(), |
|
verbose=True |
|
) |
|
|
|
def format_search_results(results: str): |
|
"""Format the search results into main content and sources sections""" |
|
if "### π Sources:" in results: |
|
main_content, sources = results.split("### π Sources:") |
|
|
|
|
|
col1, col2 = st.columns([3, 2]) |
|
|
|
with col1: |
|
st.markdown("### π Main Findings") |
|
st.markdown(main_content) |
|
|
|
with col2: |
|
st.markdown("### π Sources") |
|
st.markdown(sources, unsafe_allow_html=True) |
|
else: |
|
st.markdown(results) |
|
|
|
@st.cache_resource |
|
def get_agent(): |
|
"""Single function to handle data loading, processing, and agent creation""" |
|
|
|
dataset = load_dataset("camel-ai/biology") |
|
train_docs = dataset["train"] |
|
source_docs = concatenate_datasets([train_docs]) |
|
|
|
|
|
documents = [ |
|
Document( |
|
page_content=item['message_2'], |
|
metadata={ |
|
|
|
"title": item['message_1'], |
|
"description": item['sub_topic'], |
|
|
|
} |
|
) |
|
for item in source_docs |
|
] |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1000, |
|
chunk_overlap=500, |
|
add_start_index=True, |
|
strip_whitespace=True, |
|
) |
|
processed_docs = text_splitter.split_documents(documents) |
|
|
|
|
|
retriever_tool = RetrieverTool(processed_docs) |
|
return CodeAgent( |
|
tools=[retriever_tool], |
|
model=HfApiModel(), |
|
) |
|
|
|
|
|
st.title("π Scientific Search Engine") |
|
st.markdown(""" |
|
This search engine uses advanced AI to help you explore science. |
|
It provides detailed, sourced information from a curated database of scientific knowledge. |
|
""") |
|
|
|
|
|
if 'agent' not in st.session_state: |
|
with st.spinner("Loading database..."): |
|
st.session_state.agent = get_agent() |
|
|
|
|
|
search_query = st.text_input( |
|
"π Search African History", |
|
placeholder="E.g., Tell me about cancer in dogs", |
|
help="Enter any question about science" |
|
) |
|
|
|
|
|
with st.expander("Advanced Search Options"): |
|
search_type = st.radio( |
|
"Search Type", |
|
["General Query", "Scientific branches"], |
|
help="Select the type of search you want to perform" |
|
) |
|
|
|
if search_type == "Scientific branches": |
|
search_query = f"Focus on the specific scientific branch of: {search_query}" |
|
elif search_type == "Geographic Region": |
|
search_query = f"Focus on the region of: {search_query}" |
|
|
|
|
|
if st.button("Search", type="primary"): |
|
if search_query: |
|
with st.spinner("Searching records..."): |
|
try: |
|
results = st.session_state.agent.run(search_query) |
|
|
|
|
|
format_search_results(results) |
|
|
|
|
|
st.markdown("---") |
|
st.info(""" |
|
π‘ **How to read the results:** |
|
- Main findings are summarized on the left |
|
- Source references are numbered [Source X] |
|
- Click on source details on the right to expand |
|
- Follow the links to read the original articles |
|
""") |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred during the search: {e}") |
|
else: |
|
st.warning("Please enter a search query to begin.") |
|
|
|
|
|
with st.sidebar: |
|
st.markdown("### About This Search Engine") |
|
st.markdown(""" |
|
This search engine specializes in African history, providing: |
|
- π Detailed information |
|
- π Source verification |
|
- π Geographic context |
|
- β³ Historical timeline context |
|
""") |
|
|
|
st.markdown("### Data Sources") |
|
st.markdown("Our database includes information from various historical documents, " |
|
"academic papers, and verified historical records.") |
|
|
|
|
|
st.markdown("---") |
|
st.caption("Powered by SmolAgents, RAG, and Camel AI Dataset") |