Shashi Kiran commited on
Commit
e7c0015
·
1 Parent(s): 4816820
Files changed (3) hide show
  1. app.py +28 -7
  2. faiss/index.faiss +0 -0
  3. faiss/index.pkl +3 -0
app.py CHANGED
@@ -3,20 +3,33 @@ import faiss
3
  import numpy as np
4
  import requests
5
  import torch
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from sentence_transformers import SentenceTransformer
8
 
9
  class CustomRetriever:
10
- def __init__(self, faiss_index_path: str):
11
- """Initializes the retriever by loading the FAISS index and setting up the embedding model."""
 
12
  self.index = faiss.read_index(faiss_index_path)
 
 
 
 
 
 
13
  self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
14
 
15
  def retrieve(self, query: str, top_k: int = 5):
16
  """Retrieve top-k relevant documents based on the query."""
17
  query_embedding = self.embedder.encode([query])
 
 
18
  distances, indices = self.index.search(np.array(query_embedding).astype('float32'), top_k)
19
- return [(index, distance) for index, distance in zip(indices[0], distances[0])]
 
 
 
20
 
21
 
22
  class CustomGenerator:
@@ -29,24 +42,32 @@ class CustomGenerator:
29
  """Generate a response using the retrieved documents and the user input."""
30
  context = "\n".join([f"Doc {i+1}: {doc}" for i, (doc, _) in enumerate(retrieved_docs)])
31
  prompt = f"Context:\n{context}\n\nUser: {user_input}\nBot:"
 
32
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
33
  with torch.no_grad():
34
  outputs = self.model.generate(inputs.input_ids, max_length=max_length, pad_token_id=self.tokenizer.eos_token_id)
 
35
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
36
  return response.split("Bot:")[-1].strip()
37
 
38
 
39
  def rag_chatbot(user_input):
40
  """The main RAG chatbot function to retrieve documents and generate a response."""
 
41
  top_k = 5 # Number of documents to retrieve
42
- retrieved_doc_ids = retriever.retrieve(user_input, top_k)
43
- retrieved_docs = [(f"Dummy content for doc {doc_id}", distance) for doc_id, distance in retrieved_doc_ids]
 
44
  response = generator.generate(user_input, retrieved_docs)
45
  return response
46
 
47
 
48
- FAISS_INDEX_PATH = ""
49
- retriever = CustomRetriever(faiss_index_path=FAISS_INDEX_PATH)
 
 
 
 
50
  generator = CustomGenerator()
51
 
52
  # Gradio UI
 
3
  import numpy as np
4
  import requests
5
  import torch
6
+ import pickle
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
  from sentence_transformers import SentenceTransformer
9
 
10
  class CustomRetriever:
11
+ def __init__(self, faiss_index_path: str, metadata_path: str):
12
+ """Initializes the retriever by loading the FAISS index and document metadata."""
13
+ # Load the FAISS index
14
  self.index = faiss.read_index(faiss_index_path)
15
+
16
+ # Load the document metadata (mapping FAISS indices to document content)
17
+ with open(metadata_path, 'rb') as file:
18
+ self.doc_metadata = pickle.load(file)
19
+
20
+ # Load the SentenceTransformer for embedding queries
21
  self.embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
22
 
23
  def retrieve(self, query: str, top_k: int = 5):
24
  """Retrieve top-k relevant documents based on the query."""
25
  query_embedding = self.embedder.encode([query])
26
+
27
+ # Search the FAISS index for top-k similar embeddings
28
  distances, indices = self.index.search(np.array(query_embedding).astype('float32'), top_k)
29
+
30
+ # Retrieve the actual document content using the indices
31
+ retrieved_docs = [(self.doc_metadata[idx], distance) for idx, distance in zip(indices[0], distances[0])]
32
+ return retrieved_docs
33
 
34
 
35
  class CustomGenerator:
 
42
  """Generate a response using the retrieved documents and the user input."""
43
  context = "\n".join([f"Doc {i+1}: {doc}" for i, (doc, _) in enumerate(retrieved_docs)])
44
  prompt = f"Context:\n{context}\n\nUser: {user_input}\nBot:"
45
+
46
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
47
  with torch.no_grad():
48
  outputs = self.model.generate(inputs.input_ids, max_length=max_length, pad_token_id=self.tokenizer.eos_token_id)
49
+
50
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
51
  return response.split("Bot:")[-1].strip()
52
 
53
 
54
  def rag_chatbot(user_input):
55
  """The main RAG chatbot function to retrieve documents and generate a response."""
56
+ # Step 1: Retrieve relevant documents
57
  top_k = 5 # Number of documents to retrieve
58
+ retrieved_docs = retriever.retrieve(user_input, top_k)
59
+
60
+ # Step 2: Generate a response using the documents
61
  response = generator.generate(user_input, retrieved_docs)
62
  return response
63
 
64
 
65
+ # Paths to your FAISS index and metadata files
66
+ FAISS_INDEX_PATH = r"C:\Users\schandrappa\Downloads\Banking_Regulations_Compliance_ChatBOT\faiss\index.faiss"
67
+ METADATA_PATH = r"C:\Users\schandrappa\Downloads\Banking_Regulations_Compliance_ChatBOT\faiss\index.pkl"
68
+
69
+ # Initialize retriever and generator
70
+ retriever = CustomRetriever(faiss_index_path=FAISS_INDEX_PATH, metadata_path=METADATA_PATH)
71
  generator = CustomGenerator()
72
 
73
  # Gradio UI
faiss/index.faiss ADDED
Binary file (249 kB). View file
 
faiss/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2361a54ccf2ea94476514663e7f3cf028ab93e37d32c72faba6e0e01d4b77781
3
+ size 1423760