ducknew commited on
Commit
19d8254
·
0 Parent(s):

Duplicate from ducknew/MedKBQA-LLM

Browse files
Files changed (7) hide show
  1. .gitattributes +35 -0
  2. README.md +14 -0
  3. app.py +162 -0
  4. cache/index.faiss +3 -0
  5. cache/index.pkl +3 -0
  6. requirements.txt +13 -0
  7. setting.toml +37 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ cache/index.faiss filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MedKBQA
3
+ emoji: 💯
4
+ colorFrom: pink
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.28.3
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: ducknew/MedKBQA-LLM
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import gradio as gr
4
+ from tqdm import tqdm
5
+ from loguru import logger
6
+ from transformers import AutoTokenizer,AutoModel
7
+ from duckduckgo_search import ddg_suggestions
8
+ from duckduckgo_search import ddg_translate, ddg, ddg_news
9
+
10
+ from langchain.document_loaders import UnstructuredFileLoader
11
+ from langchain.text_splitter import CharacterTextSplitter,RecursiveCharacterTextSplitter
12
+ from langchain.llms import OpenAI
13
+ from langchain.schema import Document
14
+ from langchain.embeddings import OpenAIEmbeddings
15
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
16
+ from langchain.vectorstores import FAISS
17
+ from langchain.chains import ConversationalRetrievalChain,RetrievalQA,LLMChain
18
+ from langchain.prompts import PromptTemplate
19
+ from langchain.prompts.prompt import PromptTemplate
20
+ from langchain.chat_models import ChatOpenAI
21
+ from langchain import OpenAI,VectorDBQA
22
+
23
+ def load_model():
24
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
25
+ # gpu:.half().cuda()
26
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).quantize(bits=4, compile_parallel_kernel=True, parallel_num=2).float()
27
+ model = model.eval()
28
+ return tokenizer,model
29
+
30
+
31
+ def chat_glm(input, history=None):
32
+ if history is None:
33
+ history = []
34
+
35
+ tokenizer,model = load_model()
36
+ response, history = model.chat(tokenizer, input, history)
37
+ logger.debug("chatglm:", input,response)
38
+ return history, history
39
+
40
+ def search_web(query):
41
+ logger.debug("searchweb:", query)
42
+ results = ddg(query)
43
+ web_content = ''
44
+ if results:
45
+ for result in results:
46
+ web_content += result['body']
47
+ return web_content
48
+
49
+ def search_vec(query):
50
+ logger.debug("searchvec:", query)
51
+ embedding_model_name = 'GanymedeNil/text2vec-large-chinese'
52
+ vec_path = 'cache'
53
+ embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
54
+ vector_store = FAISS.load_local(vec_path,embeddings)
55
+
56
+ qa = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type="stuff", vectorstore=vector_store,return_source_documents=True)
57
+ result = qa({"query": query})
58
+ return result['result']
59
+
60
+ def chat_gpt(input, use_web, use_vec, history=None):
61
+ if history is None:
62
+ history = []
63
+ # history = [] # 4097 tokens limit
64
+
65
+ context = "无"
66
+ if use_vec:
67
+ context = search_vec(input)
68
+ prompt_template = f"""基于以下已知信息,请简洁并专业地回答用户的问题。
69
+ 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。若答案中存在编造成分,请在该部分开头添加“据我推测”。另外,答案请使用中文。
70
+ 已知内容:
71
+ {context}"""+"""
72
+ 问题:
73
+ {question}"""
74
+
75
+ prompt = PromptTemplate(template=prompt_template,input_variables=["question"])
76
+
77
+ llm = OpenAI(temperature = 0.2)
78
+ chain = LLMChain(llm=llm, prompt=prompt)
79
+ result = chain.run(input)
80
+ return result
81
+
82
+ def predict(input,
83
+ large_language_model,
84
+ use_web,
85
+ use_vec,
86
+ openai_key,
87
+ history=None):
88
+ logger.debug("predict..",large_language_model,use_web)
89
+ if openai_key is not None:
90
+ os.environ['OPENAI_API_KEY'] = openai_key
91
+ else:
92
+ return '',"You forgot OpenAI API key","You forgot OpenAI API key"
93
+ if history == None:
94
+ history = []
95
+
96
+ if large_language_model == "GPT-3.5-turbo":
97
+ resp = chat_gpt(input, use_web, use_vec, history)
98
+ elif large_language_model == "ChatGLM-6B-int4":
99
+ _,resp = chat_glm(input, history)
100
+ resp = resp[-1][1]
101
+ elif large_language_model == "Search Web":
102
+ resp = search_web(input)
103
+ elif large_language_model == "Search VectorStore":
104
+ resp = search_vec(input)
105
+
106
+ history.append((input, resp))
107
+ return '', history, history
108
+
109
+ def clear_session():
110
+ return '', None
111
+
112
+ block = gr.Blocks()
113
+ with block as demo:
114
+ gr.Markdown("""<h1><center>MedKBQA(demo)</center></h1>
115
+ <center><font size=3>
116
+ 本项目基于LangChain、ChatGLM以及Open AI接口, 提供基于本地医药知识的自动问答应用. <br>
117
+ </center></font>
118
+ """)
119
+ with gr.Row():
120
+ with gr.Column(scale=1):
121
+ model_choose = gr.Accordion("模型选择")
122
+ with model_choose:
123
+ large_language_model = gr.Dropdown(
124
+ ["ChatGLM-6B-int4","GPT-3.5-turbo","Search Web","Search VectorStore"],
125
+ label="large language model",
126
+ value="ChatGLM-6B-int4")
127
+ use_web = gr.Radio(["True", "False"],
128
+ label="Web Search",
129
+ value="False")
130
+ use_vec = gr.Radio(["True", "False"],
131
+ label="VectorStore Search",
132
+ value="False")
133
+ openai_key = gr.Textbox(label="请输入OpenAI API key", type="password")
134
+ with gr.Column(scale=4):
135
+ chatbot = gr.Chatbot(label='ChatLLM').style(height=600)
136
+ message = gr.Textbox(label='请输入问题')
137
+ state = gr.State()
138
+
139
+ with gr.Row():
140
+ clear_history = gr.Button("🧹 清除历史对话")
141
+ send = gr.Button("🚀 发送")
142
+
143
+ send.click(predict,
144
+ inputs=[
145
+ message, large_language_model, use_web, use_vec, openai_key, state
146
+ ],
147
+ outputs=[message, chatbot, state])
148
+ clear_history.click(fn=clear_session,
149
+ inputs=[],
150
+ outputs=[chatbot, state],
151
+ queue=False)
152
+
153
+ message.submit(predict,
154
+ inputs=[
155
+ message, large_language_model, use_web, use_vec, openai_key, state
156
+ ],
157
+ outputs=[message, chatbot, state])
158
+ gr.Markdown("""提醒:<br>
159
+ 1. 使用时请先选择使用chatglm或者chatgpt进行问答. <br>
160
+ 2. 使用chatgpt时需要输入您的api key.
161
+ """)
162
+ demo.queue().launch(server_name='0.0.0.0', share=False)
cache/index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad3bddc5b874aec3b07d734adf6253f53c490570c3c00927d8d77ca12251eb91
3
+ size 91779117
cache/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b3171eebfa281ad94146ac42f6e0263c6b7e9a26de0616fb631253cf1d4d4df
3
+ size 2297073
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers>=4.27.1
4
+ cpm_kernels
5
+ icetk
6
+ dynaconf
7
+ duckduckgo_search
8
+ faiss-cpu
9
+ sentence-transformers
10
+ langchain
11
+ loguru
12
+ openai
13
+ tiktoken
setting.toml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [models]
2
+
3
+ [models.llm."chatglm-6b"]
4
+ type = "chatglm"
5
+ path = "THUDM/chatglm-6b"
6
+ [models.llm."chatglm-6b-int8"]
7
+ type = "chatglm"
8
+ path = "THUDM/chatglm-6b-int8"
9
+ [models.llm."chatglm-6b-int4"]
10
+ type = "chatglm"
11
+ path = "THUDM/chatglm-6b-int4"
12
+ [models.llm."phoenix-inst-chat-7b"]
13
+ type = "phoenix"
14
+ path = "FreedomIntelligence/phoenix-inst-chat-7b"
15
+ [models.llm."phoenix-inst-chat-7b-int4"]
16
+ type = "phoenix"
17
+ path = "FreedomIntelligence/phoenix-inst-chat-7b-int4"
18
+
19
+ [models.embeddings]
20
+ [models.embeddings."text2vec-large-chinese"]
21
+ type = "default"
22
+ path = "GanymedeNil/text2vec-large-chinese"
23
+ [models.embeddings."text2vec-base"]
24
+ type = "default"
25
+ path = "shibing624/text2vec-base-chinese"
26
+ [models.embeddings."text2vec-base"]
27
+ type = "default"
28
+ path = "shibing624/text2vec-base-chinese"
29
+ [models.embeddings."sentence-transformers"]
30
+ type = "default"
31
+ path = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
32
+ [models.embeddings."ernie-tiny"]
33
+ type = "default"
34
+ path = "nghuyong/ernie-3.0-nano-zh"
35
+ [models.embeddings."ernie-base"]
36
+ type = "default"
37
+ path = "nghuyong/ernie-3.0-base-zh"