Daemontatox commited on
Commit
78fd9fa
·
verified ·
1 Parent(s): 7b48faf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -119
app.py CHANGED
@@ -1,13 +1,5 @@
1
  import spaces
2
  import subprocess
3
-
4
- subprocess.run(
5
- 'pip install flash-attn --no-build-isolation',
6
- env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
7
- shell=True
8
- )
9
-
10
-
11
  import os
12
  import torch
13
  from dotenv import load_dotenv
@@ -20,17 +12,14 @@ from qdrant_client import QdrantClient, models
20
  from langchain_openai import ChatOpenAI
21
  import gradio as gr
22
  import logging
23
- from typing import List, Tuple
24
  from dataclasses import dataclass
25
  from datetime import datetime
26
- from transformers import AutoTokenizer, AutoModelForCausalLM ,pipeline
27
- from langchain_huggingface.llms import HuggingFacePipeline
28
- import re
29
  from langchain_huggingface.llms import HuggingFacePipeline
30
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline,BitsAndBytesConfig,TextIteratorStreamer
31
  from langchain_cerebras import ChatCerebras
32
-
33
-
34
 
35
  # Configure logging
36
  logging.basicConfig(level=logging.INFO)
@@ -51,7 +40,6 @@ class ChatHistory:
51
  self.messages.append(Message(role=role, content=content, timestamp=timestamp))
52
 
53
  def get_formatted_history(self, max_messages: int = 10) -> str:
54
- """Returns the most recent conversation history formatted as a string"""
55
  recent_messages = self.messages[-max_messages:] if len(self.messages) > max_messages else self.messages
56
  formatted_history = "\n".join([
57
  f"{msg.role}: {msg.content}" for msg in recent_messages
@@ -61,10 +49,9 @@ class ChatHistory:
61
  def clear(self):
62
  self.messages = []
63
 
64
- # Load environment variables
65
  load_dotenv()
66
 
67
- # HuggingFace API Token
68
  HF_TOKEN = os.getenv("HF_TOKEN")
69
  C_apikey = os.getenv("C_apikey")
70
  OPENAPI_KEY = os.getenv("OPENAPI_KEY")
@@ -73,10 +60,9 @@ if not HF_TOKEN:
73
  logger.error("HF_TOKEN is not set in the environment variables.")
74
  exit(1)
75
 
76
- # HuggingFace Embeddings
77
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
78
 
79
- # Qdrant Client Setup
80
  try:
81
  client = QdrantClient(
82
  url=os.getenv("QDRANT_URL"),
@@ -84,125 +70,61 @@ try:
84
  prefer_grpc=False
85
  )
86
  except Exception as e:
87
- logger.error("Failed to connect to Qdrant. Ensure QDRANT_URL and QDRANT_API_KEY are correctly set.")
88
  exit(1)
89
 
90
- # Define collection name
91
  collection_name = "mawared"
92
 
93
- # Try to create collection
94
  try:
95
  client.create_collection(
96
  collection_name=collection_name,
97
  vectors_config=models.VectorParams(
98
- size=384, # GTE-large embedding size
99
  distance=models.Distance.COSINE
100
  )
101
  )
102
- logger.info(f"Created new collection: {collection_name}")
103
  except Exception as e:
104
- if "already exists" in str(e):
105
- logger.info(f"Collection {collection_name} already exists, continuing...")
106
- else:
107
  logger.error(f"Error creating collection: {e}")
108
  exit(1)
109
 
110
- # Create Qdrant vector store
111
  db = Qdrant(
112
  client=client,
113
  collection_name=collection_name,
114
  embeddings=embeddings,
115
  )
116
 
117
- # Create retriever
118
  retriever = db.as_retriever(
119
  search_type="similarity",
120
  search_kwargs={"k": 5}
121
  )
122
 
123
- # retriever = db.as_retriever(
124
- # search_type="mmr",
125
- # search_kwargs={"k": 5, "fetch_k": 10, "lambda_mult": 0.5}
126
- # )
127
-
128
-
129
-
130
-
131
- # retriever = db.as_retriever(
132
- # search_type="similarity_score_threshold",
133
- # search_kwargs={"k": 5, "score_threshold": 0.8}
134
- # )
135
-
136
-
137
-
138
- # Load model directly
139
-
140
-
141
-
142
- # Set up the LLM
143
- # llm = ChatOpenAI(
144
- # base_url="https://api-inference.huggingface.co/v1/",
145
- # temperature=0,
146
- # api_key=HF_TOKEN,
147
- # model="mistralai/Mistral-Nemo-Instruct-2407",
148
- # max_tokens=None,
149
- # timeout=None
150
-
151
- # )
152
-
153
-
154
- #llm = ChatOpenAI(
155
- #base_url="https://openrouter.ai/api/v1",
156
- #temperature=0.01,
157
- #api_key=OPENAPI_KEY,
158
- #model="google/gemini-2.0-flash-exp:free",
159
- #max_tokens=None,
160
- #timeout=None,
161
- #max_retries=3,
162
-
163
- #)
164
-
165
-
166
  llm = ChatCerebras(
167
- model="llama-3.3-70b",
168
- api_key=C_apikey
 
169
  )
170
 
171
-
172
-
173
-
174
-
175
-
176
-
177
- # Create prompt template with chat history
178
  template = """
179
- You are an expert assistant specializing in the Mawared HR System. Your role is to provide precise and contextually relevant answers based on the retrieved context and chat history.
 
 
180
 
181
  Key Responsibilities:
182
 
183
  Use the given chat history and retrieved context to craft accurate and detailed responses.
184
  If necessary, ask specific and targeted clarifying questions to gather more information.
185
  Present step-by-step instructions in a clear, numbered format when applicable.
186
- Rules for Responses:
187
-
188
- Strictly use the information from the provided context and chat history. Avoid making up or fabricating any details.
189
- Do not reference the retrieval process, sources, pages, or documents in your responses.
190
- Maintain a conversational flow by asking relevant follow-up questions to engage the user and enhance the interaction.
191
- Inputs for Your Response:
192
 
193
  Previous Conversation: {chat_history}
194
  Retrieved Context: {context}
195
  Current Question: {question}
196
- Answer:{{answer}}
197
- Your answers must be expressive, detailed, and fully address the user’s needs without deviating from the provided information.
198
  """
199
 
200
  prompt = ChatPromptTemplate.from_template(template)
201
 
202
- # Create the RAG chain with chat history
203
-
204
-
205
-
206
  def create_rag_chain(chat_history: str):
207
  chain = (
208
  {
@@ -216,38 +138,62 @@ def create_rag_chain(chat_history: str):
216
  )
217
  return chain
218
 
219
- # Initialize chat history
220
  chat_history = ChatHistory()
221
 
222
- # Gradio Function
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  @spaces.GPU()
224
- def ask_question_gradio(question, history):
225
  try:
226
- # Add user question to chat history
227
  chat_history.add_message("user", question)
228
-
229
- # Get formatted history
230
  formatted_history = chat_history.get_formatted_history()
231
-
232
- # Create chain with current chat history
233
  rag_chain = create_rag_chain(formatted_history)
234
 
235
- # Generate response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  response = ""
237
- for chunk in rag_chain.stream(question):
238
- response += chunk
 
239
 
240
- # Add assistant response to chat history
241
  chat_history.add_message("assistant", response)
242
 
243
- # Update Gradio chat history
244
- history.append({"role": "user", "content": question})
245
- history.append({"role": "assistant", "content": response})
246
-
247
- return "", history
248
  except Exception as e:
249
  logger.error(f"Error during question processing: {e}")
250
- return "", history + [{"role": "assistant", "content": "An error occurred. Please try again later."}]
 
251
 
252
  def clear_chat():
253
  chat_history.clear()
@@ -255,17 +201,15 @@ def clear_chat():
255
 
256
  # Gradio Interface
257
  with gr.Blocks(theme='Yntec/HaleyCH_Theme_Orange_Green') as iface:
258
- gr.Image("Image.jpg" , width=750 , height=300 ,show_label=False, show_download_button=False)
259
  gr.Markdown("# Mawared HR Assistant 2.5.1")
260
  gr.Markdown('### Instructions')
261
- gr.Markdown("Ask a question about MawaredHR and get a detailed answer , if you get an error try again with same prompt , its an Api issue and we are working on it 😀")
262
-
263
-
264
 
265
  chatbot = gr.Chatbot(
266
  height=750,
267
  show_label=False,
268
- type="messages" # Using the new messages format
269
  )
270
 
271
  with gr.Row():
@@ -287,6 +231,5 @@ with gr.Blocks(theme='Yntec/HaleyCH_Theme_Orange_Green') as iface:
287
  outputs=[chatbot, question_input]
288
  )
289
 
290
- # Launch the Gradio App
291
  if __name__ == "__main__":
292
  iface.launch()
 
1
  import spaces
2
  import subprocess
 
 
 
 
 
 
 
 
3
  import os
4
  import torch
5
  from dotenv import load_dotenv
 
12
  from langchain_openai import ChatOpenAI
13
  import gradio as gr
14
  import logging
15
+ from typing import List, Tuple, Generator
16
  from dataclasses import dataclass
17
  from datetime import datetime
18
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
 
19
  from langchain_huggingface.llms import HuggingFacePipeline
 
20
  from langchain_cerebras import ChatCerebras
21
+ from queue import Queue
22
+ from threading import Thread
23
 
24
  # Configure logging
25
  logging.basicConfig(level=logging.INFO)
 
40
  self.messages.append(Message(role=role, content=content, timestamp=timestamp))
41
 
42
  def get_formatted_history(self, max_messages: int = 10) -> str:
 
43
  recent_messages = self.messages[-max_messages:] if len(self.messages) > max_messages else self.messages
44
  formatted_history = "\n".join([
45
  f"{msg.role}: {msg.content}" for msg in recent_messages
 
49
  def clear(self):
50
  self.messages = []
51
 
52
+ # Load environment variables and setup (same as before)
53
  load_dotenv()
54
 
 
55
  HF_TOKEN = os.getenv("HF_TOKEN")
56
  C_apikey = os.getenv("C_apikey")
57
  OPENAPI_KEY = os.getenv("OPENAPI_KEY")
 
60
  logger.error("HF_TOKEN is not set in the environment variables.")
61
  exit(1)
62
 
 
63
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
64
 
65
+ # Qdrant setup (same as before)
66
  try:
67
  client = QdrantClient(
68
  url=os.getenv("QDRANT_URL"),
 
70
  prefer_grpc=False
71
  )
72
  except Exception as e:
73
+ logger.error("Failed to connect to Qdrant.")
74
  exit(1)
75
 
 
76
  collection_name = "mawared"
77
 
 
78
  try:
79
  client.create_collection(
80
  collection_name=collection_name,
81
  vectors_config=models.VectorParams(
82
+ size=384,
83
  distance=models.Distance.COSINE
84
  )
85
  )
 
86
  except Exception as e:
87
+ if "already exists" not in str(e):
 
 
88
  logger.error(f"Error creating collection: {e}")
89
  exit(1)
90
 
 
91
  db = Qdrant(
92
  client=client,
93
  collection_name=collection_name,
94
  embeddings=embeddings,
95
  )
96
 
 
97
  retriever = db.as_retriever(
98
  search_type="similarity",
99
  search_kwargs={"k": 5}
100
  )
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  llm = ChatCerebras(
103
+ model="llama-3.3-70b",
104
+ api_key=C_apikey,
105
+ streaming=True # Enable streaming
106
  )
107
 
 
 
 
 
 
 
 
108
  template = """
109
+ You are a Friendly assistant specializing in the Mawared HR System.
110
+ Your role is to provide precise and contextually relevant answers based on the retrieved context and chat history.
111
+ Your top priority is user experience and satisfaction, only answer questions based on Mawared HR system and ignore everything else.
112
 
113
  Key Responsibilities:
114
 
115
  Use the given chat history and retrieved context to craft accurate and detailed responses.
116
  If necessary, ask specific and targeted clarifying questions to gather more information.
117
  Present step-by-step instructions in a clear, numbered format when applicable.
118
+ If you think you will not be able to provide a clear answer based on the user question , ask a clariifying question and ask for more details.
 
 
 
 
 
119
 
120
  Previous Conversation: {chat_history}
121
  Retrieved Context: {context}
122
  Current Question: {question}
123
+ Answer:
 
124
  """
125
 
126
  prompt = ChatPromptTemplate.from_template(template)
127
 
 
 
 
 
128
  def create_rag_chain(chat_history: str):
129
  chain = (
130
  {
 
138
  )
139
  return chain
140
 
 
141
  chat_history = ChatHistory()
142
 
143
+ def process_stream(stream_queue: Queue, history: List[dict]) -> Generator[List[dict], None, None]:
144
+ """Process the streaming response and update the chat interface"""
145
+ current_response = ""
146
+
147
+ while True:
148
+ chunk = stream_queue.get()
149
+ if chunk is None: # Signal that streaming is complete
150
+ break
151
+
152
+ current_response += chunk
153
+ new_history = history.copy()
154
+ new_history[-1]["content"] = current_response
155
+ yield new_history
156
+
157
  @spaces.GPU()
158
+ def ask_question_gradio(question: str, history: List[dict]) -> Generator[tuple, None, None]:
159
  try:
 
160
  chat_history.add_message("user", question)
 
 
161
  formatted_history = chat_history.get_formatted_history()
 
 
162
  rag_chain = create_rag_chain(formatted_history)
163
 
164
+ # Update history with user message
165
+ history.append({"role": "user", "content": question})
166
+ history.append({"role": "assistant", "content": ""})
167
+
168
+ # Create a queue for streaming responses
169
+ stream_queue = Queue()
170
+
171
+ # Function to process the stream in a separate thread
172
+ def stream_processor():
173
+ try:
174
+ for chunk in rag_chain.stream(question):
175
+ stream_queue.put(chunk)
176
+ stream_queue.put(None) # Signal completion
177
+ except Exception as e:
178
+ logger.error(f"Streaming error: {e}")
179
+ stream_queue.put(None)
180
+
181
+ # Start streaming in a separate thread
182
+ Thread(target=stream_processor).start()
183
+
184
+ # Yield updates to the chat interface
185
  response = ""
186
+ for updated_history in process_stream(stream_queue, history):
187
+ response = updated_history[-1]["content"]
188
+ yield "", updated_history
189
 
190
+ # Add final response to chat history
191
  chat_history.add_message("assistant", response)
192
 
 
 
 
 
 
193
  except Exception as e:
194
  logger.error(f"Error during question processing: {e}")
195
+ history.append({"role": "assistant", "content": "An error occurred. Please try again later."})
196
+ yield "", history
197
 
198
  def clear_chat():
199
  chat_history.clear()
 
201
 
202
  # Gradio Interface
203
  with gr.Blocks(theme='Yntec/HaleyCH_Theme_Orange_Green') as iface:
204
+ gr.Image("Image.jpg", width=750, height=300, show_label=False, show_download_button=False)
205
  gr.Markdown("# Mawared HR Assistant 2.5.1")
206
  gr.Markdown('### Instructions')
207
+ gr.Markdown("Ask a question about MawaredHR and get a detailed answer, if you get an error try again with same prompt, its an Api issue and we are working on it 😀")
 
 
208
 
209
  chatbot = gr.Chatbot(
210
  height=750,
211
  show_label=False,
212
+ bubble_full_width=False,
213
  )
214
 
215
  with gr.Row():
 
231
  outputs=[chatbot, question_input]
232
  )
233
 
 
234
  if __name__ == "__main__":
235
  iface.launch()