File size: 6,150 Bytes
c0d8a2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354b9c7
c0d8a2a
 
 
354b9c7
c0d8a2a
d2a326b
 
354b9c7
d2a326b
 
 
c0d8a2a
 
 
 
 
b8f4d75
 
c0d8a2a
354b9c7
c0d8a2a
 
 
354b9c7
c0d8a2a
 
 
 
 
7732e76
c0d8a2a
354b9c7
 
c0d8a2a
 
 
 
354b9c7
 
 
c6608e8
c0d8a2a
354b9c7
 
 
 
c0d8a2a
354b9c7
 
c6608e8
06e8e4e
 
354b9c7
 
 
 
 
 
 
 
 
5beead5
354b9c7
c0d8a2a
 
 
 
354b9c7
c0d8a2a
 
354b9c7
c0d8a2a
 
 
 
 
 
 
558756a
354b9c7
c0d8a2a
354b9c7
 
558756a
 
354b9c7
 
c0d8a2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354b9c7
c0d8a2a
 
 
 
 
354b9c7
 
 
c0d8a2a
 
 
 
 
 
 
 
 
 
 
 
354b9c7
c0d8a2a
 
 
 
 
 
 
 
 
354b9c7
c0d8a2a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import time
import gradio as gr
from tqdm import tqdm
from loguru import logger
from transformers import AutoTokenizer,AutoModel
from duckduckgo_search import ddg_suggestions
from duckduckgo_search import ddg_translate, ddg, ddg_news

from langchain.document_loaders import UnstructuredFileLoader
from langchain.text_splitter import CharacterTextSplitter,RecursiveCharacterTextSplitter
from langchain.llms import OpenAI
from langchain.schema import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain,RetrievalQA,LLMChain
from langchain.prompts import PromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain import OpenAI,VectorDBQA

def load_model():
    tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
    # gpu:.half().cuda()
    model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).quantize(bits=4, compile_parallel_kernel=True, parallel_num=2).float()
    model = model.eval()
    return tokenizer,model


def chat_glm(input, history=None):
    if history is None:
        history = []

    tokenizer,model = load_model()
    response, history = model.chat(tokenizer, input, history)
    logger.debug("chatglm:", input,response)
    return history, history

def search_web(query):
    logger.debug("searchweb:", query)
    results = ddg(query)
    web_content = ''
    if results:
        for result in results:
            web_content += result['body']
    return web_content

def search_vec(query):
    logger.debug("searchvec:", query)
    embedding_model_name = 'GanymedeNil/text2vec-large-chinese'
    vec_path = 'cache'
    embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
    vector_store = FAISS.load_local(vec_path,embeddings)
    
    qa = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type="stuff", vectorstore=vector_store,return_source_documents=True)
    result = qa({"query": query})
    return result['result']

def chat_gpt(input, use_web, use_vec, history=None):
    if history is None:
        history = []
    # history = [] # 4097 tokens limit
    
    context = "无"
    if use_vec:
        context = search_vec(input)
    prompt_template = f"""基于以下已知信息,请专业地回答用户的问题。
        若答案中存在编造成分,请在该部分开头添加“据我推测”。另外,答案请使用中文。
        已知内容:
        {context}"""+"""
        问题:
        {question}"""
            
    prompt = PromptTemplate(template=prompt_template,input_variables=["question"])
    
    llm = OpenAI(temperature = 0.2)
    chain = LLMChain(llm=llm, prompt=prompt)
    result = chain.run(input)
    return result

def predict(input,
            large_language_model,
            use_web,
            use_vec,
            openai_key,
            history=None):
    logger.debug("predict..",large_language_model,use_web)
    if openai_key is not None:
        os.environ['OPENAI_API_KEY'] = openai_key
    else:
        return '',"You forgot OpenAI API key","You forgot OpenAI API key"
    if history == None:
        history = []

    if large_language_model == "GPT-3.5-turbo":
        resp  = chat_gpt(input, use_web, use_vec, history)
    elif large_language_model == "ChatGLM-6B-int4":
        _,resp = chat_glm(input, history)
        resp = resp[-1][1]
    elif large_language_model == "Search Web":
        resp = search_web(input)
    elif large_language_model == "Search VectorStore":
        resp = search_vec(input)

    history.append((input, resp))
    return '', history, history

def clear_session():
    return '', None

block = gr.Blocks()
with block as demo:
    gr.Markdown("""<h1><center>MedKBQA(demo)</center></h1>
    <center><font size=3>
    本项目基于LangChain、ChatGLM以及Open AI接口, 提供基于本地医药知识的自动问答应用. <br>
    </center></font>
    """)
    with gr.Row():
        with gr.Column(scale=1):
            model_choose = gr.Accordion("模型选择")
            with model_choose:
                large_language_model = gr.Dropdown(
                    ["ChatGLM-6B-int4","GPT-3.5-turbo","Search Web","Search VectorStore"],
                    label="large language model",
                    value="ChatGLM-6B-int4")
            use_web = gr.Radio(["True", "False"],
                    label="Web Search",
                    value="False")
            use_vec = gr.Radio(["True", "False"],
                    label="VectorStore Search",
                    value="False")
            openai_key = gr.Textbox(label="请输入OpenAI API key", type="password")
        with gr.Column(scale=4):
            chatbot = gr.Chatbot(label='ChatLLM').style(height=600)
            message = gr.Textbox(label='请输入问题')
            state = gr.State()

            with gr.Row():
                clear_history = gr.Button("🧹 清除历史对话")
                send = gr.Button("🚀 发送")

                send.click(predict,
                           inputs=[
                               message, large_language_model, use_web, use_vec, openai_key, state
                           ],
                           outputs=[message, chatbot, state])
                clear_history.click(fn=clear_session,
                                    inputs=[],
                                    outputs=[chatbot, state],
                                    queue=False)

                message.submit(predict,
                               inputs=[
                                   message, large_language_model, use_web, use_vec, openai_key, state
                               ],
                               outputs=[message, chatbot, state])
    gr.Markdown("""提醒:<br>
    1. 使用时请先选择使用chatglm或者chatgpt进行问答. <br>
    2. 使用chatgpt时需要输入您的api key.
    """)
demo.queue().launch(server_name='0.0.0.0', share=False)