Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import get_token, InferenceClient | |
from sentence_transformers import CrossEncoder | |
gradio_client = Client("https://smol-blueprint-vector-search-hub.hf.space/") | |
reranker = CrossEncoder("sentence-transformers/all-MiniLM-L12-v2") | |
inference_client = InferenceClient(api_key=get_token()) | |
def similarity_search(query: str, k: int = 5): | |
results = gradio_client.predict(api_name="/similarity_search", query=query, k=k) | |
return pd.DataFrame(data=results["data"], columns=results["headers"]) | |
def query_and_rerank_documents(query: str, k_retrieved: int = 10): | |
documents = similarity_search(query, k_retrieved) | |
documents = documents.drop_duplicates("chunk") | |
documents["rank"] = reranker.predict([[query, hit] for hit in documents["chunk"]]) | |
reranked_documents = documents.sort_values(by="rank", ascending=False) | |
return reranked_documents | |
def generate_response_api(query: str): | |
messages = [ | |
{ | |
"role": "system", | |
"content": "You will receive a query and context. Only return the answer based on the context without mentioning the context.", | |
}, | |
{"role": "user", "content": query}, | |
] | |
completion = inference_client.chat.completions.create( | |
model="HuggingFaceTB/SmolLM2-360M-Instruct", messages=messages, max_tokens=2000 | |
) | |
return completion.choices[0].message | |
def rag_pipeline(query: str, k_retrieved: int = 10, k_reranked: int = 5): | |
documents = query_and_rerank_documents(query, k_retrieved=k_retrieved) | |
query_with_context = ( | |
f"Context: {documents['chunk'].to_list()[:k_reranked]}\n\nQuery: {query}" | |
) | |
return generate_response_api(query_with_context).content, documents | |
with gr.Blocks() as demo: | |
gr.Markdown("""# RAG Hub Datasets | |
Part of [smol blueprint](https://github.com/davidberenstein1957/smol-blueprint) - a smol blueprint for AI development, focusing on practical examples of RAG, information extraction, analysis and fine-tuning in the age of LLMs.""") | |
with gr.Row(): | |
query_input = gr.Textbox( | |
label="Query", placeholder="Enter your question here...", lines=3 | |
) | |
with gr.Row(): | |
with gr.Column(): | |
retrieve_slider = gr.Slider( | |
minimum=1, | |
maximum=20, | |
value=10, | |
label="Number of documents to retrieve", | |
) | |
with gr.Column(): | |
rerank_slider = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=5, | |
label="Number of documents to use after reranking", | |
) | |
submit_btn = gr.Button("Submit") | |
response_output = gr.Textbox(label="Response", lines=10) | |
documents_output = gr.Dataframe( | |
label="Documents", headers=["chunk", "url", "distance", "rank"], wrap=True | |
) | |
submit_btn.click( | |
fn=rag_pipeline, | |
inputs=[query_input, retrieve_slider, rerank_slider], | |
outputs=[response_output, documents_output], | |
) | |
demo.launch() | |