embeds / main.py
chainyo's picture
fix password
1b29f8c
import pinecone
import requests
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
from config import config
def search(text: str, k: int = 5):
"""Get the k closest articles to the text."""
embeds = _get_embeddings(text)
r = requests.post(
f"https://{config.pinecone_index}-5b18b87.svc.{config.pinecone_env}.pinecone.io/query",
headers={
"Api-Key": config.pinecone_api_key,
"accept": "application/json",
"content-type": "application/json",
},
json={
"vector": embeds,
"top_k": k,
"includeMetadata": True,
"includeValues": False,
},
)
if r.status_code == 200:
return r.json()
else:
raise Exception(f"Error: {r.status_code} - {r.text}")
def _get_embeddings(text: str):
inputs_ids = st.session_state.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
last_hidden_states = st.session_state.model(**inputs_ids)[0]
return last_hidden_states.mean(dim=1).squeeze().tolist()
password = st.text_input("Password", type="password")
if password == config.password:
st.title("PubMed Embeddings")
st.subheader("Search for a PubMed article and get its id.")
text = st.text_input("Search for a PubMed article", "Epidemiology of COVID-19")
with st.spinner("Loading Embedding Model..."):
pinecone.init(api_key=config.pinecone_api_key, env=config.pinecone_env)
if "index" not in st.session_state:
st.session_state.index = pinecone.Index(config.pinecone_index)
if "tokenizer" not in st.session_state:
st.session_state.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
if "model" not in st.session_state:
st.session_state.model = AutoModel.from_pretrained(config.model_name)
if st.button("Search"):
with st.spinner("Searching..."):
results = search(text)
for res in results["matches"]:
st.write(f"{res['id']} - confidence: {res['score']:.2f}")
else:
st.write("Password incorrect!")