chayanbhansali commited on
Commit
eef9fc0
·
verified ·
1 Parent(s): d346fac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -11,9 +11,13 @@ class RAGChatbot:
11
  embedding_model="all-MiniLM-L6-v2"):
12
  # Initialize tokenizer and model
13
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
14
  self.model = AutoModelForCausalLM.from_pretrained(
15
  model_name,
16
- torch_dtype=torch.float16,
17
  device_map="auto"
18
  )
19
 
@@ -88,12 +92,19 @@ class RAGChatbot:
88
  # Generate response
89
  response = self.generate_response(query, context)
90
 
91
- # Append to history and return as list of tuples
92
- updated_history = history + [[query, response]]
 
 
 
93
  return updated_history, ""
94
  except Exception as e:
95
  error_response = f"An error occurred: {str(e)}"
96
- return history + [[query, error_response]], ""
 
 
 
 
97
 
98
  # Create Gradio interface
99
  def create_interface():
@@ -108,7 +119,7 @@ def create_interface():
108
 
109
  status_output = gr.Textbox(label="Load Status")
110
 
111
- chatbot = gr.Chatbot()
112
  msg = gr.Textbox(label="Enter your query")
113
  submit_btn = gr.Button("Send")
114
  clear_btn = gr.Button("Clear Chat")
 
11
  embedding_model="all-MiniLM-L6-v2"):
12
  # Initialize tokenizer and model
13
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ self.bnb_config = BitsAndBytesConfig(
15
+ load_in_8bit=True, # Enable 8-bit loading
16
+ llm_int8_threshold=6.0, # Threshold for mixed-precision computation
17
+ )
18
  self.model = AutoModelForCausalLM.from_pretrained(
19
  model_name,
20
+ quantization_config=bnb_config,
21
  device_map="auto"
22
  )
23
 
 
92
  # Generate response
93
  response = self.generate_response(query, context)
94
 
95
+ # Append to history using messages format
96
+ updated_history = history + [
97
+ {"role": "user", "content": query},
98
+ {"role": "assistant", "content": response}
99
+ ]
100
  return updated_history, ""
101
  except Exception as e:
102
  error_response = f"An error occurred: {str(e)}"
103
+ updated_history = history + [
104
+ {"role": "user", "content": query},
105
+ {"role": "assistant", "content": error_response}
106
+ ]
107
+ return updated_history, ""
108
 
109
  # Create Gradio interface
110
  def create_interface():
 
119
 
120
  status_output = gr.Textbox(label="Load Status")
121
 
122
+ chatbot = gr.Chatbot(type="messages") # Specify message type
123
  msg = gr.Textbox(label="Enter your query")
124
  submit_btn = gr.Button("Send")
125
  clear_btn = gr.Button("Clear Chat")