rag-retrieve / app.py
davidberenstein1957's picture
Update app.py
e9fd1e4 verified
import duckdb
import gradio as gr
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding
from huggingface_hub import get_token
static_embedding = StaticEmbedding.from_model2vec("minishlab/potion-base-8M")
model = SentenceTransformer(modules=[static_embedding])
embedding_dimensions = model.get_sentence_embedding_dimension()
dataset_name = "ai-blueprint/fineweb-bbc-news-embeddings"
embedding_column = "embeddings"
embedding_column_float = f"{embedding_column}_float"
table_name = "fineweb"
duckdb.sql(query=f"""
INSTALL vss;
LOAD vss;
CREATE TABLE {table_name} AS
SELECT *, {embedding_column}::float[{embedding_dimensions}] as {embedding_column_float}
FROM 'hf://datasets/{dataset_name}/**/*.parquet';
CREATE INDEX my_hnsw_index ON {table_name} USING HNSW ({embedding_column_float}) WITH (metric = 'cosine');
""")
def similarity_search(query: str, k: int = 5):
embedding = model.encode(query).tolist()
df = duckdb.sql(
query=f"""
SELECT *, array_cosine_distance({embedding_column_float}, {embedding}::FLOAT[{embedding_dimensions}]) as distance
FROM {table_name}
ORDER BY distance
LIMIT {k};
"""
).to_df()
df = df.drop(columns=[embedding_column, embedding_column_float])
return df
with gr.Blocks() as demo:
gr.Markdown("""# RAG - retrieve
Executes vector search on top of [fineweb-bbc-news-embeddings](https://huggingface.co./datasets/ai-blueprint/fineweb-bbc-news-embeddings) using DuckDB.
Part of [AI blueprint](https://github.com/huggingface/ai-blueprint) - a blueprint for AI development, focusing on practical examples of RAG, information extraction, analysis and fine-tuning in the age of LLMs. """)
query = gr.Textbox(label="Query")
k = gr.Slider(1, 50, value=5, label="Number of results")
btn = gr.Button("Search")
results = gr.Dataframe(headers=["url", "chunk", "distance"], wrap=True)
btn.click(fn=similarity_search, inputs=[query, k], outputs=[results])
demo.launch()