File size: 3,078 Bytes
fc223d6
 
 
 
f2a1ff7
 
fc223d6
f2a1ff7
 
 
 
 
 
fc223d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73b4cd1
fc223d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73b4cd1
fc223d6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
from huggingface_hub import get_token, InferenceClient
from sentence_transformers import CrossEncoder


gradio_client = Client("https://smol-blueprint-vector-search-hub.hf.space/")
reranker = CrossEncoder("sentence-transformers/all-MiniLM-L12-v2")
inference_client = InferenceClient(api_key=get_token())


def similarity_search(query: str, k: int = 5):
    results = gradio_client.predict(api_name="/similarity_search", query=query, k=k)
    return pd.DataFrame(data=results["data"], columns=results["headers"])


def query_and_rerank_documents(query: str, k_retrieved: int = 10):
    documents = similarity_search(query, k_retrieved)
    documents = documents.drop_duplicates("chunk")
    documents["rank"] = reranker.predict([[query, hit] for hit in documents["chunk"]])
    reranked_documents = documents.sort_values(by="rank", ascending=False)
    return reranked_documents


def generate_response_api(query: str):
    messages = [
        {
            "role": "system",
            "content": "You will receive a query and context. Only return the answer based on the context without mentioning the context.",
        },
        {"role": "user", "content": query},
    ]
    completion = inference_client.chat.completions.create(
        model="HuggingFaceTB/SmolLM2-360M-Instruct", messages=messages, max_tokens=2000
    )

    return completion.choices[0].message


def rag_pipeline(query: str, k_retrieved: int = 10, k_reranked: int = 5):
    documents = query_and_rerank_documents(query, k_retrieved=k_retrieved)
    query_with_context = (
        f"Context: {documents['chunk'].to_list()[:k_reranked]}\n\nQuery: {query}"
    )
    return generate_response_api(query_with_context).content, documents


with gr.Blocks() as demo:
    gr.Markdown("""# RAG Hub Datasets 
                
                Part of [smol blueprint](https://github.com/davidberenstein1957/smol-blueprint) - a smol blueprint for AI development, focusing on practical examples of RAG, information extraction, analysis and fine-tuning in the age of LLMs.""")

    with gr.Row():
        query_input = gr.Textbox(
            label="Query", placeholder="Enter your question here...", lines=3
        )

    with gr.Row():
        with gr.Column():
            retrieve_slider = gr.Slider(
                minimum=1,
                maximum=20,
                value=10,
                label="Number of documents to retrieve",
            )
        with gr.Column():
            rerank_slider = gr.Slider(
                minimum=1,
                maximum=10,
                value=5,
                label="Number of documents to use after reranking",
            )

    submit_btn = gr.Button("Submit")
    response_output = gr.Textbox(label="Response", lines=10)
    documents_output = gr.Dataframe(
        label="Documents", headers=["chunk", "url", "distance", "rank"], wrap=True
    )

    submit_btn.click(
        fn=rag_pipeline,
        inputs=[query_input, retrieve_slider, rerank_slider],
        outputs=[response_output, documents_output],
    )

demo.launch()