File size: 5,411 Bytes
4e0fb6d
4b7d265
4e0fb6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b7d265
4e0fb6d
 
 
 
 
 
 
 
4b7d265
4e0fb6d
 
 
 
 
 
 
 
 
4b7d265
4e0fb6d
 
 
 
 
4b7d265
4e0fb6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b7d265
4e0fb6d
 
 
4b7d265
4e0fb6d
 
 
 
 
 
 
 
 
 
 
4b7d265
4e0fb6d
 
 
 
 
 
 
 
 
4b7d265
4e0fb6d
 
 
 
 
 
 
 
 
 
 
 
 
4b7d265
4e0fb6d
 
 
 
4b7d265
4e0fb6d
 
 
 
 
 
 
 
 
 
 
4b7d265
4e0fb6d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
from huggingface_hub import InferenceClient
import gradio as gr
import nltk
import torch
from transformers import DistilBertTokenizer, DistilBertModel
from duckduckgo_search import ddg
from langchain.chains import RetrievalQA
from langchain.document_loaders import UnstructuredFileLoader
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.vectorstores import Chroma
from transformers import DistilBertConfig, DistilBertModel

# Initialize tokenizer and model for embedding
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
embedding_model_name = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

# Load Qwen 2 for text generation
qwen_text_gen = InferenceClient("HuggingFaceH4/zephyr-7b-beta")

# Function to search the web
def search_web(query):
    results = ddg(query)
    web_content = ''
    if results:
        for result in results:
            web_content += result['body']
    return web_content

# Function to initialize knowledge vector store
def init_knowledge_vector_store(file):
    if file is None:
        return
    filepath = file.name
    distilbert_embedding = HuggingFaceBgeEmbeddings(model_name=embedding_model_name)
    loader = UnstructuredFileLoader(filepath, mode="elements")
    docs = loader.load()
    Chroma.from_documents(docs, distilbert_embedding, persist_directory="./vector_store")

# Function to get knowledge vector store
def get_knowledge_vector_store():
    distilbert_embedding = HuggingFaceBgeEmbeddings(model_name=embedding_model_name)
    vector_store = Chroma(embedding_function=distilbert_embedding, persist_directory="./vector_store")
    return vector_store

# Function to get knowledge-based answer
def get_knowledge_based_answer(query, qwen_text_gen, vector_store, VECTOR_SEARCH_TOP_K, web_content):
    if web_content:
        prompt_template = f"""Answer the user's question based on the following known information.
                            Known web search content: {web_content} """ + """
                            Known Content:
                            {context}
                            question:
                            {question}"""
    else:
        prompt_template = """Answer the user's question based on the known information.
            Known Content:
            {context}
            question:
            {question}"""
    prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
    
    knowledge_chain = RetrievalQA.from_llm(
        llm=qwen_text_gen, 
        retriever=vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K}), 
        prompt=prompt
    )
    
    knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
        input_variables=["page_content"], 
        template="{page_content}"
    )
    
    knowledge_chain.return_source_documents = True
    
    result = knowledge_chain.invoke({"query": query})
    
    return result['result']

# Function to clear session
def clear_session():
    return '', None

# Function to predict
def predict(input, qwen_text_gen, VECTOR_SEARCH_TOP_K, use_web, key=None, history=None):
    if history == None:
        history = []
    vector_store = get_knowledge_vector_store()
    if use_web == 'True':
        web_content = search_web(query=input)
        if web_content is None:
            web_content = ""
    else:
        web_content = ''

    resp = get_knowledge_based_answer(
        query=input,
        qwen_text_gen=qwen_text_gen,
        vector_store=vector_store,
        VECTOR_SEARCH_TOP_K=VECTOR_SEARCH_TOP_K,
        web_content=web_content,
    )
    history.append((input, resp))
    return '', history, history

# Gradio interface setup
block = gr.Blocks()
with block as demo:
    gr.Markdown("<h1><center>Chat History </center></h1>")
    with gr.Row():
        with gr.Column(scale=1):
            file = gr.File(label='Please upload txt, md, docx type files', file_types=['.txt', '.md', '.docx'])
            get_vs = gr.Button("Generate Knowledge Base")
            get_vs.click(init_knowledge_vector_store, inputs=[file])
            
            use_web = gr.Radio(["True", "False"], label="Web Search", value="False")
            
            VECTOR_SEARCH_TOP_K = gr.Slider(1, 10, value=5, step=1, label="vector search top k", interactive=True)

        with gr.Column(scale=4):
            chatbot = gr.Chatbot(label='Ming History Knowledge Question and Answer Assistant', height=600)
            message = gr.Textbox(label='Please enter your question')
            state = gr.State()

            with gr.Row():
                clear_history = gr.Button("Clear history conversation")
                send = gr.Button("Send")
                send.click(predict,
                           inputs=[message, qwen_text_gen, VECTOR_SEARCH_TOP_K, use_web, state],
                           outputs=[message, chatbot, state])
                clear_history.click(fn=clear_session, inputs=[], outputs=[chatbot, state], queue=False)
                
                message.submit(predict,
                               inputs=[message, qwen_text_gen, VECTOR_SEARCH_TOP_K, use_web, state],
                               outputs=[message, chatbot, state])

demo.queue().launch(share=False)