File size: 5,411 Bytes
4e0fb6d 4b7d265 4e0fb6d 4b7d265 4e0fb6d 4b7d265 4e0fb6d 4b7d265 4e0fb6d 4b7d265 4e0fb6d 4b7d265 4e0fb6d 4b7d265 4e0fb6d 4b7d265 4e0fb6d 4b7d265 4e0fb6d 4b7d265 4e0fb6d 4b7d265 4e0fb6d 4b7d265 4e0fb6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
from huggingface_hub import InferenceClient
import gradio as gr
import nltk
import torch
from transformers import DistilBertTokenizer, DistilBertModel
from duckduckgo_search import ddg
from langchain.chains import RetrievalQA
from langchain.document_loaders import UnstructuredFileLoader
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.vectorstores import Chroma
from transformers import DistilBertConfig, DistilBertModel
# Initialize tokenizer and model for embedding
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
embedding_model_name = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# Load Qwen 2 for text generation
qwen_text_gen = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
# Function to search the web
def search_web(query):
results = ddg(query)
web_content = ''
if results:
for result in results:
web_content += result['body']
return web_content
# Function to initialize knowledge vector store
def init_knowledge_vector_store(file):
if file is None:
return
filepath = file.name
distilbert_embedding = HuggingFaceBgeEmbeddings(model_name=embedding_model_name)
loader = UnstructuredFileLoader(filepath, mode="elements")
docs = loader.load()
Chroma.from_documents(docs, distilbert_embedding, persist_directory="./vector_store")
# Function to get knowledge vector store
def get_knowledge_vector_store():
distilbert_embedding = HuggingFaceBgeEmbeddings(model_name=embedding_model_name)
vector_store = Chroma(embedding_function=distilbert_embedding, persist_directory="./vector_store")
return vector_store
# Function to get knowledge-based answer
def get_knowledge_based_answer(query, qwen_text_gen, vector_store, VECTOR_SEARCH_TOP_K, web_content):
if web_content:
prompt_template = f"""Answer the user's question based on the following known information.
Known web search content: {web_content} """ + """
Known Content:
{context}
question:
{question}"""
else:
prompt_template = """Answer the user's question based on the known information.
Known Content:
{context}
question:
{question}"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
knowledge_chain = RetrievalQA.from_llm(
llm=qwen_text_gen,
retriever=vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K}),
prompt=prompt
)
knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
knowledge_chain.return_source_documents = True
result = knowledge_chain.invoke({"query": query})
return result['result']
# Function to clear session
def clear_session():
return '', None
# Function to predict
def predict(input, qwen_text_gen, VECTOR_SEARCH_TOP_K, use_web, key=None, history=None):
if history == None:
history = []
vector_store = get_knowledge_vector_store()
if use_web == 'True':
web_content = search_web(query=input)
if web_content is None:
web_content = ""
else:
web_content = ''
resp = get_knowledge_based_answer(
query=input,
qwen_text_gen=qwen_text_gen,
vector_store=vector_store,
VECTOR_SEARCH_TOP_K=VECTOR_SEARCH_TOP_K,
web_content=web_content,
)
history.append((input, resp))
return '', history, history
# Gradio interface setup
block = gr.Blocks()
with block as demo:
gr.Markdown("<h1><center>Chat History </center></h1>")
with gr.Row():
with gr.Column(scale=1):
file = gr.File(label='Please upload txt, md, docx type files', file_types=['.txt', '.md', '.docx'])
get_vs = gr.Button("Generate Knowledge Base")
get_vs.click(init_knowledge_vector_store, inputs=[file])
use_web = gr.Radio(["True", "False"], label="Web Search", value="False")
VECTOR_SEARCH_TOP_K = gr.Slider(1, 10, value=5, step=1, label="vector search top k", interactive=True)
with gr.Column(scale=4):
chatbot = gr.Chatbot(label='Ming History Knowledge Question and Answer Assistant', height=600)
message = gr.Textbox(label='Please enter your question')
state = gr.State()
with gr.Row():
clear_history = gr.Button("Clear history conversation")
send = gr.Button("Send")
send.click(predict,
inputs=[message, qwen_text_gen, VECTOR_SEARCH_TOP_K, use_web, state],
outputs=[message, chatbot, state])
clear_history.click(fn=clear_session, inputs=[], outputs=[chatbot, state], queue=False)
message.submit(predict,
inputs=[message, qwen_text_gen, VECTOR_SEARCH_TOP_K, use_web, state],
outputs=[message, chatbot, state])
demo.queue().launch(share=False) |