mena-app / app.py
wjbmattingly's picture
Add application file
289ba91
from dotenv import load_dotenv
import os
import pandas as pd
# Load environment variables from .env file
load_dotenv()
import gradio as gr
from weaviate.classes.query import QueryReference
import weaviate
from sentence_transformers import SentenceTransformer
from weaviate.auth import Auth
model = SentenceTransformer('all-MiniLM-L6-v2')
# Now these will work with your .env file
WEAVIATE_URL = os.getenv("WEAVIATE_URL")
WEAVIATE_API_KEY = os.getenv("WEAVIATE_API_KEY")
RESULTS_PER_PAGE = 5
# Add custom CSS near the top of the file
custom_css = """
.container {
max-width: 1000px !important;
margin: 0 auto !important;
padding: 2rem !important;
background-color: #f8fafc !important; /* Light blue-gray background */
}
.search-box {
margin-bottom: 2rem !important;
}
.search-button {
background-color: #0f172a !important; /* Deep blue */
color: #ffffff !important;
border-radius: 6px !important;
transition: background-color 0.3s ease !important;
}
.search-button:hover {
background-color: #1e293b !important; /* Slightly lighter blue on hover */
}
.pagination-button {
background-color: #ffffff !important;
color: #0f172a !important;
border: 1px solid #cbd5e1 !important;
border-radius: 6px !important;
min-width: 100px !important;
transition: all 0.3s ease !important;
}
.pagination-button:hover {
background-color: #f1f5f9 !important;
border-color: #94a3b8 !important;
}
.paper-card {
border: 1px solid #e2e8f0 !important;
border-radius: 12px !important;
margin-bottom: 1.5rem !important;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1) !important;
background: #ffffff !important;
transition: transform 0.2s ease, box-shadow 0.2s ease !important;
}
.paper-card:hover {
transform: translateY(-2px) !important;
box-shadow: 0 6px 12px -2px rgba(0, 0, 0, 0.15) !important;
}
.card-header {
background: #f1f5f9 !important;
padding: 1.25rem !important;
border-bottom: 1px solid #e2e8f0 !important;
border-radius: 12px 12px 0 0 !important;
cursor: pointer !important;
}
.card-header h3 {
color: #0f172a !important; /* Darker text for better contrast */
font-size: 1.1rem !important;
margin: 0 !important;
font-weight: 600 !important;
}
.card-content {
padding: 1.25rem !important;
color: #0f172a !important; /* Changed from #334155 to darker color */
line-height: 1.6 !important;
}
/* Additional styles for better typography and links */
a {
color: #2563eb !important;
text-decoration: none !important;
transition: color 0.2s ease !important;
}
a:hover {
color: #1d4ed8 !important;
}
/* Style for the main title */
h1 {
color: #0f172a !important;
font-weight: 700 !important;
margin-bottom: 2rem !important;
}
/* Style for the search input */
.gradio-textbox input {
border: 2px solid #e2e8f0 !important;
border-radius: 8px !important;
padding: 0.75rem !important;
transition: border-color 0.3s ease !important;
}
.gradio-textbox input:focus {
border-color: #2563eb !important;
outline: none !important;
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1) !important;
}
/* Make sure all text content has good contrast */
p, span, label {
color: #0f172a !important; /* Consistent dark color for all text */
}
/* Style for labels and other UI text */
.gradio-textbox label {
color: #0f172a !important;
font-weight: 500 !important;
}
/* Page label styling */
.gradio-label {
color: #0f172a !important;
font-weight: 500 !important;
font-size: 0.875rem !important; /* Smaller font size */
}
/* Make sure author links maintain proper color */
.card-content a {
color: #2563eb !important;
}
"""
def search_papers(query):
if not query:
return "Please enter a search query", "Page 1 of 1", None
vector_query = model.encode(query)
client = weaviate.connect_to_weaviate_cloud(
cluster_url=WEAVIATE_URL,
auth_credentials=Auth.api_key(WEAVIATE_API_KEY),
)
work_collection = client.collections.get("Work")
# Get all results at once
response = work_collection.query.near_vector(
near_vector=vector_query,
return_properties=["title", "abstract", "open_alex_id"],
limit=1000, # Adjust this based on your needs
return_references=[
QueryReference(
link_on="authors",
return_properties=["display_name", "open_alex_id", "concept_ids"]
)
]
)
if not response.objects:
return "No results found", "Page 0 of 0", None
# Convert results to DataFrame
results = []
for work in response.objects:
author_links = []
if work.references.get('authors'):
for author in work.references['authors'].objects:
author_url = author.properties['open_alex_id']
author_name = author.properties['display_name']
author_links.append(f"<a href='{author_url}' target='_blank' style='color: #2563eb !important;'>{author_name}</a>")
author_links = list(set(author_links))
results.append({
'title': work.properties['title'],
'work_url': work.properties['open_alex_id'],
'abstract': work.properties['abstract'],
'authors': ', '.join(author_links),
})
return pd.DataFrame(results), len(results)
def format_page(df, page_num):
if df is None:
return "No results found", '<div style="text-align: center; margin: 1rem 0; color: #0f172a;">Page 0 of 0</div>'
start_idx = (page_num - 1) * RESULTS_PER_PAGE
end_idx = start_idx + RESULTS_PER_PAGE
page_df = df.iloc[start_idx:end_idx]
total_pages = (len(df) + RESULTS_PER_PAGE - 1) // RESULTS_PER_PAGE
results_html = ""
for i, row in enumerate(page_df.itertuples(), start=start_idx+1):
results_html += f"""
<div class="paper-card">
<div class="card-header"
onclick="this.nextElementSibling.style.display = this.nextElementSibling.style.display === 'none' ? 'block' : 'none'">
<h3>{i}. {row.title}</h3>
</div>
<div class="card-content" style="display:none;">
<p style="color: #0f172a !important;"><b style="color: #0f172a !important;">Authors:</b> <span style="color: #0f172a !important;">{row.authors}</span></p>
<p>{row.abstract}</p>
<p><a href="{row.work_url}" target="_blank"
style="color: #2563eb !important; text-decoration: none;">View on OpenAlex →</a></p>
</div>
</div>
"""
return results_html, f'<div style="text-align: center; margin: 1rem 0; color: #0f172a;">Page {page_num} of {total_pages}</div>'
# Modified Gradio interface
with gr.Blocks(css=custom_css) as demo:
with gr.Column(elem_classes="container"):
gr.Markdown("# MENA Open-Alex Semantic Search")
with gr.Column(elem_classes="search-box"):
query_input = gr.Textbox(
label="Enter your query:",
placeholder="Search for papers..."
)
search_button = gr.Button("Search", elem_classes="search-button")
# Results display
results_output = gr.HTML()
page_label = gr.HTML(value='<div style="text-align: center; margin: 1rem 0; color: #0f172a;">Page 1 of 1</div>')
# Pagination controls
with gr.Row():
prev_button = gr.Button("Previous", elem_classes="pagination-button")
next_button = gr.Button("Next", elem_classes="pagination-button")
# Page state
page_number = gr.State(value=1)
# Add DataFrame state
results_df = gr.State(value=None)
def search_with_page(query, page):
df, total = search_papers(query)
return (*format_page(df, 1), df, 1)
def prev_page(query, page, df):
if page > 1:
return (*format_page(df, page - 1), page - 1)
return (*format_page(df, page), page)
def next_page(query, page, df):
total_pages = (len(df) + RESULTS_PER_PAGE - 1) // RESULTS_PER_PAGE
if page < total_pages:
return (*format_page(df, page + 1), page + 1)
return (*format_page(df, page), page)
# Modified event handlers
search_button.click(
fn=search_with_page,
inputs=[query_input, page_number],
outputs=[results_output, page_label, results_df, page_number]
)
prev_button.click(
fn=prev_page,
inputs=[query_input, page_number, results_df],
outputs=[results_output, page_label, page_number]
)
next_button.click(
fn=next_page,
inputs=[query_input, page_number, results_df],
outputs=[results_output, page_label, page_number]
)
demo.launch()