|
import pinecone |
|
import requests |
|
import streamlit as st |
|
import torch |
|
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
from config import config |
|
|
|
|
|
def search(text: str, k: int = 5): |
|
"""Get the k closest articles to the text.""" |
|
embeds = _get_embeddings(text) |
|
|
|
r = requests.post( |
|
f"https://{config.pinecone_index}-5b18b87.svc.{config.pinecone_env}.pinecone.io/query", |
|
headers={ |
|
"Api-Key": config.pinecone_api_key, |
|
"accept": "application/json", |
|
"content-type": "application/json", |
|
}, |
|
json={ |
|
"vector": embeds, |
|
"top_k": k, |
|
"includeMetadata": True, |
|
"includeValues": False, |
|
}, |
|
) |
|
|
|
if r.status_code == 200: |
|
return r.json() |
|
else: |
|
raise Exception(f"Error: {r.status_code} - {r.text}") |
|
|
|
|
|
def _get_embeddings(text: str): |
|
inputs_ids = st.session_state.tokenizer(text, return_tensors="pt", padding=True, truncation=True) |
|
|
|
with torch.no_grad(): |
|
last_hidden_states = st.session_state.model(**inputs_ids)[0] |
|
|
|
return last_hidden_states.mean(dim=1).squeeze().tolist() |
|
|
|
|
|
|
|
password = st.text_input("Password", type="password") |
|
if password == config.password: |
|
st.title("PubMed Embeddings") |
|
st.subheader("Search for a PubMed article and get its id.") |
|
|
|
text = st.text_input("Search for a PubMed article", "Epidemiology of COVID-19") |
|
|
|
with st.spinner("Loading Embedding Model..."): |
|
pinecone.init(api_key=config.pinecone_api_key, env=config.pinecone_env) |
|
if "index" not in st.session_state: |
|
st.session_state.index = pinecone.Index(config.pinecone_index) |
|
if "tokenizer" not in st.session_state: |
|
st.session_state.tokenizer = AutoTokenizer.from_pretrained(config.model_name) |
|
if "model" not in st.session_state: |
|
st.session_state.model = AutoModel.from_pretrained(config.model_name) |
|
|
|
if st.button("Search"): |
|
with st.spinner("Searching..."): |
|
results = search(text) |
|
|
|
for res in results["matches"]: |
|
st.write(f"{res['id']} - confidence: {res['score']:.2f}") |
|
else: |
|
st.write("Password incorrect!") |
|
|
|
|