File size: 2,752 Bytes
c81853b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import streamlit as st
from langchain.llms import LlamaCpp
from langchain.memory import ConversationBufferMemory
from langchain.chains import RetrievalQA
from langchain.embeddings import FastEmbedEmbeddings
from langchain.vectorstores import Chroma
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler 
from langchain import hub

def init_retriever():
    """
    Initialize and return the retriever function
    """
    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
    llm = LlamaCpp(model_path="./models/llama-2-13b-chat.Q4_K_S.gguf", 
                   n_ctx=4000, 
                   max_tokens=4000,
                   f16_kv=True,
                   callback_manager=callback_manager,
                   verbose=True)
    embeddings = FastEmbedEmbeddings(model_name="BAAI/bge-small-en-v1.5", cache_dir="./embedding_model/")
    db = Chroma(persist_directory="./vectordb/", embedding_function=embeddings)
    rag_prompt_llama = hub.pull("rlm/rag-prompt-llama")
    qa_chain = RetrievalQA.from_chain_type(
        llm,
        retriever=db.as_retriever(),
        chain_type_kwargs={"prompt": rag_prompt_llama},
    )
    qa_chain.callback_manager = callback_manager
    qa_chain.memory = ConversationBufferMemory()
    
    return qa_chain

# Check if retriever is already initialized in the session state
if "retriever" not in st.session_state:
    st.session_state.retriever = init_retriever()

# Function to apply rounded edges using CSS
def add_rounded_edges(image_path="./randstad_featuredimage.png", radius=30):
    st.markdown(
        f'<style>.rounded-img{{border-radius: {radius}px; overflow: hidden;}}</style>',
        unsafe_allow_html=True,
    )
    st.image(image_path, use_column_width=True, output_format='auto')

# add side bar
with st.sidebar:
    # add Randstad logo
    add_rounded_edges()

st.title("πŸ’¬ HR Chatbot")
st.caption("πŸš€ A chatbot powered by Local LLM")

clear = False

# Add clear chat button
if st.button("Clear Chat History"):
    clear = True
    st.session_state.messages = []

if "messages" not in st.session_state:
    st.session_state.messages = [{"role": "assistant", "content": "How can I help you?"}]

for msg in st.session_state.messages:
    st.chat_message(msg["role"]).write(msg["content"])

if prompt := st.chat_input():
    st.session_state.messages.append({"role": "user", "content": prompt})
    st.chat_message("user").write(prompt)
    chain = st.session_state.retriever
    if clear:
        chain.clean()
    msg = chain.run(st.session_state.messages)
    st.session_state.messages.append({"role": "assistant", "content": msg})
    st.chat_message("assistant").write(msg)