olegperegudov commited on
Commit
11f324c
1 Parent(s): 13b81ea
Files changed (5) hide show
  1. .gitignore +3 -0
  2. app.py +43 -2
  3. build_model.py +23 -0
  4. constants.py +45 -0
  5. utils.py +100 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ data
2
+ env
3
+ model
app.py CHANGED
@@ -1,4 +1,45 @@
1
  import streamlit as st
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
 
3
+ import utils
4
+ from build_model import load_model
5
+
6
+ st.title("Buzzbot")
7
+
8
+ # Initialize retriever and model
9
+ if "retriever" not in st.session_state:
10
+ st.session_state["retriever"] = utils.build_retriever()
11
+ if "model" not in st.session_state:
12
+ st.session_state["model"] = load_model()
13
+ if "conversation" not in st.session_state:
14
+ st.session_state["conversation"] = utils.Conversation()
15
+
16
+ # Initialize chat history
17
+ if "messages" not in st.session_state:
18
+ st.session_state.messages = []
19
+
20
+ # Display chat messages from history on app rerun
21
+ for message in st.session_state.messages:
22
+ with st.chat_message(message["role"]):
23
+ st.markdown(message["content"])
24
+ if message['role']=="assistant":
25
+ st.caption(message["source_docs"])
26
+
27
+ # Accept user input
28
+ if user_input := st.chat_input("What is up?"):
29
+ # Add user message to chat history
30
+ st.session_state.messages.append({"role": "user", "content": user_input, "source_docs": None})
31
+ # Display user message in chat message container
32
+ with st.chat_message("user"):
33
+ st.markdown(user_input)
34
+
35
+ # Display assistant response in chat message container
36
+ with st.chat_message("assistant"):
37
+ with st.spinner(""):
38
+ answer, source_docs = utils.ask_question(
39
+ user_input, st.session_state.conversation, st.session_state.model, st.session_state.retriever
40
+ )
41
+ st.write(answer)
42
+ # for source_doc in source_docs:
43
+ st.caption(source_docs)
44
+
45
+ st.session_state.messages.append({"role": "assistant", "content": answer, "source_docs": source_docs})
build_model.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.callbacks.manager import CallbackManager
2
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
3
+ from langchain_community.llms import LlamaCpp
4
+
5
+ import constants
6
+
7
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
8
+
9
+
10
+ def load_model():
11
+ return LlamaCpp(
12
+ model_path=constants.MODEL_SAVE_PATH,
13
+ temperature=constants.TEMPERATURE,
14
+ max_tokens=constants.MAX_TOKENS,
15
+ top_p=constants.TOP_P,
16
+ # callback_manager=callback_manager, # will stream to stdout, but wont attach to variable
17
+ verbose=False, # Verbose is required to pass to the callback manager
18
+ n_gpu_layers=constants.N_GPU_LAYERS,
19
+ n_batch=constants.N_BATCH,
20
+ n_ctx=constants.N_CTX,
21
+ repeat_penalty=constants.REPEAT_PENALTY,
22
+ streaming=False,
23
+ )
constants.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+
5
+ # model path
6
+ MODEL_NAME = "saiga_mistral_7b.Q4_K_M.gguf"
7
+ MODEL_URL = f"https://huggingface.co/TheBloke/saiga_mistral_7b-GGUF/blob/main/{MODEL_NAME}"
8
+
9
+ # FOR PRODUCTION
10
+ CWD = os.path.dirname(os.path.realpath(__file__))
11
+ DATA_PATH = os.path.join(CWD, "data")
12
+ DOCS_PATH = os.path.join(DATA_PATH, "docs")
13
+ MODEL_PATH = os.path.join(CWD, "model")
14
+ MODEL_SAVE_PATH = os.path.join(MODEL_PATH, MODEL_NAME)
15
+
16
+ # RAG params
17
+ N_GPU_LAYERS = (
18
+ -1 if torch.cuda.is_available() else 0
19
+ ) # The number of layers to put on the GPU. The rest will be on the CPU (0 means all layers on the CPU).
20
+ N_BATCH = 1024 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU
21
+
22
+ TEMPERATURE = 0.1 # The temperature of the sampling. 0.1 is a good value for most cases
23
+ MAX_TOKENS = 1024 # The maximum number of tokens to generate
24
+ TOP_P = 2
25
+ N_CTX = 2048 # context len, up to a maximum of 32k
26
+ CHUNK_SIZE = 750 # max number of letters for each chunk during splitting
27
+ CHUNK_OVERLAP = 200 # overlap between chunks
28
+ SEARCH_TYPE = "mmr"
29
+ LAST_MESSAGES = 3 # The number of last messages in conversation history to include in the context
30
+ REPEAT_PENALTY = 1.1 # The penalty for repeating tokens in the output
31
+ DEVICE = "cuda" if N_GPU_LAYERS > 0 else "cpu"
32
+
33
+ EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
34
+ VECTOR_STORE_PATH = os.path.join(DATA_PATH, "chroma_db")
35
+
36
+ # retriever config
37
+ SEARCH_KWARGS = {"k": 3, "score_threshold": 0.6}
38
+
39
+ DEFAULT_MESSAGE_TEMPLATE = "<s>{role}\n{content}</s>"
40
+ DEFAULT_RESPONSE_TEMPLATE = "<s>bot\n"
41
+ DEFAULT_SYSTEM_PROMPT = "Ты ассистент помощник, который отвечает на вопросы используя предоставленный контекст. \
42
+ В качестве контекста используются тексты из различных источников. \
43
+ Постарайся ответить на вопрос максимально точно. \
44
+ Для ответа используй только информацию из контекста и вопроса. Ничего не выдумывай. \
45
+ Если не можешь ответить на вопрос, напиши - 'Не хватает данных для ответа.' "
utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.document_loaders import DirectoryLoader
5
+ from langchain_community.embeddings import HuggingFaceEmbeddings
6
+ from langchain_community.vectorstores import Chroma
7
+
8
+ import constants
9
+
10
+
11
+ class Conversation:
12
+ def __init__(
13
+ self,
14
+ message_template=constants.DEFAULT_MESSAGE_TEMPLATE,
15
+ system_prompt=constants.DEFAULT_SYSTEM_PROMPT,
16
+ response_template=constants.DEFAULT_RESPONSE_TEMPLATE,
17
+ ):
18
+ self.message_template = message_template
19
+ self.response_template = response_template
20
+ self.messages = [{"role": "system", "content": system_prompt}]
21
+
22
+ def add_user_message(self, message):
23
+ self.messages.append({"role": "user", "content": message})
24
+
25
+ def add_bot_message(self, message):
26
+ self.messages.append({"role": "bot", "content": message})
27
+
28
+ def get_conversation_history(self):
29
+ final_text = ""
30
+ # 1st system message + last few messages (excluding system duplicate)
31
+ context_and_last_few_messages = [self.messages[0]] + self.messages[1:][-constants.LAST_MESSAGES :]
32
+ for message in context_and_last_few_messages:
33
+ message_text = self.message_template.format(**message)
34
+ final_text += message_text
35
+ return final_text.strip()
36
+
37
+
38
+ def source_documents(relevant_docs):
39
+ source_docs = set()
40
+ for doc in relevant_docs:
41
+ fname = doc.metadata["source"]
42
+ fname_base = os.path.splitext(os.path.basename(fname))[0]
43
+ source_docs.add(fname_base)
44
+ return list(source_docs)
45
+
46
+
47
+ def load_raw_documents():
48
+ return DirectoryLoader(constants.DOCS_PATH, glob="*.txt").load()
49
+
50
+
51
+ def build_nodes(raw_documents):
52
+ return RecursiveCharacterTextSplitter(
53
+ chunk_size=constants.CHUNK_SIZE,
54
+ chunk_overlap=constants.CHUNK_OVERLAP,
55
+ length_function=len,
56
+ is_separator_regex=False,
57
+ ).split_documents(raw_documents)
58
+
59
+
60
+ def build_embeddings():
61
+ return HuggingFaceEmbeddings(model_name=constants.EMBED_MODEL_NAME, model_kwargs={"device": constants.DEVICE})
62
+
63
+
64
+ def build_db(nodes, embeddings):
65
+ return Chroma.from_documents(nodes, embeddings)
66
+
67
+
68
+ def build_retriever():
69
+ raw_documents = load_raw_documents()
70
+ nodes = build_nodes(raw_documents)
71
+ embeddings = build_embeddings()
72
+ db = build_db(nodes, embeddings)
73
+ return db.as_retriever(search_kwargs=constants.SEARCH_KWARGS, search_type=constants.SEARCH_TYPE)
74
+
75
+
76
+ def fetch_relevant_nodes(question, retriever):
77
+ relevant_docs = retriever.get_relevant_documents(question)
78
+ context = [doc.page_content for doc in relevant_docs]
79
+ source_docs = source_documents(relevant_docs)
80
+ context = list(set(context)) # remove duplicated strings from context
81
+ return context, source_docs
82
+
83
+
84
+ def ask_question(question, conversation, model, retriever):
85
+
86
+ context, source_docs = fetch_relevant_nodes(question, retriever)
87
+
88
+ # add user message to conversation's context
89
+ conversation.add_user_message(question)
90
+ conversation_history = conversation.get_conversation_history()
91
+ prompt = f"{conversation_history}\n\
92
+ {context}\n\
93
+ {constants.DEFAULT_RESPONSE_TEMPLATE}"
94
+
95
+
96
+ answer = model.invoke(prompt)
97
+ # add bot message to conversation's context
98
+ conversation.add_bot_message(answer)
99
+
100
+ return answer, source_docs