chainyo commited on
Commit
c24fab8
·
1 Parent(s): 3622f89

init space

Browse files
Files changed (3) hide show
  1. config.py +29 -0
  2. main.py +64 -0
  3. requirements.txt +7 -0
config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import getenv
2
+ from dotenv import load_dotenv
3
+
4
+ from pydantic import BaseSettings
5
+
6
+
7
+ class Config(BaseSettings):
8
+ """Configuration for the application."""
9
+ # Pinecone
10
+ pinecone_api_key: str
11
+ pinecone_env: str
12
+ pinecone_index: str
13
+ # LLM
14
+ embedding_dim: int
15
+ embedding_version: str
16
+ embedding_dir: str
17
+ model_name: str
18
+
19
+
20
+ load_dotenv()
21
+ config = Config(
22
+ pinecone_api_key=getenv("PINECONE_API_KEY"),
23
+ pinecone_env=getenv("PINECONE_ENV"),
24
+ pinecone_index=getenv("PINECONE_INDEX"),
25
+ embedding_dim=int(getenv("EMBEDDING_DIM")),
26
+ embedding_version=getenv("EMBEDDING_VERSION"),
27
+ embedding_dir=getenv("EMBEDDING_DIR"),
28
+ model_name=getenv("MODEL_NAME"),
29
+ )
main.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pinecone
2
+ import requests
3
+ import streamlit as st
4
+ import torch
5
+
6
+ from transformers import AutoTokenizer, AutoModel
7
+
8
+ from config import config
9
+
10
+
11
+ def search(text: str, k: int = 5):
12
+ """Get the k closest articles to the text."""
13
+ embeds = _get_embeddings(text)
14
+
15
+ r = requests.post(
16
+ f"https://{config.pinecone_index}-5b18b87.svc.{config.pinecone_env}.pinecone.io/query",
17
+ headers={
18
+ "Api-Key": config.pinecone_api_key,
19
+ "accept": "application/json",
20
+ "content-type": "application/json",
21
+ },
22
+ json={
23
+ "vector": embeds,
24
+ "top_k": k,
25
+ "includeMetadata": True,
26
+ "includeValues": False,
27
+ },
28
+ )
29
+
30
+ if r.status_code == 200:
31
+ return r.json()
32
+ else:
33
+ raise Exception(f"Error: {r.status_code} - {r.text}")
34
+
35
+
36
+ def _get_embeddings(text: str):
37
+ inputs_ids = st.session_state.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
38
+
39
+ with torch.no_grad():
40
+ last_hidden_states = st.session_state.model(**inputs_ids)[0]
41
+
42
+ return last_hidden_states.mean(dim=1).squeeze().tolist()
43
+
44
+
45
+ st.title("PubMed Embeddings")
46
+ st.subheader("Search for a PubMed article and get its id.")
47
+
48
+ text = st.text_input("Search for a PubMed article", "Epidemiology of COVID-19")
49
+
50
+ with st.spinner("Loading Embedding Model..."):
51
+ pinecone.init(api_key=config.pinecone_api_key, env=config.pinecone_env)
52
+ if "index" not in st.session_state:
53
+ st.session_state.index = pinecone.Index(config.pinecone_index)
54
+ if "tokenizer" not in st.session_state:
55
+ st.session_state.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
56
+ if "model" not in st.session_state:
57
+ st.session_state.model = AutoModel.from_pretrained(config.model_name)
58
+
59
+ if st.button("Search"):
60
+ with st.spinner("Searching..."):
61
+ results = st.session_state.embeds_handler.search(text)
62
+
63
+ for res in results["matches"]:
64
+ st.write(f"{res['id']} - confidence: {res['score']:.2f}")
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ pinecone-client>=2.1.0
2
+ python-dotenv>=0.21.1
3
+ pydantic>=1.10.4
4
+ requests>=2.26.0
5
+ streamlit>=1.17.0
6
+ transformers>=4.26.0
7
+ torch>=1.12.0