|
|
|
|
|
|
|
|
|
import os |
|
import numpy as np |
|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import faiss |
|
|
|
|
|
def load_embeddings(embeddings_folder='embeddings'): |
|
all_embeddings = [] |
|
metadata = [] |
|
for file in os.listdir(embeddings_folder): |
|
if file.endswith('.npy'): |
|
embedding_path = os.path.join(embeddings_folder, file) |
|
embedding = np.load(embedding_path) |
|
all_embeddings.append(embedding) |
|
|
|
meta_info = file.replace('.npy', '') |
|
metadata.extend([meta_info] * embedding.shape[0]) |
|
|
|
|
|
all_embeddings = np.vstack(all_embeddings) |
|
return all_embeddings, metadata |
|
|
|
embeddings, metadata = load_embeddings() |
|
|
|
|
|
dimension = embeddings.shape[1] |
|
index = faiss.IndexFlatL2(dimension) |
|
index.add(embeddings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = "TheBloke/zephyr-7B-beta-GPTQ" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="balanced", trust_remote_code=False) |
|
|
|
|
|
|
|
def retrieve_documents(query, top_k=3): |
|
query_embedding = np.mean([embeddings[i] for i in range(len(metadata)) if query.lower() in metadata[i].lower()], axis=0) |
|
distances, indices = index.search(np.array([query_embedding]), top_k) |
|
retrieved_docs = [metadata[idx] for idx in indices[0]] |
|
return retrieved_docs |
|
|
|
|
|
def generate_response(query): |
|
retrieved_docs = retrieve_documents(query) |
|
context = " ".join(retrieved_docs) |
|
input_text = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:" |
|
inputs = tokenizer(input_text, return_tensors="pt") |
|
output = model.generate(**inputs, max_length=512) |
|
answer = tokenizer.decode(output[0], skip_special_tokens=True) |
|
return answer |
|
|
|
|
|
def gradio_interface(query): |
|
response = generate_response(query) |
|
return response |
|
|
|
iface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."), |
|
outputs="text", |
|
title="RAG-based Course Search", |
|
description="Enter a query to search for relevant courses using Retrieval Augmented Generation." |
|
) |
|
|
|
if _name_ == "_main_": |
|
iface.launch() |