ducknew commited on
Commit
c0d8a2a
·
0 Parent(s):

Duplicate from ducknew/MedKBQA

Browse files
Files changed (7) hide show
  1. .gitattributes +35 -0
  2. README.md +14 -0
  3. app.py +156 -0
  4. cache/index.faiss +3 -0
  5. cache/index.pkl +3 -0
  6. requirements.txt +11 -0
  7. setting.toml +37 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ cache/index.faiss filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MedKBQA
3
+ emoji: 💯
4
+ colorFrom: pink
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.28.3
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: ducknew/MedKBQA
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import gradio as gr
4
+ from tqdm import tqdm
5
+ from loguru import logger
6
+ from transformers import AutoTokenizer,AutoModel
7
+ from duckduckgo_search import ddg_suggestions
8
+ from duckduckgo_search import ddg_translate, ddg, ddg_news
9
+
10
+ from langchain.document_loaders import UnstructuredFileLoader
11
+ from langchain.text_splitter import CharacterTextSplitter,RecursiveCharacterTextSplitter
12
+ from langchain.llms import OpenAI
13
+ from langchain.schema import Document
14
+ from langchain.embeddings import OpenAIEmbeddings
15
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
16
+ from langchain.vectorstores import FAISS
17
+ from langchain.chains import ConversationalRetrievalChain,RetrievalQA
18
+ from langchain.prompts import PromptTemplate
19
+ from langchain.prompts.prompt import PromptTemplate
20
+ from langchain.chat_models import ChatOpenAI
21
+
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
24
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).quantize(bits=4, compile_parallel_kernel=True, parallel_num=2).float()
25
+ model = model.eval()
26
+
27
+
28
+ def chat_glm(input, history=None):
29
+ if history is None:
30
+ history = []
31
+ response, history = model.chat(tokenizer, input, history)
32
+ logger.info("chatglm:", input,response)
33
+ return history, history
34
+
35
+ def search_web(query):
36
+ logger.info("searchweb:", query)
37
+ results = ddg(query)
38
+ web_content = ''
39
+ if results:
40
+ for result in results:
41
+ web_content += result['body']
42
+ return web_content
43
+
44
+ def chat_gpt(input, use_web, history=None):
45
+ if history is None:
46
+ history = []
47
+ embedding_model_name = 'GanymedeNil/text2vec-large-chinese'
48
+ vec_path = 'cache'
49
+ embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
50
+
51
+ if use_web:
52
+ web_content = search_web(query)
53
+ else:
54
+ web_content = None
55
+ if web_content:
56
+ prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。
57
+ 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。
58
+ 已知网络检索内容:{web_content}""" + """
59
+ 已知内容:
60
+ {context}
61
+ 问题:
62
+ {question}"""
63
+ else:
64
+ prompt_template = """基于以下已知信息,请简洁并专业地回答用户的问题。
65
+ 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。不允许在答案中添加编造成分。另外,答案请使用中文。
66
+ 已知内容:
67
+ {context}
68
+ 问题:
69
+ {question}"""
70
+
71
+ prompt = PromptTemplate(template=prompt_template,input_variables=["context", "question"])
72
+ vector_store = FAISS.load_local(vec_path,embeddings)
73
+
74
+ qa = RetrievalQA.from_llm(
75
+ llm = ChatOpenAI(temperature=0.7,model_name='gpt-3.5-turbo'),
76
+ retriever = vector_store.as_retriever(search_kwargs={"k": 3}),
77
+ prompt = prompt,
78
+ return_source_documents=True
79
+ )
80
+
81
+ result = qa({"query": query, "chat_history": history})
82
+ logger.info("chatgpt:", input,result)
83
+ return result["answer"]
84
+
85
+ def predict(input,
86
+ large_language_model,
87
+ use_web,
88
+ openai_key,
89
+ history=None):
90
+ logger.info("predict..",large_language_model,use_web)
91
+ if openai_key is not None:
92
+ os.environ['OPENAI_API_KEY'] = openai_key
93
+ else:
94
+ return '',"You forgot OpenAI API key","You forgot OpenAI API key"
95
+ if history == None:
96
+ history = []
97
+
98
+ if large_language_model == "gpt-3.5-turbo":
99
+ resp = chat_gpt(input, use_web, history)
100
+ elif large_language_model == "ChatGLM-6B-int4":
101
+ resp = chat_glm(input, history)
102
+
103
+ history.append((input, resp))
104
+ return '', history, history
105
+
106
+ def clear_session():
107
+ return '', None
108
+
109
+ block = gr.Blocks()
110
+ with block as demo:
111
+ gr.Markdown("""<h1><center>MedKBQA(demo)</center></h1>
112
+ <center><font size=3>
113
+ 本项目基于LangChain、ChatGLM以及Open AI接口, 提供基于本地医药知识的自动问答应用. <br>
114
+ </center></font>
115
+ """)
116
+ with gr.Row():
117
+ with gr.Column(scale=1):
118
+ model_choose = gr.Accordion("模型选择")
119
+ with model_choose:
120
+ large_language_model = gr.Dropdown(
121
+ ["ChatGLM-6B-int4","gpt-3.5-turbo"],
122
+ label="large language model",
123
+ value="ChatGLM-6B-int4")
124
+ use_web = gr.Radio(["True", "False"],
125
+ label="Web Search",
126
+ value="False")
127
+ openai_key = gr.Textbox(label="请输入OpenAI API key", type="password")
128
+ with gr.Column(scale=4):
129
+ chatbot = gr.Chatbot(label='ChatLLM').style(height=600)
130
+ message = gr.Textbox(label='请输入问题')
131
+ state = gr.State()
132
+
133
+ with gr.Row():
134
+ clear_history = gr.Button("🧹 清除历史对话")
135
+ send = gr.Button("🚀 发送")
136
+
137
+ send.click(predict,
138
+ inputs=[
139
+ message, large_language_model, use_web, openai_key, state
140
+ ],
141
+ outputs=[message, chatbot, state])
142
+ clear_history.click(fn=clear_session,
143
+ inputs=[],
144
+ outputs=[chatbot, state],
145
+ queue=False)
146
+
147
+ message.submit(predict,
148
+ inputs=[
149
+ message, large_language_model, use_web, openai_key, state
150
+ ],
151
+ outputs=[message, chatbot, state])
152
+ gr.Markdown("""提醒:<br>
153
+ 1. 使用时请先选择使用chatglm或者chatgpt进行问答. <br>
154
+ 2. 使用chatgpt时需要输入您的api key.
155
+ """)
156
+ demo.queue().launch(server_name='0.0.0.0', share=False)
cache/index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad3bddc5b874aec3b07d734adf6253f53c490570c3c00927d8d77ca12251eb91
3
+ size 91779117
cache/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b3171eebfa281ad94146ac42f6e0263c6b7e9a26de0616fb631253cf1d4d4df
3
+ size 2297073
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers>=4.27.1
4
+ cpm_kernels
5
+ icetk
6
+ dynaconf
7
+ duckduckgo_search
8
+ faiss-cpu
9
+ sentence-transformers
10
+ langchain
11
+ loguru
setting.toml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [models]
2
+
3
+ [models.llm."chatglm-6b"]
4
+ type = "chatglm"
5
+ path = "THUDM/chatglm-6b"
6
+ [models.llm."chatglm-6b-int8"]
7
+ type = "chatglm"
8
+ path = "THUDM/chatglm-6b-int8"
9
+ [models.llm."chatglm-6b-int4"]
10
+ type = "chatglm"
11
+ path = "THUDM/chatglm-6b-int4"
12
+ [models.llm."phoenix-inst-chat-7b"]
13
+ type = "phoenix"
14
+ path = "FreedomIntelligence/phoenix-inst-chat-7b"
15
+ [models.llm."phoenix-inst-chat-7b-int4"]
16
+ type = "phoenix"
17
+ path = "FreedomIntelligence/phoenix-inst-chat-7b-int4"
18
+
19
+ [models.embeddings]
20
+ [models.embeddings."text2vec-large-chinese"]
21
+ type = "default"
22
+ path = "GanymedeNil/text2vec-large-chinese"
23
+ [models.embeddings."text2vec-base"]
24
+ type = "default"
25
+ path = "shibing624/text2vec-base-chinese"
26
+ [models.embeddings."text2vec-base"]
27
+ type = "default"
28
+ path = "shibing624/text2vec-base-chinese"
29
+ [models.embeddings."sentence-transformers"]
30
+ type = "default"
31
+ path = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
32
+ [models.embeddings."ernie-tiny"]
33
+ type = "default"
34
+ path = "nghuyong/ernie-3.0-nano-zh"
35
+ [models.embeddings."ernie-base"]
36
+ type = "default"
37
+ path = "nghuyong/ernie-3.0-base-zh"