yuxj commited on
Commit
4f65819
1 Parent(s): 9fb44c2
Files changed (2) hide show
  1. app.py +101 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from langchain.embeddings import HuggingFaceEmbeddings
3
+ from langchain.vectorstores import FAISS
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from duckduckgo_search import ddg
6
+ import time
7
+ import gradio as gr
8
+
9
+ def best_device():
10
+ if torch.cuda.is_available():
11
+ return 'cuda'
12
+ if torch.backends.mps.is_available():
13
+ return 'mps'
14
+ return 'cpu'
15
+
16
+ device = best_device()
17
+ embeddings = HuggingFaceEmbeddings(model_name = 'GanymedeNil/text2vec-large-chinese', model_kwargs={'device': device})
18
+ local_db = FAISS.load_local('/kaggle/input/text2vec', embeddings)
19
+
20
+ model_name = 'THUDM/chatglm-6b-int4'
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code = True)
22
+ if device == 'cuda':
23
+ model = AutoModel.from_pretrained(model_name, trust_remote_code = True).half().cuda().eval()
24
+ elif device == 'mps':
25
+ model = AutoModel.from_pretrained(model_name, trust_remote_code = True).half().to("mps").eval()
26
+ else:
27
+ model = AutoModel.from_pretrained(model_name, trust_remote_code = True).float().eval()
28
+
29
+ def local_query(text, top_k = 3):
30
+ docs_and_scores = local_db.similarity_search_with_score(text)
31
+ docs_and_scores.sort(key = lambda x : x[1])
32
+ local_content = ''
33
+ count = 0
34
+ for doc in docs_and_scores:
35
+ if count < top_k:
36
+ local_content += doc[0].page_content.replace(' ', '') + '\n'
37
+ count += 1
38
+ return local_content
39
+
40
+ def web_search(text, limit = 3):
41
+ web_content = ''
42
+ try:
43
+ results = ddg(text)
44
+ if results:
45
+ count = 0
46
+ for result in results:
47
+ if count < limit:
48
+ web_content += result['body'] + "\n"
49
+ count += 1
50
+ except Exception as e:
51
+ print(f"网络检索异常:{text}")
52
+ return web_content
53
+
54
+ def ask_question(question, local_content = '', web_content = ''):
55
+ question = f'简洁和专业的来回答我的问题。\n如果你不知道答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n我的问题是:\n{question}'
56
+ if len(web_content) > 0:
57
+ if len(local_content) > 0:
58
+ question = f'基于以下已知信息,简洁和专业的来回答我的问题。\n如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n已知信息是:\n{web_content}\n{local_content}\n我的问题是:\n{question}'
59
+ else:
60
+ question = f'基于以下已知信息,简洁和专业的来回答我的问题。\n如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n已知信息是:\n{web_content}\n我的问题是:\n{question}'
61
+ elif len(local_content) > 0:
62
+ question = f'基于以下已知信息,简洁和专业的来回答我的问题。\n如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n已知信息是:\n{local_content}\n我的问题是:\n{question}'
63
+ response, history = model.chat(tokenizer, question, history = [], max_length = 10000, temperature = 0.1)
64
+ return response
65
+
66
+ def on_click(question, kb_types):
67
+ if best_device() == 'cuda':
68
+ torch.cuda.empty_cache()
69
+
70
+ print("问题 [" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + "]: \n", question + "\n\n")
71
+ local_content = ''
72
+ if '结合本地数据' in kb_types:
73
+ local_content = local_query(question, 2)
74
+ web_content = ''
75
+ if '结合网络检索' in kb_types:
76
+ web_content = web_search(question, 3)
77
+ if len(local_content) > 0:
78
+ if len(web_content) > 0:
79
+ print('结合本地数据和网络检索 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ')
80
+ else:
81
+ print('结合本地数据 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ')
82
+ elif len(web_content) > 0:
83
+ print('结合网络检索 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ')
84
+ else:
85
+ print('仅用模型数据 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ')
86
+ result = ask_question(question, local_content, web_content)
87
+ print(f'{result}\n\n----------------------------')
88
+
89
+ if best_device() == 'cuda':
90
+ torch.cuda.empty_cache()
91
+ return result
92
+
93
+ with gr.Blocks() as block:
94
+ gr.Markdown('<center><h1>LLM问答机器人测试</h1></center>')
95
+ cg_type = gr.CheckboxGroup(['结合本地数据', '结合网络检索'], label = '知识库类型(不勾选则仅用模型数据):')
96
+ tb_input = gr.Textbox(label = '输入问题(本地数据只有中国历史知识):')
97
+ btn = gr.Button("测试", variant = 'primary')
98
+ tb_output = gr.Textbox(label = 'AI回答:')
99
+ btn.click(fn = on_click, inputs = [tb_input, cg_type], outputs = tb_output)
100
+ block.queue(concurrency_count = 1)
101
+ block.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ sentencepiece
3
+ cpm_kernels
4
+ accelerate
5
+ langchain
6
+ unstructured
7
+ sentence_transformers
8
+ duckduckgo_search
9
+ gradio
10
+ faiss-cpu