rag-generate / app.py
davidberenstein1957's picture
app.py
ffa07db verified
raw
history blame
2.45 kB
import gradio as gr
from gradio_client import Client
from huggingface_hub import get_token, InferenceClient
from llama_cpp import Llama
llm = Llama.from_pretrained(
repo_id="HuggingFaceTB/SmolLM2-360M-Instruct-GGUF",
filename="smollm2-360m-instruct-q8_0.gguf",
verbose=False,
)
def generate(
user_prompt: str,
system_prompt: str = "You are a helpful assistant.",
max_tokens: int = 4000,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
):
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
return llm.create_chat_completion(
messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
)
with gr.Blocks() as demo:
gr.Markdown("""# RAG - generate
Generate a response to a query using a [HuggingFaceTB/SmolLM2-360M-Instruct and llama-cpp-python](https://huggingface.co./HuggingFaceTB/SmolLM2-360M-Instruct-GGUF?library=llama-cpp-python).
Part of [ai-blueprint](https://github.com/davidberenstein1957/ai-blueprint) - a blueprint for AI development, focusing on applied examples of RAG, information extraction, analysis and fine-tuning in the age of LLMs and agents.""")
with gr.Row():
system_prompt = gr.Textbox(label="System prompt", lines=3)
user_prompt = gr.Textbox(label="Query", lines=3)
with gr.Accordion("kwargs"):
with gr.Row(variant="panel"):
max_tokens = gr.Number(label="Max tokens", value=512)
temperature = gr.Number(label="Temperature", value=0.2)
top_p = gr.Number(label="Top p", value=0.95)
top_k = gr.Number(label="Top k", value=40)
submit_btn = gr.Button("Submit")
response_output = gr.Textbox(label="Response", lines=10)
documents_output = gr.Dataframe(
label="Documents", headers=["chunk", "url", "distance", "rank"], wrap=True
)
submit_btn.click(
fn=generate,
inputs=[
user_prompt,
system_prompt,
max_tokens,
temperature,
top_p,
top_k,
],
outputs=[response_output],
)
demo.launch()