Spaces:
Runtime error
Runtime error
import os | |
import time | |
import gradio as gr | |
from tqdm import tqdm | |
from loguru import logger | |
from transformers import AutoTokenizer,AutoModel | |
from duckduckgo_search import ddg_suggestions | |
from duckduckgo_search import ddg_translate, ddg, ddg_news | |
from langchain.document_loaders import UnstructuredFileLoader | |
from langchain.text_splitter import CharacterTextSplitter,RecursiveCharacterTextSplitter | |
from langchain.llms import OpenAI | |
from langchain.schema import Document | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.chains import ConversationalRetrievalChain,RetrievalQA,LLMChain | |
from langchain.prompts import PromptTemplate | |
from langchain.prompts.prompt import PromptTemplate | |
from langchain.chat_models import ChatOpenAI | |
from langchain import OpenAI,VectorDBQA | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) | |
# gpu:.half().cuda() | |
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).quantize(bits=4, compile_parallel_kernel=True, parallel_num=2).float() | |
model = model.eval() | |
return tokenizer,model | |
def chat_glm(input, history=None): | |
if history is None: | |
history = [] | |
tokenizer,model = load_model() | |
response, history = model.chat(tokenizer, input, history) | |
logger.debug("chatglm:", input,response) | |
return history, history | |
def search_web(query): | |
logger.debug("searchweb:", query) | |
results = ddg(query) | |
web_content = '' | |
if results: | |
for result in results: | |
web_content += result['body'] | |
return web_content | |
def search_vec(query): | |
logger.debug("searchvec:", query) | |
embedding_model_name = 'GanymedeNil/text2vec-large-chinese' | |
vec_path = 'cache' | |
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name) | |
vector_store = FAISS.load_local(vec_path,embeddings) | |
qa = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type="stuff", vectorstore=vector_store,return_source_documents=True) | |
result = qa({"query": query}) | |
return result['result'] | |
def chat_gpt(input, use_web, use_vec, history=None): | |
if history is None: | |
history = [] | |
# history = [] # 4097 tokens limit | |
context = "无" | |
if use_vec: | |
context = search_vec(input) | |
prompt_template = f"""基于以下已知信息,请专业地回答用户的问题。 | |
若答案中存在编造成分,请在该部分开头添加“据我推测”。另外,答案请使用中文。 | |
已知内容: | |
{context}"""+""" | |
问题: | |
{question}""" | |
prompt = PromptTemplate(template=prompt_template,input_variables=["question"]) | |
llm = OpenAI(temperature = 0.2) | |
chain = LLMChain(llm=llm, prompt=prompt) | |
result = chain.run(input) | |
return result | |
def predict(input, | |
large_language_model, | |
use_web, | |
use_vec, | |
openai_key, | |
history=None): | |
logger.debug("predict..",large_language_model,use_web) | |
if openai_key is not None: | |
os.environ['OPENAI_API_KEY'] = openai_key | |
else: | |
return '',"You forgot OpenAI API key","You forgot OpenAI API key" | |
if history == None: | |
history = [] | |
if large_language_model == "GPT-3.5-turbo": | |
resp = chat_gpt(input, use_web, use_vec, history) | |
elif large_language_model == "ChatGLM-6B-int4": | |
_,resp = chat_glm(input, history) | |
resp = resp[-1][1] | |
elif large_language_model == "Search Web": | |
resp = search_web(input) | |
elif large_language_model == "Search VectorStore": | |
resp = search_vec(input) | |
history.append((input, resp)) | |
return '', history, history | |
def clear_session(): | |
return '', None | |
block = gr.Blocks() | |
with block as demo: | |
gr.Markdown("""<h1><center>MedKBQA(demo)</center></h1> | |
<center><font size=3> | |
本项目基于LangChain、ChatGLM以及Open AI接口, 提供基于本地医药知识的自动问答应用. <br> | |
</center></font> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
model_choose = gr.Accordion("模型选择") | |
with model_choose: | |
large_language_model = gr.Dropdown( | |
["ChatGLM-6B-int4","GPT-3.5-turbo","Search Web","Search VectorStore"], | |
label="large language model", | |
value="ChatGLM-6B-int4") | |
use_web = gr.Radio(["True", "False"], | |
label="Web Search", | |
value="False") | |
use_vec = gr.Radio(["True", "False"], | |
label="VectorStore Search", | |
value="False") | |
openai_key = gr.Textbox(label="请输入OpenAI API key", type="password") | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot(label='ChatLLM').style(height=600) | |
message = gr.Textbox(label='请输入问题') | |
state = gr.State() | |
with gr.Row(): | |
clear_history = gr.Button("🧹 清除历史对话") | |
send = gr.Button("🚀 发送") | |
send.click(predict, | |
inputs=[ | |
message, large_language_model, use_web, use_vec, openai_key, state | |
], | |
outputs=[message, chatbot, state]) | |
clear_history.click(fn=clear_session, | |
inputs=[], | |
outputs=[chatbot, state], | |
queue=False) | |
message.submit(predict, | |
inputs=[ | |
message, large_language_model, use_web, use_vec, openai_key, state | |
], | |
outputs=[message, chatbot, state]) | |
gr.Markdown("""提醒:<br> | |
1. 使用时请先选择使用chatglm或者chatgpt进行问答. <br> | |
2. 使用chatgpt时需要输入您的api key. | |
""") | |
demo.queue().launch(server_name='0.0.0.0', share=False) |