File size: 6,086 Bytes
32f3cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728a310
32f3cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
2df00a4
 
 
 
 
 
 
 
 
 
 
 
 
 
32f3cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import torch
import os

from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.llms import HuggingFaceHub
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.prompts import PromptTemplate
#from langchain.chains import (
#    StuffDocumentsChain, LLMChain, ConversationalRetrievalChain
#)
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# Static model name
llm_name = "meta-llama/Llama-2-7b-chat-hf"

# Static file path for multiple files
static_file_paths = [
    "IRM ISO_IEC_27001_2022(en).pdf",
    #"SCF - Cybersecurity & Data Privacy Capability Maturity Model (CP-CMM) (2023.4).pdf",
    #"AG_Level1_V2.0_Final_20211210.pdf",
    #"CIS_Controls_v8_v21.10.pdf",
    #"CSF PDF v11.1.0-1.pdf",
    #"ISO_31000_2018(en)-1.pdf",
    #"OWASP Application Security Verification Standard 4.0.3-en-1.pdf",
    #"NIST.CSWP.29.ipd The NIST Cybersecurity Framework 2.0 202308-1 (1).pdf",
    #"ISO_IEC_27002_2022(en)-1.pdf",
]

# Use cuda for faster processing
device = torch.device("cuda" if torch.cuda.is_available() else "CPU")

# Load documents
#loaders = [PyPDFLoader(x) for x in static_file_paths]
#pages = []
#for loader in loaders:
#    pages.extend(loader.load())
#text_splitter = RecursiveCharacterTextSplitter(
#    chunk_size=600,
#    chunk_overlap=40,
#)
#doc_splits = text_splitter.split_documents(pages)
#embedding = HuggingFaceEmbeddings()
#vectordb = Chroma.from_documents(
#    documents=doc_splits,
#    embedding=embedding,
#)

# Load model
tokenizer = AutoTokenizer.from_pretrained(llm_name, token=os.environ['HUGGINGFACEHUB_API_TOKEN'],)
model = AutoModelForCausalLM.from_pretrained(llm_name, token=os.environ['HUGGINGFACEHUB_API_TOKEN'], torch_dtype=torch.float16)
model = model.to(device)

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, device=device, token=os.environ['HUGGINGFACEHUB_API_TOKEN'])
hf = HuggingFacePipeline(pipeline=pipe)

# Set up template and memory
template = """You are a helpful and appreciative cybersecurity expert who gives comprehensive answers using lists, step-by-step instructions and other aids. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Helpful Answer:
"""
prompt = PromptTemplate.from_template(template)
memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key='answer',
        return_messages=True
    )
retriever = vectordb.as_retriever()
qachain = ConversationalRetrievalChain.from_llm(
    hf,
    retriever=retriever,
    chain_type="stuff", 
    memory=memory,
    return_source_documents=True,
    combine_docs_chain_kwargs={
        "prompt": prompt,
    }
)

def format_chat_history(message, chat_history):
    formatted_chat_history = []
    for user_message, bot_message in chat_history:
        formatted_chat_history.append(f"User: {user_message}")
        formatted_chat_history.append(f"Assistant: {bot_message}")
    return formatted_chat_history

# Conversation with chatbot
def conversation(qa_chain, message, history):
    formatted_chat_history = format_chat_history(message, history)
    response = qa_chain({"question": message, "chat_history": formatted_chat_history})
    response_answer = response["answer"]
    response_sources = response["source_documents"]
    response_source1 = response_sources[0].page_content.strip()
    response_source2 = response_sources[1].page_content.strip()
    response_source1_page = response_sources[0].metadata["page"] + 1
    response_source2_page = response_sources[1].metadata["page"] + 1
    new_history = history + [(message, response_answer)]
    return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page

def demo():
    with gr.Blocks(theme="base") as demo:
        qa_chain = gr.State(qachain)

        gr.Markdown(
        """<center><h2>Context Chatbot</center></h2>
        <h3>Ask any questions about your PDF documents, along with follow-ups</h3>
        When generating answers, it takes past questions into account (via conversational memory), and includes document references for clarity purposes.</i>
        """)

        # Conversation with chatbot
        with gr.Tab("Step 3 - Conversation with chatbot"):
            chatbot = gr.Chatbot(height=600)
            with gr.Row():
                msg = gr.Textbox(placeholder="Type message", container=True)
            with gr.Row():
                submit_btn = gr.Button("Submit")
                clear_btn = gr.ClearButton([msg, chatbot])
            with gr.Accordion("Advanced - Document references", open=False):
                with gr.Row():
                    response_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
                    response_source1_page = gr.Number(label="Page", scale=1)
                with gr.Row():
                    response_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
                    response_source2_page = gr.Number(label="Page", scale=1)

        # Preprocessing events
        #db_btn.click(initialize_database, outputs=[vector_db, db_progress])

        # Chatbot events
        submit_btn.click(conversation, \
            inputs=[qa_chain, msg, chatbot], \
            outputs=[qa_chain, msg, chatbot, response_source1, response_source1_page, response_source2, response_source2_page], \
            queue=False)
        clear_btn.click(lambda:[None,"",0,"",0], \
            inputs=None, \
            outputs=[chatbot], \
            queue=False)

    demo.queue().launch(debug=True)

if __name__ == "__main__":
    demo()