Spaces:
Runtime error
Runtime error
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import pandas as pd | |
import transformers | |
import torch | |
from sentence_transformers import SentenceTransformer, util | |
""" | |
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co./docs/huggingface_hub/v0.22.2/en/guides/inference | |
""" | |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
# Load the SBERT model | |
sbert_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
# Function to initialize the Llama 3 8B model pipeline | |
def initiate_pipeline(): | |
model = "meta-llama/Meta-Llama-3-8B-Instruct" | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
return transformers.pipeline( | |
"text-generation", | |
model=model, | |
model_kwargs={"torch_dtype": torch.bfloat16}, | |
device=device, | |
) | |
# Initialize the model | |
llama_model = initiate_pipeline() | |
# Load the Q&A pairs from the CSV | |
qa_data = pd.read_csv("rag_juri_cv.csv") | |
# Function to retrieve the top 5 relevant Q&A pairs using Sentence-BERT | |
def retrieve_top_k(query, k=5): | |
# Combine the questions from the CSV into a list | |
questions = qa_data['QUESTION'].tolist() | |
# Encode the questions and the query using Sentence-BERT | |
question_embeddings = sbert_model.encode(questions, convert_to_tensor=True) | |
query_embedding = sbert_model.encode(query, convert_to_tensor=True) | |
# Compute cosine similarities between the query and all questions | |
cosine_scores = util.pytorch_cos_sim(query_embedding, question_embeddings).flatten() | |
# Get the indices of the top k most similar questions | |
top_k_indices = torch.topk(cosine_scores, k=k).indices.cpu() | |
# Retrieve the corresponding Q&A pairs | |
top_k_qa = qa_data.iloc[top_k_indices] | |
return top_k_qa | |
def chatbot(query): | |
# Retrieve the top 5 relevant Q&A pairs | |
top_k_qa = retrieve_top_k(query) | |
# Generate the prefix, body, and suffix for the prompt | |
prefix = """ | |
<|begin_of_text|><|start_header_id|>user<|end_header_id|>You are a chatbot specialized in answering questions about Juri Grosjean's CV. | |
Please only use the information provided in the context to answer the question. | |
Here is the question to answer: | |
""" + query + "\n\n" | |
context = "This is the context information to answer the question:\n" | |
for index, row in top_k_qa.iterrows(): | |
context += f"Information {index}: {row['ANSWER']}\n\n" | |
suffix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>" | |
prompt = prefix + context + suffix | |
# Generate a response | |
outputs = llama_model( | |
prompt, | |
max_new_tokens=500, | |
do_sample=True, | |
temperature=0.6, | |
top_p=0.9, | |
) | |
# Extract and return the chatbot's answer | |
output = outputs[0]["generated_text"] | |
return output.split("assistant")[-1].strip() | |
# Set up the Gradio interface | |
demo = gr.Interface( | |
fn=chatbot, | |
inputs=gr.Textbox(lines=5, placeholder="Ask a question about Juri Grosjean's CV"), | |
outputs="text" | |
) | |
if __name__ == "__main__": | |
demo.launch() |