Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,5 @@
|
|
1 |
import spaces
|
2 |
import subprocess
|
3 |
-
|
4 |
-
subprocess.run(
|
5 |
-
'pip install flash-attn --no-build-isolation',
|
6 |
-
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
|
7 |
-
shell=True
|
8 |
-
)
|
9 |
-
|
10 |
-
|
11 |
import os
|
12 |
import torch
|
13 |
from dotenv import load_dotenv
|
@@ -20,17 +12,14 @@ from qdrant_client import QdrantClient, models
|
|
20 |
from langchain_openai import ChatOpenAI
|
21 |
import gradio as gr
|
22 |
import logging
|
23 |
-
from typing import List, Tuple
|
24 |
from dataclasses import dataclass
|
25 |
from datetime import datetime
|
26 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
27 |
-
from langchain_huggingface.llms import HuggingFacePipeline
|
28 |
-
import re
|
29 |
from langchain_huggingface.llms import HuggingFacePipeline
|
30 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline,BitsAndBytesConfig,TextIteratorStreamer
|
31 |
from langchain_cerebras import ChatCerebras
|
32 |
-
|
33 |
-
|
34 |
|
35 |
# Configure logging
|
36 |
logging.basicConfig(level=logging.INFO)
|
@@ -51,7 +40,6 @@ class ChatHistory:
|
|
51 |
self.messages.append(Message(role=role, content=content, timestamp=timestamp))
|
52 |
|
53 |
def get_formatted_history(self, max_messages: int = 10) -> str:
|
54 |
-
"""Returns the most recent conversation history formatted as a string"""
|
55 |
recent_messages = self.messages[-max_messages:] if len(self.messages) > max_messages else self.messages
|
56 |
formatted_history = "\n".join([
|
57 |
f"{msg.role}: {msg.content}" for msg in recent_messages
|
@@ -61,10 +49,9 @@ class ChatHistory:
|
|
61 |
def clear(self):
|
62 |
self.messages = []
|
63 |
|
64 |
-
# Load environment variables
|
65 |
load_dotenv()
|
66 |
|
67 |
-
# HuggingFace API Token
|
68 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
69 |
C_apikey = os.getenv("C_apikey")
|
70 |
OPENAPI_KEY = os.getenv("OPENAPI_KEY")
|
@@ -73,10 +60,9 @@ if not HF_TOKEN:
|
|
73 |
logger.error("HF_TOKEN is not set in the environment variables.")
|
74 |
exit(1)
|
75 |
|
76 |
-
# HuggingFace Embeddings
|
77 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
78 |
|
79 |
-
# Qdrant
|
80 |
try:
|
81 |
client = QdrantClient(
|
82 |
url=os.getenv("QDRANT_URL"),
|
@@ -84,125 +70,61 @@ try:
|
|
84 |
prefer_grpc=False
|
85 |
)
|
86 |
except Exception as e:
|
87 |
-
logger.error("Failed to connect to Qdrant.
|
88 |
exit(1)
|
89 |
|
90 |
-
# Define collection name
|
91 |
collection_name = "mawared"
|
92 |
|
93 |
-
# Try to create collection
|
94 |
try:
|
95 |
client.create_collection(
|
96 |
collection_name=collection_name,
|
97 |
vectors_config=models.VectorParams(
|
98 |
-
size=384,
|
99 |
distance=models.Distance.COSINE
|
100 |
)
|
101 |
)
|
102 |
-
logger.info(f"Created new collection: {collection_name}")
|
103 |
except Exception as e:
|
104 |
-
if "already exists" in str(e):
|
105 |
-
logger.info(f"Collection {collection_name} already exists, continuing...")
|
106 |
-
else:
|
107 |
logger.error(f"Error creating collection: {e}")
|
108 |
exit(1)
|
109 |
|
110 |
-
# Create Qdrant vector store
|
111 |
db = Qdrant(
|
112 |
client=client,
|
113 |
collection_name=collection_name,
|
114 |
embeddings=embeddings,
|
115 |
)
|
116 |
|
117 |
-
# Create retriever
|
118 |
retriever = db.as_retriever(
|
119 |
search_type="similarity",
|
120 |
search_kwargs={"k": 5}
|
121 |
)
|
122 |
|
123 |
-
# retriever = db.as_retriever(
|
124 |
-
# search_type="mmr",
|
125 |
-
# search_kwargs={"k": 5, "fetch_k": 10, "lambda_mult": 0.5}
|
126 |
-
# )
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
# retriever = db.as_retriever(
|
132 |
-
# search_type="similarity_score_threshold",
|
133 |
-
# search_kwargs={"k": 5, "score_threshold": 0.8}
|
134 |
-
# )
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
# Load model directly
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
# Set up the LLM
|
143 |
-
# llm = ChatOpenAI(
|
144 |
-
# base_url="https://api-inference.huggingface.co/v1/",
|
145 |
-
# temperature=0,
|
146 |
-
# api_key=HF_TOKEN,
|
147 |
-
# model="mistralai/Mistral-Nemo-Instruct-2407",
|
148 |
-
# max_tokens=None,
|
149 |
-
# timeout=None
|
150 |
-
|
151 |
-
# )
|
152 |
-
|
153 |
-
|
154 |
-
#llm = ChatOpenAI(
|
155 |
-
#base_url="https://openrouter.ai/api/v1",
|
156 |
-
#temperature=0.01,
|
157 |
-
#api_key=OPENAPI_KEY,
|
158 |
-
#model="google/gemini-2.0-flash-exp:free",
|
159 |
-
#max_tokens=None,
|
160 |
-
#timeout=None,
|
161 |
-
#max_retries=3,
|
162 |
-
|
163 |
-
#)
|
164 |
-
|
165 |
-
|
166 |
llm = ChatCerebras(
|
167 |
-
|
168 |
-
|
|
|
169 |
)
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
# Create prompt template with chat history
|
178 |
template = """
|
179 |
-
You are
|
|
|
|
|
180 |
|
181 |
Key Responsibilities:
|
182 |
|
183 |
Use the given chat history and retrieved context to craft accurate and detailed responses.
|
184 |
If necessary, ask specific and targeted clarifying questions to gather more information.
|
185 |
Present step-by-step instructions in a clear, numbered format when applicable.
|
186 |
-
|
187 |
-
|
188 |
-
Strictly use the information from the provided context and chat history. Avoid making up or fabricating any details.
|
189 |
-
Do not reference the retrieval process, sources, pages, or documents in your responses.
|
190 |
-
Maintain a conversational flow by asking relevant follow-up questions to engage the user and enhance the interaction.
|
191 |
-
Inputs for Your Response:
|
192 |
|
193 |
Previous Conversation: {chat_history}
|
194 |
Retrieved Context: {context}
|
195 |
Current Question: {question}
|
196 |
-
Answer:
|
197 |
-
Your answers must be expressive, detailed, and fully address the user’s needs without deviating from the provided information.
|
198 |
"""
|
199 |
|
200 |
prompt = ChatPromptTemplate.from_template(template)
|
201 |
|
202 |
-
# Create the RAG chain with chat history
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
def create_rag_chain(chat_history: str):
|
207 |
chain = (
|
208 |
{
|
@@ -216,38 +138,62 @@ def create_rag_chain(chat_history: str):
|
|
216 |
)
|
217 |
return chain
|
218 |
|
219 |
-
# Initialize chat history
|
220 |
chat_history = ChatHistory()
|
221 |
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
@spaces.GPU()
|
224 |
-
def ask_question_gradio(question, history):
|
225 |
try:
|
226 |
-
# Add user question to chat history
|
227 |
chat_history.add_message("user", question)
|
228 |
-
|
229 |
-
# Get formatted history
|
230 |
formatted_history = chat_history.get_formatted_history()
|
231 |
-
|
232 |
-
# Create chain with current chat history
|
233 |
rag_chain = create_rag_chain(formatted_history)
|
234 |
|
235 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
response = ""
|
237 |
-
for
|
238 |
-
response
|
|
|
239 |
|
240 |
-
# Add
|
241 |
chat_history.add_message("assistant", response)
|
242 |
|
243 |
-
# Update Gradio chat history
|
244 |
-
history.append({"role": "user", "content": question})
|
245 |
-
history.append({"role": "assistant", "content": response})
|
246 |
-
|
247 |
-
return "", history
|
248 |
except Exception as e:
|
249 |
logger.error(f"Error during question processing: {e}")
|
250 |
-
|
|
|
251 |
|
252 |
def clear_chat():
|
253 |
chat_history.clear()
|
@@ -255,17 +201,15 @@ def clear_chat():
|
|
255 |
|
256 |
# Gradio Interface
|
257 |
with gr.Blocks(theme='Yntec/HaleyCH_Theme_Orange_Green') as iface:
|
258 |
-
gr.Image("Image.jpg"
|
259 |
gr.Markdown("# Mawared HR Assistant 2.5.1")
|
260 |
gr.Markdown('### Instructions')
|
261 |
-
gr.Markdown("Ask a question about MawaredHR and get a detailed answer
|
262 |
-
|
263 |
-
|
264 |
|
265 |
chatbot = gr.Chatbot(
|
266 |
height=750,
|
267 |
show_label=False,
|
268 |
-
|
269 |
)
|
270 |
|
271 |
with gr.Row():
|
@@ -287,6 +231,5 @@ with gr.Blocks(theme='Yntec/HaleyCH_Theme_Orange_Green') as iface:
|
|
287 |
outputs=[chatbot, question_input]
|
288 |
)
|
289 |
|
290 |
-
# Launch the Gradio App
|
291 |
if __name__ == "__main__":
|
292 |
iface.launch()
|
|
|
1 |
import spaces
|
2 |
import subprocess
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import os
|
4 |
import torch
|
5 |
from dotenv import load_dotenv
|
|
|
12 |
from langchain_openai import ChatOpenAI
|
13 |
import gradio as gr
|
14 |
import logging
|
15 |
+
from typing import List, Tuple, Generator
|
16 |
from dataclasses import dataclass
|
17 |
from datetime import datetime
|
18 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
|
|
|
|
19 |
from langchain_huggingface.llms import HuggingFacePipeline
|
|
|
20 |
from langchain_cerebras import ChatCerebras
|
21 |
+
from queue import Queue
|
22 |
+
from threading import Thread
|
23 |
|
24 |
# Configure logging
|
25 |
logging.basicConfig(level=logging.INFO)
|
|
|
40 |
self.messages.append(Message(role=role, content=content, timestamp=timestamp))
|
41 |
|
42 |
def get_formatted_history(self, max_messages: int = 10) -> str:
|
|
|
43 |
recent_messages = self.messages[-max_messages:] if len(self.messages) > max_messages else self.messages
|
44 |
formatted_history = "\n".join([
|
45 |
f"{msg.role}: {msg.content}" for msg in recent_messages
|
|
|
49 |
def clear(self):
|
50 |
self.messages = []
|
51 |
|
52 |
+
# Load environment variables and setup (same as before)
|
53 |
load_dotenv()
|
54 |
|
|
|
55 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
56 |
C_apikey = os.getenv("C_apikey")
|
57 |
OPENAPI_KEY = os.getenv("OPENAPI_KEY")
|
|
|
60 |
logger.error("HF_TOKEN is not set in the environment variables.")
|
61 |
exit(1)
|
62 |
|
|
|
63 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
64 |
|
65 |
+
# Qdrant setup (same as before)
|
66 |
try:
|
67 |
client = QdrantClient(
|
68 |
url=os.getenv("QDRANT_URL"),
|
|
|
70 |
prefer_grpc=False
|
71 |
)
|
72 |
except Exception as e:
|
73 |
+
logger.error("Failed to connect to Qdrant.")
|
74 |
exit(1)
|
75 |
|
|
|
76 |
collection_name = "mawared"
|
77 |
|
|
|
78 |
try:
|
79 |
client.create_collection(
|
80 |
collection_name=collection_name,
|
81 |
vectors_config=models.VectorParams(
|
82 |
+
size=384,
|
83 |
distance=models.Distance.COSINE
|
84 |
)
|
85 |
)
|
|
|
86 |
except Exception as e:
|
87 |
+
if "already exists" not in str(e):
|
|
|
|
|
88 |
logger.error(f"Error creating collection: {e}")
|
89 |
exit(1)
|
90 |
|
|
|
91 |
db = Qdrant(
|
92 |
client=client,
|
93 |
collection_name=collection_name,
|
94 |
embeddings=embeddings,
|
95 |
)
|
96 |
|
|
|
97 |
retriever = db.as_retriever(
|
98 |
search_type="similarity",
|
99 |
search_kwargs={"k": 5}
|
100 |
)
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
llm = ChatCerebras(
|
103 |
+
model="llama-3.3-70b",
|
104 |
+
api_key=C_apikey,
|
105 |
+
streaming=True # Enable streaming
|
106 |
)
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
template = """
|
109 |
+
You are a Friendly assistant specializing in the Mawared HR System.
|
110 |
+
Your role is to provide precise and contextually relevant answers based on the retrieved context and chat history.
|
111 |
+
Your top priority is user experience and satisfaction, only answer questions based on Mawared HR system and ignore everything else.
|
112 |
|
113 |
Key Responsibilities:
|
114 |
|
115 |
Use the given chat history and retrieved context to craft accurate and detailed responses.
|
116 |
If necessary, ask specific and targeted clarifying questions to gather more information.
|
117 |
Present step-by-step instructions in a clear, numbered format when applicable.
|
118 |
+
If you think you will not be able to provide a clear answer based on the user question , ask a clariifying question and ask for more details.
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
Previous Conversation: {chat_history}
|
121 |
Retrieved Context: {context}
|
122 |
Current Question: {question}
|
123 |
+
Answer:
|
|
|
124 |
"""
|
125 |
|
126 |
prompt = ChatPromptTemplate.from_template(template)
|
127 |
|
|
|
|
|
|
|
|
|
128 |
def create_rag_chain(chat_history: str):
|
129 |
chain = (
|
130 |
{
|
|
|
138 |
)
|
139 |
return chain
|
140 |
|
|
|
141 |
chat_history = ChatHistory()
|
142 |
|
143 |
+
def process_stream(stream_queue: Queue, history: List[dict]) -> Generator[List[dict], None, None]:
|
144 |
+
"""Process the streaming response and update the chat interface"""
|
145 |
+
current_response = ""
|
146 |
+
|
147 |
+
while True:
|
148 |
+
chunk = stream_queue.get()
|
149 |
+
if chunk is None: # Signal that streaming is complete
|
150 |
+
break
|
151 |
+
|
152 |
+
current_response += chunk
|
153 |
+
new_history = history.copy()
|
154 |
+
new_history[-1]["content"] = current_response
|
155 |
+
yield new_history
|
156 |
+
|
157 |
@spaces.GPU()
|
158 |
+
def ask_question_gradio(question: str, history: List[dict]) -> Generator[tuple, None, None]:
|
159 |
try:
|
|
|
160 |
chat_history.add_message("user", question)
|
|
|
|
|
161 |
formatted_history = chat_history.get_formatted_history()
|
|
|
|
|
162 |
rag_chain = create_rag_chain(formatted_history)
|
163 |
|
164 |
+
# Update history with user message
|
165 |
+
history.append({"role": "user", "content": question})
|
166 |
+
history.append({"role": "assistant", "content": ""})
|
167 |
+
|
168 |
+
# Create a queue for streaming responses
|
169 |
+
stream_queue = Queue()
|
170 |
+
|
171 |
+
# Function to process the stream in a separate thread
|
172 |
+
def stream_processor():
|
173 |
+
try:
|
174 |
+
for chunk in rag_chain.stream(question):
|
175 |
+
stream_queue.put(chunk)
|
176 |
+
stream_queue.put(None) # Signal completion
|
177 |
+
except Exception as e:
|
178 |
+
logger.error(f"Streaming error: {e}")
|
179 |
+
stream_queue.put(None)
|
180 |
+
|
181 |
+
# Start streaming in a separate thread
|
182 |
+
Thread(target=stream_processor).start()
|
183 |
+
|
184 |
+
# Yield updates to the chat interface
|
185 |
response = ""
|
186 |
+
for updated_history in process_stream(stream_queue, history):
|
187 |
+
response = updated_history[-1]["content"]
|
188 |
+
yield "", updated_history
|
189 |
|
190 |
+
# Add final response to chat history
|
191 |
chat_history.add_message("assistant", response)
|
192 |
|
|
|
|
|
|
|
|
|
|
|
193 |
except Exception as e:
|
194 |
logger.error(f"Error during question processing: {e}")
|
195 |
+
history.append({"role": "assistant", "content": "An error occurred. Please try again later."})
|
196 |
+
yield "", history
|
197 |
|
198 |
def clear_chat():
|
199 |
chat_history.clear()
|
|
|
201 |
|
202 |
# Gradio Interface
|
203 |
with gr.Blocks(theme='Yntec/HaleyCH_Theme_Orange_Green') as iface:
|
204 |
+
gr.Image("Image.jpg", width=750, height=300, show_label=False, show_download_button=False)
|
205 |
gr.Markdown("# Mawared HR Assistant 2.5.1")
|
206 |
gr.Markdown('### Instructions')
|
207 |
+
gr.Markdown("Ask a question about MawaredHR and get a detailed answer, if you get an error try again with same prompt, its an Api issue and we are working on it 😀")
|
|
|
|
|
208 |
|
209 |
chatbot = gr.Chatbot(
|
210 |
height=750,
|
211 |
show_label=False,
|
212 |
+
bubble_full_width=False,
|
213 |
)
|
214 |
|
215 |
with gr.Row():
|
|
|
231 |
outputs=[chatbot, question_input]
|
232 |
)
|
233 |
|
|
|
234 |
if __name__ == "__main__":
|
235 |
iface.launch()
|