Spaces:
Runtime error
Runtime error
Shashi Kiran
commited on
Commit
·
e7c0015
1
Parent(s):
4816820
first
Browse files- app.py +28 -7
- faiss/index.faiss +0 -0
- faiss/index.pkl +3 -0
app.py
CHANGED
@@ -3,20 +3,33 @@ import faiss
|
|
3 |
import numpy as np
|
4 |
import requests
|
5 |
import torch
|
|
|
6 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
7 |
from sentence_transformers import SentenceTransformer
|
8 |
|
9 |
class CustomRetriever:
|
10 |
-
def __init__(self, faiss_index_path: str):
|
11 |
-
"""Initializes the retriever by loading the FAISS index and
|
|
|
12 |
self.index = faiss.read_index(faiss_index_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
14 |
|
15 |
def retrieve(self, query: str, top_k: int = 5):
|
16 |
"""Retrieve top-k relevant documents based on the query."""
|
17 |
query_embedding = self.embedder.encode([query])
|
|
|
|
|
18 |
distances, indices = self.index.search(np.array(query_embedding).astype('float32'), top_k)
|
19 |
-
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
class CustomGenerator:
|
@@ -29,24 +42,32 @@ class CustomGenerator:
|
|
29 |
"""Generate a response using the retrieved documents and the user input."""
|
30 |
context = "\n".join([f"Doc {i+1}: {doc}" for i, (doc, _) in enumerate(retrieved_docs)])
|
31 |
prompt = f"Context:\n{context}\n\nUser: {user_input}\nBot:"
|
|
|
32 |
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
|
33 |
with torch.no_grad():
|
34 |
outputs = self.model.generate(inputs.input_ids, max_length=max_length, pad_token_id=self.tokenizer.eos_token_id)
|
|
|
35 |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
36 |
return response.split("Bot:")[-1].strip()
|
37 |
|
38 |
|
39 |
def rag_chatbot(user_input):
|
40 |
"""The main RAG chatbot function to retrieve documents and generate a response."""
|
|
|
41 |
top_k = 5 # Number of documents to retrieve
|
42 |
-
|
43 |
-
|
|
|
44 |
response = generator.generate(user_input, retrieved_docs)
|
45 |
return response
|
46 |
|
47 |
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
generator = CustomGenerator()
|
51 |
|
52 |
# Gradio UI
|
|
|
3 |
import numpy as np
|
4 |
import requests
|
5 |
import torch
|
6 |
+
import pickle
|
7 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
|
10 |
class CustomRetriever:
|
11 |
+
def __init__(self, faiss_index_path: str, metadata_path: str):
|
12 |
+
"""Initializes the retriever by loading the FAISS index and document metadata."""
|
13 |
+
# Load the FAISS index
|
14 |
self.index = faiss.read_index(faiss_index_path)
|
15 |
+
|
16 |
+
# Load the document metadata (mapping FAISS indices to document content)
|
17 |
+
with open(metadata_path, 'rb') as file:
|
18 |
+
self.doc_metadata = pickle.load(file)
|
19 |
+
|
20 |
+
# Load the SentenceTransformer for embedding queries
|
21 |
self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
22 |
|
23 |
def retrieve(self, query: str, top_k: int = 5):
|
24 |
"""Retrieve top-k relevant documents based on the query."""
|
25 |
query_embedding = self.embedder.encode([query])
|
26 |
+
|
27 |
+
# Search the FAISS index for top-k similar embeddings
|
28 |
distances, indices = self.index.search(np.array(query_embedding).astype('float32'), top_k)
|
29 |
+
|
30 |
+
# Retrieve the actual document content using the indices
|
31 |
+
retrieved_docs = [(self.doc_metadata[idx], distance) for idx, distance in zip(indices[0], distances[0])]
|
32 |
+
return retrieved_docs
|
33 |
|
34 |
|
35 |
class CustomGenerator:
|
|
|
42 |
"""Generate a response using the retrieved documents and the user input."""
|
43 |
context = "\n".join([f"Doc {i+1}: {doc}" for i, (doc, _) in enumerate(retrieved_docs)])
|
44 |
prompt = f"Context:\n{context}\n\nUser: {user_input}\nBot:"
|
45 |
+
|
46 |
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
|
47 |
with torch.no_grad():
|
48 |
outputs = self.model.generate(inputs.input_ids, max_length=max_length, pad_token_id=self.tokenizer.eos_token_id)
|
49 |
+
|
50 |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
51 |
return response.split("Bot:")[-1].strip()
|
52 |
|
53 |
|
54 |
def rag_chatbot(user_input):
|
55 |
"""The main RAG chatbot function to retrieve documents and generate a response."""
|
56 |
+
# Step 1: Retrieve relevant documents
|
57 |
top_k = 5 # Number of documents to retrieve
|
58 |
+
retrieved_docs = retriever.retrieve(user_input, top_k)
|
59 |
+
|
60 |
+
# Step 2: Generate a response using the documents
|
61 |
response = generator.generate(user_input, retrieved_docs)
|
62 |
return response
|
63 |
|
64 |
|
65 |
+
# Paths to your FAISS index and metadata files
|
66 |
+
FAISS_INDEX_PATH = r"C:\Users\schandrappa\Downloads\Banking_Regulations_Compliance_ChatBOT\faiss\index.faiss"
|
67 |
+
METADATA_PATH = r"C:\Users\schandrappa\Downloads\Banking_Regulations_Compliance_ChatBOT\faiss\index.pkl"
|
68 |
+
|
69 |
+
# Initialize retriever and generator
|
70 |
+
retriever = CustomRetriever(faiss_index_path=FAISS_INDEX_PATH, metadata_path=METADATA_PATH)
|
71 |
generator = CustomGenerator()
|
72 |
|
73 |
# Gradio UI
|
faiss/index.faiss
ADDED
Binary file (249 kB). View file
|
|
faiss/index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2361a54ccf2ea94476514663e7f3cf028ab93e37d32c72faba6e0e01d4b77781
|
3 |
+
size 1423760
|