ducknew
commited on
Commit
·
19d8254
0
Parent(s):
Duplicate from ducknew/MedKBQA-LLM
Browse files- .gitattributes +35 -0
- README.md +14 -0
- app.py +162 -0
- cache/index.faiss +3 -0
- cache/index.pkl +3 -0
- requirements.txt +13 -0
- setting.toml +37 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
cache/index.faiss filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: MedKBQA
|
3 |
+
emoji: 💯
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.28.3
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
duplicated_from: ducknew/MedKBQA-LLM
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import gradio as gr
|
4 |
+
from tqdm import tqdm
|
5 |
+
from loguru import logger
|
6 |
+
from transformers import AutoTokenizer,AutoModel
|
7 |
+
from duckduckgo_search import ddg_suggestions
|
8 |
+
from duckduckgo_search import ddg_translate, ddg, ddg_news
|
9 |
+
|
10 |
+
from langchain.document_loaders import UnstructuredFileLoader
|
11 |
+
from langchain.text_splitter import CharacterTextSplitter,RecursiveCharacterTextSplitter
|
12 |
+
from langchain.llms import OpenAI
|
13 |
+
from langchain.schema import Document
|
14 |
+
from langchain.embeddings import OpenAIEmbeddings
|
15 |
+
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
16 |
+
from langchain.vectorstores import FAISS
|
17 |
+
from langchain.chains import ConversationalRetrievalChain,RetrievalQA,LLMChain
|
18 |
+
from langchain.prompts import PromptTemplate
|
19 |
+
from langchain.prompts.prompt import PromptTemplate
|
20 |
+
from langchain.chat_models import ChatOpenAI
|
21 |
+
from langchain import OpenAI,VectorDBQA
|
22 |
+
|
23 |
+
def load_model():
|
24 |
+
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
|
25 |
+
# gpu:.half().cuda()
|
26 |
+
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).quantize(bits=4, compile_parallel_kernel=True, parallel_num=2).float()
|
27 |
+
model = model.eval()
|
28 |
+
return tokenizer,model
|
29 |
+
|
30 |
+
|
31 |
+
def chat_glm(input, history=None):
|
32 |
+
if history is None:
|
33 |
+
history = []
|
34 |
+
|
35 |
+
tokenizer,model = load_model()
|
36 |
+
response, history = model.chat(tokenizer, input, history)
|
37 |
+
logger.debug("chatglm:", input,response)
|
38 |
+
return history, history
|
39 |
+
|
40 |
+
def search_web(query):
|
41 |
+
logger.debug("searchweb:", query)
|
42 |
+
results = ddg(query)
|
43 |
+
web_content = ''
|
44 |
+
if results:
|
45 |
+
for result in results:
|
46 |
+
web_content += result['body']
|
47 |
+
return web_content
|
48 |
+
|
49 |
+
def search_vec(query):
|
50 |
+
logger.debug("searchvec:", query)
|
51 |
+
embedding_model_name = 'GanymedeNil/text2vec-large-chinese'
|
52 |
+
vec_path = 'cache'
|
53 |
+
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
|
54 |
+
vector_store = FAISS.load_local(vec_path,embeddings)
|
55 |
+
|
56 |
+
qa = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type="stuff", vectorstore=vector_store,return_source_documents=True)
|
57 |
+
result = qa({"query": query})
|
58 |
+
return result['result']
|
59 |
+
|
60 |
+
def chat_gpt(input, use_web, use_vec, history=None):
|
61 |
+
if history is None:
|
62 |
+
history = []
|
63 |
+
# history = [] # 4097 tokens limit
|
64 |
+
|
65 |
+
context = "无"
|
66 |
+
if use_vec:
|
67 |
+
context = search_vec(input)
|
68 |
+
prompt_template = f"""基于以下已知信息,请简洁并专业地回答用户的问题。
|
69 |
+
如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。若答案中存在编造成分,请在该部分开头添加“据我推测”。另外,答案请使用中文。
|
70 |
+
已知内容:
|
71 |
+
{context}"""+"""
|
72 |
+
问题:
|
73 |
+
{question}"""
|
74 |
+
|
75 |
+
prompt = PromptTemplate(template=prompt_template,input_variables=["question"])
|
76 |
+
|
77 |
+
llm = OpenAI(temperature = 0.2)
|
78 |
+
chain = LLMChain(llm=llm, prompt=prompt)
|
79 |
+
result = chain.run(input)
|
80 |
+
return result
|
81 |
+
|
82 |
+
def predict(input,
|
83 |
+
large_language_model,
|
84 |
+
use_web,
|
85 |
+
use_vec,
|
86 |
+
openai_key,
|
87 |
+
history=None):
|
88 |
+
logger.debug("predict..",large_language_model,use_web)
|
89 |
+
if openai_key is not None:
|
90 |
+
os.environ['OPENAI_API_KEY'] = openai_key
|
91 |
+
else:
|
92 |
+
return '',"You forgot OpenAI API key","You forgot OpenAI API key"
|
93 |
+
if history == None:
|
94 |
+
history = []
|
95 |
+
|
96 |
+
if large_language_model == "GPT-3.5-turbo":
|
97 |
+
resp = chat_gpt(input, use_web, use_vec, history)
|
98 |
+
elif large_language_model == "ChatGLM-6B-int4":
|
99 |
+
_,resp = chat_glm(input, history)
|
100 |
+
resp = resp[-1][1]
|
101 |
+
elif large_language_model == "Search Web":
|
102 |
+
resp = search_web(input)
|
103 |
+
elif large_language_model == "Search VectorStore":
|
104 |
+
resp = search_vec(input)
|
105 |
+
|
106 |
+
history.append((input, resp))
|
107 |
+
return '', history, history
|
108 |
+
|
109 |
+
def clear_session():
|
110 |
+
return '', None
|
111 |
+
|
112 |
+
block = gr.Blocks()
|
113 |
+
with block as demo:
|
114 |
+
gr.Markdown("""<h1><center>MedKBQA(demo)</center></h1>
|
115 |
+
<center><font size=3>
|
116 |
+
本项目基于LangChain、ChatGLM以及Open AI接口, 提供基于本地医药知识的自动问答应用. <br>
|
117 |
+
</center></font>
|
118 |
+
""")
|
119 |
+
with gr.Row():
|
120 |
+
with gr.Column(scale=1):
|
121 |
+
model_choose = gr.Accordion("模型选择")
|
122 |
+
with model_choose:
|
123 |
+
large_language_model = gr.Dropdown(
|
124 |
+
["ChatGLM-6B-int4","GPT-3.5-turbo","Search Web","Search VectorStore"],
|
125 |
+
label="large language model",
|
126 |
+
value="ChatGLM-6B-int4")
|
127 |
+
use_web = gr.Radio(["True", "False"],
|
128 |
+
label="Web Search",
|
129 |
+
value="False")
|
130 |
+
use_vec = gr.Radio(["True", "False"],
|
131 |
+
label="VectorStore Search",
|
132 |
+
value="False")
|
133 |
+
openai_key = gr.Textbox(label="请输入OpenAI API key", type="password")
|
134 |
+
with gr.Column(scale=4):
|
135 |
+
chatbot = gr.Chatbot(label='ChatLLM').style(height=600)
|
136 |
+
message = gr.Textbox(label='请输入问题')
|
137 |
+
state = gr.State()
|
138 |
+
|
139 |
+
with gr.Row():
|
140 |
+
clear_history = gr.Button("🧹 清除历史对话")
|
141 |
+
send = gr.Button("🚀 发送")
|
142 |
+
|
143 |
+
send.click(predict,
|
144 |
+
inputs=[
|
145 |
+
message, large_language_model, use_web, use_vec, openai_key, state
|
146 |
+
],
|
147 |
+
outputs=[message, chatbot, state])
|
148 |
+
clear_history.click(fn=clear_session,
|
149 |
+
inputs=[],
|
150 |
+
outputs=[chatbot, state],
|
151 |
+
queue=False)
|
152 |
+
|
153 |
+
message.submit(predict,
|
154 |
+
inputs=[
|
155 |
+
message, large_language_model, use_web, use_vec, openai_key, state
|
156 |
+
],
|
157 |
+
outputs=[message, chatbot, state])
|
158 |
+
gr.Markdown("""提醒:<br>
|
159 |
+
1. 使用时请先选择使用chatglm或者chatgpt进行问答. <br>
|
160 |
+
2. 使用chatgpt时需要输入您的api key.
|
161 |
+
""")
|
162 |
+
demo.queue().launch(server_name='0.0.0.0', share=False)
|
cache/index.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ad3bddc5b874aec3b07d734adf6253f53c490570c3c00927d8d77ca12251eb91
|
3 |
+
size 91779117
|
cache/index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b3171eebfa281ad94146ac42f6e0263c6b7e9a26de0616fb631253cf1d4d4df
|
3 |
+
size 2297073
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch
|
3 |
+
transformers>=4.27.1
|
4 |
+
cpm_kernels
|
5 |
+
icetk
|
6 |
+
dynaconf
|
7 |
+
duckduckgo_search
|
8 |
+
faiss-cpu
|
9 |
+
sentence-transformers
|
10 |
+
langchain
|
11 |
+
loguru
|
12 |
+
openai
|
13 |
+
tiktoken
|
setting.toml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[models]
|
2 |
+
|
3 |
+
[models.llm."chatglm-6b"]
|
4 |
+
type = "chatglm"
|
5 |
+
path = "THUDM/chatglm-6b"
|
6 |
+
[models.llm."chatglm-6b-int8"]
|
7 |
+
type = "chatglm"
|
8 |
+
path = "THUDM/chatglm-6b-int8"
|
9 |
+
[models.llm."chatglm-6b-int4"]
|
10 |
+
type = "chatglm"
|
11 |
+
path = "THUDM/chatglm-6b-int4"
|
12 |
+
[models.llm."phoenix-inst-chat-7b"]
|
13 |
+
type = "phoenix"
|
14 |
+
path = "FreedomIntelligence/phoenix-inst-chat-7b"
|
15 |
+
[models.llm."phoenix-inst-chat-7b-int4"]
|
16 |
+
type = "phoenix"
|
17 |
+
path = "FreedomIntelligence/phoenix-inst-chat-7b-int4"
|
18 |
+
|
19 |
+
[models.embeddings]
|
20 |
+
[models.embeddings."text2vec-large-chinese"]
|
21 |
+
type = "default"
|
22 |
+
path = "GanymedeNil/text2vec-large-chinese"
|
23 |
+
[models.embeddings."text2vec-base"]
|
24 |
+
type = "default"
|
25 |
+
path = "shibing624/text2vec-base-chinese"
|
26 |
+
[models.embeddings."text2vec-base"]
|
27 |
+
type = "default"
|
28 |
+
path = "shibing624/text2vec-base-chinese"
|
29 |
+
[models.embeddings."sentence-transformers"]
|
30 |
+
type = "default"
|
31 |
+
path = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
32 |
+
[models.embeddings."ernie-tiny"]
|
33 |
+
type = "default"
|
34 |
+
path = "nghuyong/ernie-3.0-nano-zh"
|
35 |
+
[models.embeddings."ernie-base"]
|
36 |
+
type = "default"
|
37 |
+
path = "nghuyong/ernie-3.0-base-zh"
|