MedKBQA-LLM / app.py
ducknew's picture
Update app.py
06e8e4e
import os
import time
import gradio as gr
from tqdm import tqdm
from loguru import logger
from transformers import AutoTokenizer,AutoModel
from duckduckgo_search import ddg_suggestions
from duckduckgo_search import ddg_translate, ddg, ddg_news
from langchain.document_loaders import UnstructuredFileLoader
from langchain.text_splitter import CharacterTextSplitter,RecursiveCharacterTextSplitter
from langchain.llms import OpenAI
from langchain.schema import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain,RetrievalQA,LLMChain
from langchain.prompts import PromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain import OpenAI,VectorDBQA
def load_model():
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
# gpu:.half().cuda()
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).quantize(bits=4, compile_parallel_kernel=True, parallel_num=2).float()
model = model.eval()
return tokenizer,model
def chat_glm(input, history=None):
if history is None:
history = []
tokenizer,model = load_model()
response, history = model.chat(tokenizer, input, history)
logger.debug("chatglm:", input,response)
return history, history
def search_web(query):
logger.debug("searchweb:", query)
results = ddg(query)
web_content = ''
if results:
for result in results:
web_content += result['body']
return web_content
def search_vec(query):
logger.debug("searchvec:", query)
embedding_model_name = 'GanymedeNil/text2vec-large-chinese'
vec_path = 'cache'
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
vector_store = FAISS.load_local(vec_path,embeddings)
qa = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type="stuff", vectorstore=vector_store,return_source_documents=True)
result = qa({"query": query})
return result['result']
def chat_gpt(input, use_web, use_vec, history=None):
if history is None:
history = []
# history = [] # 4097 tokens limit
context = "无"
if use_vec:
context = search_vec(input)
prompt_template = f"""基于以下已知信息,请专业地回答用户的问题。
若答案中存在编造成分,请在该部分开头添加“据我推测”。另外,答案请使用中文。
已知内容:
{context}"""+"""
问题:
{question}"""
prompt = PromptTemplate(template=prompt_template,input_variables=["question"])
llm = OpenAI(temperature = 0.2)
chain = LLMChain(llm=llm, prompt=prompt)
result = chain.run(input)
return result
def predict(input,
large_language_model,
use_web,
use_vec,
openai_key,
history=None):
logger.debug("predict..",large_language_model,use_web)
if openai_key is not None:
os.environ['OPENAI_API_KEY'] = openai_key
else:
return '',"You forgot OpenAI API key","You forgot OpenAI API key"
if history == None:
history = []
if large_language_model == "GPT-3.5-turbo":
resp = chat_gpt(input, use_web, use_vec, history)
elif large_language_model == "ChatGLM-6B-int4":
_,resp = chat_glm(input, history)
resp = resp[-1][1]
elif large_language_model == "Search Web":
resp = search_web(input)
elif large_language_model == "Search VectorStore":
resp = search_vec(input)
history.append((input, resp))
return '', history, history
def clear_session():
return '', None
block = gr.Blocks()
with block as demo:
gr.Markdown("""<h1><center>MedKBQA(demo)</center></h1>
<center><font size=3>
本项目基于LangChain、ChatGLM以及Open AI接口, 提供基于本地医药知识的自动问答应用. <br>
</center></font>
""")
with gr.Row():
with gr.Column(scale=1):
model_choose = gr.Accordion("模型选择")
with model_choose:
large_language_model = gr.Dropdown(
["ChatGLM-6B-int4","GPT-3.5-turbo","Search Web","Search VectorStore"],
label="large language model",
value="ChatGLM-6B-int4")
use_web = gr.Radio(["True", "False"],
label="Web Search",
value="False")
use_vec = gr.Radio(["True", "False"],
label="VectorStore Search",
value="False")
openai_key = gr.Textbox(label="请输入OpenAI API key", type="password")
with gr.Column(scale=4):
chatbot = gr.Chatbot(label='ChatLLM').style(height=600)
message = gr.Textbox(label='请输入问题')
state = gr.State()
with gr.Row():
clear_history = gr.Button("🧹 清除历史对话")
send = gr.Button("🚀 发送")
send.click(predict,
inputs=[
message, large_language_model, use_web, use_vec, openai_key, state
],
outputs=[message, chatbot, state])
clear_history.click(fn=clear_session,
inputs=[],
outputs=[chatbot, state],
queue=False)
message.submit(predict,
inputs=[
message, large_language_model, use_web, use_vec, openai_key, state
],
outputs=[message, chatbot, state])
gr.Markdown("""提醒:<br>
1. 使用时请先选择使用chatglm或者chatgpt进行问答. <br>
2. 使用chatgpt时需要输入您的api key.
""")
demo.queue().launch(server_name='0.0.0.0', share=False)