Spaces:
Sleeping
Sleeping
Vitomir Jovanović
commited on
Commit
·
591de4e
1
Parent(s):
348df3a
Glancing + new data
Browse files- Procfile.yaml +0 -1
- README.md +1 -1
- app.py +4 -5
- {models → data}/prompts_data.jsonl +0 -0
- fast_api.py +1 -2
- models/__pycache__/data_reader.cpython-312.pyc +0 -0
- models/__pycache__/prompt_search_engine.cpython-312.pyc +0 -0
- models/data_reader.py +17 -9
- models/prompt_search_engine.py +2 -6
- models/vectorizer.py +0 -33
Procfile.yaml
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
web: gunicorn -w 1 -k uvicorn.workers.UvicornWorker main:app --bind 0.0.0.0:8084 & streamlit run app.py --server.port 7860
|
|
|
|
README.md
CHANGED
@@ -24,7 +24,7 @@ Script creates swagger app with endpoints on [localhost:8084](http://127.0.0.1:8
|
|
24 |
data_reader.py
|
25 |
```
|
26 |
creates data of various prompts for encoding into vector database, from prompt-picture dataset.
|
27 |
-
Local database encoded only
|
28 |
Faiss index that is used is small and not optimized, used for experimental datasets. Search is brute force, not optimised.
|
29 |
|
30 |
### Streamlit
|
|
|
24 |
data_reader.py
|
25 |
```
|
26 |
creates data of various prompts for encoding into vector database, from prompt-picture dataset.
|
27 |
+
Local database encoded only 11000 prompts.
|
28 |
Faiss index that is used is small and not optimized, used for experimental datasets. Search is brute force, not optimised.
|
29 |
|
30 |
### Streamlit
|
app.py
CHANGED
@@ -1,15 +1,11 @@
|
|
1 |
import streamlit as st
|
2 |
-
from models.vectorizer import Vectorizer
|
3 |
from models.prompt_search_engine import PromptSearchEngine
|
4 |
from models.data_reader import load_prompts_from_jsonl
|
5 |
-
from models.Query import Query, SimilarPrompt, SearchResponse, PromptVector, VectorResponse
|
6 |
-
from sentence_transformers import SentenceTransformer
|
7 |
-
import os
|
8 |
|
9 |
# Cache the prompts data to avoid reloading every time
|
10 |
@st.cache_data
|
11 |
def load_prompts():
|
12 |
-
prompt_path = "
|
13 |
return load_prompts_from_jsonl(prompt_path)
|
14 |
|
15 |
# Cache the search engine initialization
|
@@ -36,12 +32,15 @@ k = st.number_input("Number of similar prompts to retrieve:", min_value=1, max_v
|
|
36 |
# Button to trigger search
|
37 |
if st.button("Search Prompts"):
|
38 |
if query_input:
|
|
|
39 |
similar_prompts, distances = search_engine.most_similar(query_input, top_k=k)
|
|
|
40 |
|
41 |
# Format and display search results
|
42 |
st.write(f"Search Results: ")
|
43 |
for i, (prompt, distance) in enumerate(zip(similar_prompts, distances)):
|
44 |
st.write(f"{i+1}. Prompt: {prompt}, Distance: {distance}")
|
|
|
45 |
else:
|
46 |
st.error("Please enter a prompt.")
|
47 |
|
|
|
1 |
import streamlit as st
|
|
|
2 |
from models.prompt_search_engine import PromptSearchEngine
|
3 |
from models.data_reader import load_prompts_from_jsonl
|
|
|
|
|
|
|
4 |
|
5 |
# Cache the prompts data to avoid reloading every time
|
6 |
@st.cache_data
|
7 |
def load_prompts():
|
8 |
+
prompt_path = "data/prompts_data.jsonl"
|
9 |
return load_prompts_from_jsonl(prompt_path)
|
10 |
|
11 |
# Cache the search engine initialization
|
|
|
32 |
# Button to trigger search
|
33 |
if st.button("Search Prompts"):
|
34 |
if query_input:
|
35 |
+
print(f'Search engine is searching the most similar prompts for query {query_input}')
|
36 |
similar_prompts, distances = search_engine.most_similar(query_input, top_k=k)
|
37 |
+
print(f'Those are: {similar_prompts}, {distances}')
|
38 |
|
39 |
# Format and display search results
|
40 |
st.write(f"Search Results: ")
|
41 |
for i, (prompt, distance) in enumerate(zip(similar_prompts, distances)):
|
42 |
st.write(f"{i+1}. Prompt: {prompt}, Distance: {distance}")
|
43 |
+
print(f'Those are: {prompt}, {distance}')
|
44 |
else:
|
45 |
st.error("Please enter a prompt.")
|
46 |
|
{models → data}/prompts_data.jsonl
RENAMED
The diff for this file is too large to render.
See raw diff
|
|
fast_api.py
CHANGED
@@ -5,7 +5,6 @@ import uvicorn
|
|
5 |
import socket
|
6 |
import logging
|
7 |
import datetime
|
8 |
-
from models.vectorizer import Vectorizer
|
9 |
from models.prompt_search_engine import PromptSearchEngine
|
10 |
from models.data_reader import load_prompts_from_jsonl
|
11 |
from models.Query import Query, Query_Multiple, SearchResponse, SimilarPrompt, PromptVector, VectorResponse
|
@@ -15,7 +14,7 @@ from sentence_transformers import SentenceTransformer
|
|
15 |
|
16 |
|
17 |
|
18 |
-
prompt_path = r"C:\Users\jov2bg\Desktop\PromptSearch\search_engine\
|
19 |
|
20 |
|
21 |
app = FastAPI(title="Search Prompt Engine", description="API for prompt search", version="1.0")
|
|
|
5 |
import socket
|
6 |
import logging
|
7 |
import datetime
|
|
|
8 |
from models.prompt_search_engine import PromptSearchEngine
|
9 |
from models.data_reader import load_prompts_from_jsonl
|
10 |
from models.Query import Query, Query_Multiple, SearchResponse, SimilarPrompt, PromptVector, VectorResponse
|
|
|
14 |
|
15 |
|
16 |
|
17 |
+
prompt_path = r"C:\Users\jov2bg\Desktop\PromptSearch\search_engine\data\prompts_data.jsonl"
|
18 |
|
19 |
|
20 |
app = FastAPI(title="Search Prompt Engine", description="API for prompt search", version="1.0")
|
models/__pycache__/data_reader.cpython-312.pyc
CHANGED
Binary files a/models/__pycache__/data_reader.cpython-312.pyc and b/models/__pycache__/data_reader.cpython-312.pyc differ
|
|
models/__pycache__/prompt_search_engine.cpython-312.pyc
CHANGED
Binary files a/models/__pycache__/prompt_search_engine.cpython-312.pyc and b/models/__pycache__/prompt_search_engine.cpython-312.pyc differ
|
|
models/data_reader.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from datasets import load_dataset
|
2 |
import json
|
|
|
3 |
|
4 |
|
5 |
# Load the dataset
|
@@ -10,21 +11,28 @@ num_shards = 46 # Number of webdataset tar files
|
|
10 |
|
11 |
def download_data(base_url, num_shards):
|
12 |
# Download the data
|
|
|
13 |
urls = [base_url.format(i=i) for i in range(num_shards)]
|
14 |
dataset = load_dataset("webdataset", data_files={"train": urls}, split="train", streaming=True)
|
15 |
return dataset
|
16 |
|
17 |
-
|
|
|
|
|
18 |
# Write data to the jsonl file
|
19 |
prompts = {}
|
|
|
|
|
20 |
with open(jsonl_file_path, 'w') as f:
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
def read_data(jsonl_file_path):
|
27 |
-
|
28 |
# Read data from the jsonl file
|
29 |
with open(jsonl_file_path, 'r') as f:
|
30 |
for line in f:
|
@@ -36,15 +44,15 @@ def load_prompts_from_jsonl(file_path):
|
|
36 |
prompts = []
|
37 |
with open(file_path, 'r') as f:
|
38 |
for line in f:
|
39 |
-
data = json.loads(line)
|
40 |
-
prompts.append(data)
|
41 |
print("Data loaded successfully.")
|
42 |
return prompts
|
43 |
|
44 |
|
45 |
if __name__ == "__main__":
|
46 |
-
jsonl_file_path = r"C:\Users\jov2bg\Desktop\PromptSearch\search_engine\
|
47 |
num_shards = 1
|
48 |
-
dataset = download_data(
|
49 |
extract_prompts(dataset, jsonl_file_path)
|
50 |
read_data(jsonl_file_path)
|
|
|
1 |
from datasets import load_dataset
|
2 |
import json
|
3 |
+
from tqdm import tqdm
|
4 |
|
5 |
|
6 |
# Load the dataset
|
|
|
11 |
|
12 |
def download_data(base_url, num_shards):
|
13 |
# Download the data
|
14 |
+
print("Downloading data...")
|
15 |
urls = [base_url.format(i=i) for i in range(num_shards)]
|
16 |
dataset = load_dataset("webdataset", data_files={"train": urls}, split="train", streaming=True)
|
17 |
return dataset
|
18 |
|
19 |
+
|
20 |
+
|
21 |
+
def extract_prompts(dataset, jsonl_file_path):
|
22 |
# Write data to the jsonl file
|
23 |
prompts = {}
|
24 |
+
print('Extracting data to:', jsonl_file_path)
|
25 |
+
|
26 |
with open(jsonl_file_path, 'w') as f:
|
27 |
+
with tqdm(desc="Processing prompts", unit=" prompt") as pbar:
|
28 |
+
for index, row in enumerate(dataset):
|
29 |
+
prompts[index] = row['json']['prompt']
|
30 |
+
f.write(json.dumps(prompts[index]) + '\n')
|
31 |
+
|
32 |
+
pbar.update(1)
|
33 |
|
34 |
|
35 |
def read_data(jsonl_file_path):
|
|
|
36 |
# Read data from the jsonl file
|
37 |
with open(jsonl_file_path, 'r') as f:
|
38 |
for line in f:
|
|
|
44 |
prompts = []
|
45 |
with open(file_path, 'r') as f:
|
46 |
for line in f:
|
47 |
+
data = json.loads(line)
|
48 |
+
prompts.append(data)
|
49 |
print("Data loaded successfully.")
|
50 |
return prompts
|
51 |
|
52 |
|
53 |
if __name__ == "__main__":
|
54 |
+
jsonl_file_path = r"C:\Users\jov2bg\Desktop\PromptSearch\search_engine\data\prompts_data_new.jsonl"
|
55 |
num_shards = 1
|
56 |
+
dataset = download_data(base_url, num_shards)
|
57 |
extract_prompts(dataset, jsonl_file_path)
|
58 |
read_data(jsonl_file_path)
|
models/prompt_search_engine.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
from typing import Sequence, List, Tuple
|
2 |
-
from models.vectorizer import Vectorizer
|
3 |
import numpy as np
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
import faiss
|
6 |
|
7 |
class PromptSearchEngine:
|
|
|
8 |
def __init__(self, model_name='bert-base-nli-mean-tokens'):
|
9 |
print("Search engine started!")
|
10 |
self.model = SentenceTransformer(model_name)
|
@@ -27,7 +27,7 @@ class PromptSearchEngine:
|
|
27 |
print('Finding the most similar vectors')
|
28 |
query_embedding = self.model.encode([query]).astype('float32')
|
29 |
|
30 |
-
# Optimizovana pretraga ali moramo promeniti vrstu indeksa
|
31 |
distances, indices = self.index.search(query_embedding, top_k)
|
32 |
|
33 |
# Retrieve the corresponding prompts for the found indices
|
@@ -47,12 +47,8 @@ class PromptSearchEngine:
|
|
47 |
|
48 |
# Get all vectors from FAISS
|
49 |
index_vectors = index.reconstruct_n(0, index.ntotal) # Reconstruct all vectors in the index
|
50 |
-
|
51 |
-
|
52 |
index_norms = np.linalg.norm(index_vectors, axis=1, keepdims=True)
|
53 |
normalized_index_vectors = index_vectors / index_norms
|
54 |
-
|
55 |
-
|
56 |
cosine_similarities = np.dot(normalized_index_vectors, query_norm.T)
|
57 |
|
58 |
return cosine_similarities
|
|
|
1 |
from typing import Sequence, List, Tuple
|
|
|
2 |
import numpy as np
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
import faiss
|
5 |
|
6 |
class PromptSearchEngine:
|
7 |
+
'''Instanciate the language model and index for searching the most similar prompts. Performs the semantic search.'''
|
8 |
def __init__(self, model_name='bert-base-nli-mean-tokens'):
|
9 |
print("Search engine started!")
|
10 |
self.model = SentenceTransformer(model_name)
|
|
|
27 |
print('Finding the most similar vectors')
|
28 |
query_embedding = self.model.encode([query]).astype('float32')
|
29 |
|
30 |
+
# Optimizovana pretraga ali moramo promeniti vrstu indeksa za pretragu kod stvarne upotrebe
|
31 |
distances, indices = self.index.search(query_embedding, top_k)
|
32 |
|
33 |
# Retrieve the corresponding prompts for the found indices
|
|
|
47 |
|
48 |
# Get all vectors from FAISS
|
49 |
index_vectors = index.reconstruct_n(0, index.ntotal) # Reconstruct all vectors in the index
|
|
|
|
|
50 |
index_norms = np.linalg.norm(index_vectors, axis=1, keepdims=True)
|
51 |
normalized_index_vectors = index_vectors / index_norms
|
|
|
|
|
52 |
cosine_similarities = np.dot(normalized_index_vectors, query_norm.T)
|
53 |
|
54 |
return cosine_similarities
|
models/vectorizer.py
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
|
2 |
-
from sentence_transformers import SentenceTransformer
|
3 |
-
import numpy as np
|
4 |
-
from typing import Sequence
|
5 |
-
import faiss
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
class Vectorizer:
|
12 |
-
def __init__(self, model) -> None:
|
13 |
-
"""Initialize the vectorizer with a pre-trained embedding model.
|
14 |
-
Args: model: The pre-trained embedding model to use for transforming prompts.
|
15 |
-
"""
|
16 |
-
self.model = model
|
17 |
-
self.index_size = 50000
|
18 |
-
self.index = faiss.IndexFlatIP(self.index_size)
|
19 |
-
self.cached_index_idx_to_retrieval_db_idx = []
|
20 |
-
|
21 |
-
|
22 |
-
def transform_and_add_to_index(self, prompts: Sequence[str]) -> np.ndarray:
|
23 |
-
"""Transform texts into numerical vectors using the specified model.
|
24 |
-
Args: prompts: The sequence of raw corpus prompts. Returns: Vectorized prompts
|
25 |
-
"""
|
26 |
-
embeddings = self.model.encode(prompts)
|
27 |
-
embedding_dimension = embeddings.shape[1]
|
28 |
-
print('Embedding dimension:', embedding_dimension)
|
29 |
-
|
30 |
-
self.index.add(np.array(embeddings))
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|