chayanbhansali commited on
Commit
c9dfe43
·
verified ·
1 Parent(s): 6f2ae3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -10,7 +10,7 @@ class RAGChatbot:
10
  model_name="facebook/opt-350m",
11
  embedding_model="all-MiniLM-L6-v2"):
12
  # Initialize tokenizer and model
13
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
14
  # self.bnb_config = BitsAndBytesConfig(
15
  # load_in_8bit=True, # Enable 8-bit loading
16
  # llm_int8_threshold=6.0, # Threshold for mixed-precision computation
@@ -51,7 +51,7 @@ class RAGChatbot:
51
  self.documents.extend(chunks)
52
 
53
  # Generate embeddings
54
- self.embeddings = self.embedding_model.encode(self.documents)
55
  return f"Loaded {len(self.documents)} text chunks from {len(file_paths)} files"
56
 
57
  def retrieve_relevant_context(self, query, top_k=3):
@@ -71,11 +71,12 @@ class RAGChatbot:
71
  return " ".join([self.documents[i] for i in top_indices])
72
 
73
  def generate_response(self, query, context):
74
- # Construct prompt with context
75
- full_prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
 
76
 
77
  # Generate response
78
- inputs = self.tokenizer(full_prompt, return_tensors="pt").to(self.model.device)
79
  outputs = self.model.generate(**inputs, max_new_tokens=150)
80
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
81
 
 
10
  model_name="facebook/opt-350m",
11
  embedding_model="all-MiniLM-L6-v2"):
12
  # Initialize tokenizer and model
13
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
14
  # self.bnb_config = BitsAndBytesConfig(
15
  # load_in_8bit=True, # Enable 8-bit loading
16
  # llm_int8_threshold=6.0, # Threshold for mixed-precision computation
 
51
  self.documents.extend(chunks)
52
 
53
  # Generate embeddings
54
+ self.embeddings = self.embedding_model.encode(self.documents, batch_size=32, show_progress_bar=True)
55
  return f"Loaded {len(self.documents)} text chunks from {len(file_paths)} files"
56
 
57
  def retrieve_relevant_context(self, query, top_k=3):
 
71
  return " ".join([self.documents[i] for i in top_indices])
72
 
73
  def generate_response(self, query, context):
74
+ # Construct prompt with
75
+ truncated_context = " ".join(context.split()[:100])
76
+ full_prompt = f"Context: {truncated_context}\n\nQuestion: {query}\n\nAnswer:"
77
 
78
  # Generate response
79
+ inputs = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True).to(self.model.device)
80
  outputs = self.model.generate(**inputs, max_new_tokens=150)
81
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
82