rag-generate / app.py
davidberenstein1957's picture
Update app.py
f2a1ff7 verified
raw
history blame
3.08 kB
import gradio as gr
from huggingface_hub import get_token, InferenceClient
from sentence_transformers import CrossEncoder
gradio_client = Client("https://smol-blueprint-vector-search-hub.hf.space/")
reranker = CrossEncoder("sentence-transformers/all-MiniLM-L12-v2")
inference_client = InferenceClient(api_key=get_token())
def similarity_search(query: str, k: int = 5):
results = gradio_client.predict(api_name="/similarity_search", query=query, k=k)
return pd.DataFrame(data=results["data"], columns=results["headers"])
def query_and_rerank_documents(query: str, k_retrieved: int = 10):
documents = similarity_search(query, k_retrieved)
documents = documents.drop_duplicates("chunk")
documents["rank"] = reranker.predict([[query, hit] for hit in documents["chunk"]])
reranked_documents = documents.sort_values(by="rank", ascending=False)
return reranked_documents
def generate_response_api(query: str):
messages = [
{
"role": "system",
"content": "You will receive a query and context. Only return the answer based on the context without mentioning the context.",
},
{"role": "user", "content": query},
]
completion = inference_client.chat.completions.create(
model="HuggingFaceTB/SmolLM2-360M-Instruct", messages=messages, max_tokens=2000
)
return completion.choices[0].message
def rag_pipeline(query: str, k_retrieved: int = 10, k_reranked: int = 5):
documents = query_and_rerank_documents(query, k_retrieved=k_retrieved)
query_with_context = (
f"Context: {documents['chunk'].to_list()[:k_reranked]}\n\nQuery: {query}"
)
return generate_response_api(query_with_context).content, documents
with gr.Blocks() as demo:
gr.Markdown("""# RAG Hub Datasets
Part of [smol blueprint](https://github.com/davidberenstein1957/smol-blueprint) - a smol blueprint for AI development, focusing on practical examples of RAG, information extraction, analysis and fine-tuning in the age of LLMs.""")
with gr.Row():
query_input = gr.Textbox(
label="Query", placeholder="Enter your question here...", lines=3
)
with gr.Row():
with gr.Column():
retrieve_slider = gr.Slider(
minimum=1,
maximum=20,
value=10,
label="Number of documents to retrieve",
)
with gr.Column():
rerank_slider = gr.Slider(
minimum=1,
maximum=10,
value=5,
label="Number of documents to use after reranking",
)
submit_btn = gr.Button("Submit")
response_output = gr.Textbox(label="Response", lines=10)
documents_output = gr.Dataframe(
label="Documents", headers=["chunk", "url", "distance", "rank"], wrap=True
)
submit_btn.click(
fn=rag_pipeline,
inputs=[query_input, retrieve_slider, rerank_slider],
outputs=[response_output, documents_output],
)
demo.launch()