|
import os |
|
from dotenv import load_dotenv |
|
from langchain_community.vectorstores import Qdrant |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain.schema.runnable import RunnablePassthrough |
|
from langchain.schema.output_parser import StrOutputParser |
|
from qdrant_client import QdrantClient, models |
|
from langchain_openai import ChatOpenAI |
|
import gradio as gr |
|
import logging |
|
from typing import List, Tuple |
|
from dataclasses import dataclass |
|
from datetime import datetime |
|
from transformers import AutoTokenizer, AutoModelForCausalLM ,pipeline |
|
from langchain_huggingface.llms import HuggingFacePipeline |
|
import spaces |
|
from langchain_huggingface.llms import HuggingFacePipeline |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
@dataclass |
|
class Message: |
|
role: str |
|
content: str |
|
timestamp: str |
|
|
|
class ChatHistory: |
|
def __init__(self): |
|
self.messages: List[Message] = [] |
|
|
|
def add_message(self, role: str, content: str): |
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
self.messages.append(Message(role=role, content=content, timestamp=timestamp)) |
|
|
|
def get_formatted_history(self, max_messages: int = 10) -> str: |
|
"""Returns the most recent conversation history formatted as a string""" |
|
recent_messages = self.messages[-max_messages:] if len(self.messages) > max_messages else self.messages |
|
formatted_history = "\n".join([ |
|
f"{msg.role}: {msg.content}" for msg in recent_messages |
|
]) |
|
return formatted_history |
|
|
|
def clear(self): |
|
self.messages = [] |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
if not HF_TOKEN: |
|
logger.error("HF_TOKEN is not set in the environment variables.") |
|
exit(1) |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5") |
|
|
|
|
|
try: |
|
client = QdrantClient( |
|
url=os.getenv("QDRANT_URL"), |
|
api_key=os.getenv("QDRANT_API_KEY"), |
|
prefer_grpc=True |
|
) |
|
except Exception as e: |
|
logger.error("Failed to connect to Qdrant. Ensure QDRANT_URL and QDRANT_API_KEY are correctly set.") |
|
exit(1) |
|
|
|
|
|
collection_name = "mawared" |
|
|
|
|
|
try: |
|
client.create_collection( |
|
collection_name=collection_name, |
|
vectors_config=models.VectorParams( |
|
size=768, |
|
distance=models.Distance.COSINE |
|
) |
|
) |
|
logger.info(f"Created new collection: {collection_name}") |
|
except Exception as e: |
|
if "already exists" in str(e): |
|
logger.info(f"Collection {collection_name} already exists, continuing...") |
|
else: |
|
logger.error(f"Error creating collection: {e}") |
|
exit(1) |
|
|
|
|
|
db = Qdrant( |
|
client=client, |
|
collection_name=collection_name, |
|
embeddings=embeddings, |
|
) |
|
|
|
|
|
retriever = db.as_retriever( |
|
search_type="similarity", |
|
search_kwargs={"k": 3} |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_id = "CohereForAI/c4ai-command-r7b-12-2024" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained(model_id) |
|
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=8192 ) |
|
llm = HuggingFacePipeline(pipeline=pipe) |
|
|
|
|
|
|
|
|
|
|
|
|
|
template = """ |
|
You are an expert assistant specializing in the Mawared HR System. |
|
Your task is to provide accurate and contextually relevant answers based on the provided context and chat history. |
|
If you need more information, ask targeted clarifying questions. |
|
Ensure you provide detailed Numbered step by step to the user and be very accurate. |
|
Previous Conversation: |
|
{chat_history} |
|
Current Context: |
|
{context} |
|
Current Question: |
|
{question} |
|
Ask followup questions based on your provided asnwer to create a conversational flow, Only answer form the provided context and chat history , dont make up any information. |
|
answer only and only from the given context and knowledgebase. |
|
Answer: |
|
""" |
|
|
|
prompt = ChatPromptTemplate.from_template(template) |
|
|
|
|
|
|
|
|
|
@spcaes.GPU(600) |
|
def create_rag_chain(chat_history: str): |
|
chain = ( |
|
{ |
|
"context": retriever, |
|
"question": RunnablePassthrough(), |
|
"chat_history": lambda x: chat_history |
|
} |
|
| prompt |
|
| llm |
|
| StrOutputParser() |
|
) |
|
return chain |
|
|
|
|
|
chat_history = ChatHistory() |
|
|
|
|
|
|
|
def ask_question_gradio(question, history): |
|
try: |
|
|
|
chat_history.add_message("user", question) |
|
|
|
|
|
formatted_history = chat_history.get_formatted_history() |
|
|
|
|
|
rag_chain = create_rag_chain(formatted_history) |
|
|
|
|
|
response = "" |
|
for chunk in rag_chain.stream(question): |
|
response += chunk |
|
|
|
|
|
chat_history.add_message("assistant", response) |
|
|
|
|
|
history.append({"role": "user", "content": question}) |
|
history.append({"role": "assistant", "content": response}) |
|
|
|
return "", history |
|
except Exception as e: |
|
logger.error(f"Error during question processing: {e}") |
|
return "", history + [{"role": "assistant", "content": "An error occurred. Please try again later."}] |
|
|
|
def clear_chat(): |
|
chat_history.clear() |
|
return [], "" |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as iface: |
|
gr.Image("Image.jpg" , width=1200 , height=300 ,show_label=False, show_download_button=False) |
|
gr.Markdown("# Mawared HR Assistant") |
|
gr.Markdown("Ask questions about the Mawared HR system, and this assistant will provide answers based on the available context and conversation history.") |
|
|
|
|
|
|
|
chatbot = gr.Chatbot( |
|
height=400, |
|
show_label=False, |
|
type="messages" |
|
) |
|
|
|
with gr.Row(): |
|
question_input = gr.Textbox( |
|
label="Ask a question:", |
|
placeholder="Type your question here...", |
|
scale=25 |
|
) |
|
clear_button = gr.Button("Clear Chat", scale=1) |
|
|
|
question_input.submit( |
|
ask_question_gradio, |
|
inputs=[question_input, chatbot], |
|
outputs=[question_input, chatbot] |
|
) |
|
|
|
clear_button.click( |
|
clear_chat, |
|
outputs=[chatbot, question_input] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |