Spaces:
Runtime error
Runtime error
File size: 6,150 Bytes
c0d8a2a 354b9c7 c0d8a2a 354b9c7 c0d8a2a d2a326b 354b9c7 d2a326b c0d8a2a b8f4d75 c0d8a2a 354b9c7 c0d8a2a 354b9c7 c0d8a2a 7732e76 c0d8a2a 354b9c7 c0d8a2a 354b9c7 c6608e8 c0d8a2a 354b9c7 c0d8a2a 354b9c7 c6608e8 06e8e4e 354b9c7 5beead5 354b9c7 c0d8a2a 354b9c7 c0d8a2a 354b9c7 c0d8a2a 558756a 354b9c7 c0d8a2a 354b9c7 558756a 354b9c7 c0d8a2a 354b9c7 c0d8a2a 354b9c7 c0d8a2a 354b9c7 c0d8a2a 354b9c7 c0d8a2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import os
import time
import gradio as gr
from tqdm import tqdm
from loguru import logger
from transformers import AutoTokenizer,AutoModel
from duckduckgo_search import ddg_suggestions
from duckduckgo_search import ddg_translate, ddg, ddg_news
from langchain.document_loaders import UnstructuredFileLoader
from langchain.text_splitter import CharacterTextSplitter,RecursiveCharacterTextSplitter
from langchain.llms import OpenAI
from langchain.schema import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain,RetrievalQA,LLMChain
from langchain.prompts import PromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain import OpenAI,VectorDBQA
def load_model():
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
# gpu:.half().cuda()
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).quantize(bits=4, compile_parallel_kernel=True, parallel_num=2).float()
model = model.eval()
return tokenizer,model
def chat_glm(input, history=None):
if history is None:
history = []
tokenizer,model = load_model()
response, history = model.chat(tokenizer, input, history)
logger.debug("chatglm:", input,response)
return history, history
def search_web(query):
logger.debug("searchweb:", query)
results = ddg(query)
web_content = ''
if results:
for result in results:
web_content += result['body']
return web_content
def search_vec(query):
logger.debug("searchvec:", query)
embedding_model_name = 'GanymedeNil/text2vec-large-chinese'
vec_path = 'cache'
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
vector_store = FAISS.load_local(vec_path,embeddings)
qa = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type="stuff", vectorstore=vector_store,return_source_documents=True)
result = qa({"query": query})
return result['result']
def chat_gpt(input, use_web, use_vec, history=None):
if history is None:
history = []
# history = [] # 4097 tokens limit
context = "无"
if use_vec:
context = search_vec(input)
prompt_template = f"""基于以下已知信息,请专业地回答用户的问题。
若答案中存在编造成分,请在该部分开头添加“据我推测”。另外,答案请使用中文。
已知内容:
{context}"""+"""
问题:
{question}"""
prompt = PromptTemplate(template=prompt_template,input_variables=["question"])
llm = OpenAI(temperature = 0.2)
chain = LLMChain(llm=llm, prompt=prompt)
result = chain.run(input)
return result
def predict(input,
large_language_model,
use_web,
use_vec,
openai_key,
history=None):
logger.debug("predict..",large_language_model,use_web)
if openai_key is not None:
os.environ['OPENAI_API_KEY'] = openai_key
else:
return '',"You forgot OpenAI API key","You forgot OpenAI API key"
if history == None:
history = []
if large_language_model == "GPT-3.5-turbo":
resp = chat_gpt(input, use_web, use_vec, history)
elif large_language_model == "ChatGLM-6B-int4":
_,resp = chat_glm(input, history)
resp = resp[-1][1]
elif large_language_model == "Search Web":
resp = search_web(input)
elif large_language_model == "Search VectorStore":
resp = search_vec(input)
history.append((input, resp))
return '', history, history
def clear_session():
return '', None
block = gr.Blocks()
with block as demo:
gr.Markdown("""<h1><center>MedKBQA(demo)</center></h1>
<center><font size=3>
本项目基于LangChain、ChatGLM以及Open AI接口, 提供基于本地医药知识的自动问答应用. <br>
</center></font>
""")
with gr.Row():
with gr.Column(scale=1):
model_choose = gr.Accordion("模型选择")
with model_choose:
large_language_model = gr.Dropdown(
["ChatGLM-6B-int4","GPT-3.5-turbo","Search Web","Search VectorStore"],
label="large language model",
value="ChatGLM-6B-int4")
use_web = gr.Radio(["True", "False"],
label="Web Search",
value="False")
use_vec = gr.Radio(["True", "False"],
label="VectorStore Search",
value="False")
openai_key = gr.Textbox(label="请输入OpenAI API key", type="password")
with gr.Column(scale=4):
chatbot = gr.Chatbot(label='ChatLLM').style(height=600)
message = gr.Textbox(label='请输入问题')
state = gr.State()
with gr.Row():
clear_history = gr.Button("🧹 清除历史对话")
send = gr.Button("🚀 发送")
send.click(predict,
inputs=[
message, large_language_model, use_web, use_vec, openai_key, state
],
outputs=[message, chatbot, state])
clear_history.click(fn=clear_session,
inputs=[],
outputs=[chatbot, state],
queue=False)
message.submit(predict,
inputs=[
message, large_language_model, use_web, use_vec, openai_key, state
],
outputs=[message, chatbot, state])
gr.Markdown("""提醒:<br>
1. 使用时请先选择使用chatglm或者chatgpt进行问答. <br>
2. 使用chatgpt时需要输入您的api key.
""")
demo.queue().launch(server_name='0.0.0.0', share=False) |