File size: 5,246 Bytes
4f65819
 
 
 
 
 
 
a0e96fb
4f65819
 
 
 
 
 
 
 
 
 
0e80fa1
4f65819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0e96fb
4f65819
 
 
 
 
 
 
 
 
 
f61f078
4f65819
 
 
 
 
 
 
 
 
 
 
a0e96fb
4f65819
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from transformers import AutoTokenizer, AutoModel
from duckduckgo_search import ddg
import time
import gradio as gr
import gc

def best_device():
    if torch.cuda.is_available():
        return 'cuda'
    if torch.backends.mps.is_available():
        return 'mps'
    return 'cpu'
    
device = best_device()
embeddings = HuggingFaceEmbeddings(model_name = 'GanymedeNil/text2vec-large-chinese', model_kwargs={'device': device})
local_db = FAISS.load_local('./text2vec/store', embeddings)

model_name = 'THUDM/chatglm-6b-int4'
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code = True)
if device == 'cuda':
    model = AutoModel.from_pretrained(model_name, trust_remote_code = True).half().cuda().eval()
elif device == 'mps':
    model = AutoModel.from_pretrained(model_name, trust_remote_code = True).half().to("mps").eval()
else:
    model = AutoModel.from_pretrained(model_name, trust_remote_code = True).float().eval()
    
def local_query(text, top_k = 3):
    docs_and_scores = local_db.similarity_search_with_score(text)
    docs_and_scores.sort(key = lambda x : x[1])
    local_content = ''
    count = 0
    for doc in docs_and_scores:
        if count < top_k:
            local_content += doc[0].page_content.replace(' ', '') + '\n'
        count += 1
    return local_content

def web_search(text, limit = 3):
    web_content = ''
    try:
        results = ddg(text)
        if results:
            count = 0
            for result in results:
                if count < limit:
                    web_content += result['body'] + "\n"
                count += 1
    except Exception as e:
        print(f"网络检索异常:{text}")
    return web_content

def ask_question(question, local_content = '', web_content = ''):
    question = f'简洁和专业的来回答我的问题。\n如果你不知道答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n我的问题是:\n{question}'
    if len(web_content) > 0:
        if len(local_content) > 0:
            question = f'基于以下已知信息,简洁和专业的来回答我的问题。\n如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n已知信息是:\n{web_content}\n{local_content}\n我的问题是:\n{question}'
        else:
            question = f'基于以下已知信息,简洁和专业的来回答我的问题。\n如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n已知信息是:\n{web_content}\n我的问题是:\n{question}'
    elif len(local_content) > 0:
        question = f'基于以下已知信息,简洁和专业的来回答我的问题。\n如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。\n已知信息是:\n{local_content}\n我的问题是:\n{question}'
    response, history = model.chat(tokenizer, question, history = [], max_length = 10000, temperature = 0.1)
    return response

def on_click(question, kb_types):
    gc.collect()
    if best_device() == 'cuda':
        torch.cuda.empty_cache()
        
    print("问题 [" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + "]: \n", question + "\n\n")
    local_content = ''
    if '结合本地数据' in kb_types:
        local_content = local_query(question, 2)
    web_content = ''
    if '结合网络检索' in kb_types:
        web_content = web_search(question, 3)
    result = ask_question(question, local_content, web_content)
    if len(local_content) > 0:
        if len(web_content) > 0:
            print('结合本地数据和网络检索 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ')
        else:
            print('结合本地数据 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ')
    elif len(web_content) > 0:
        print('结合网络检索 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ')
    else:
        print('仅用模型数据 [' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ']: ')
    print(f'{result}\n\n----------------------------')
    
    gc.collect()
    if best_device() == 'cuda':
        torch.cuda.empty_cache()
    return result

with gr.Blocks() as block:
    gr.Markdown('<center><h1>LLM问答机器人测试</h1></center>')
    cg_type = gr.CheckboxGroup(['结合本地数据', '结合网络检索'], label = '知识库类型(不勾选则仅用模型数据):')
    tb_input = gr.Textbox(label = '输入问题(本地数据只有中国历史知识):')
    btn = gr.Button("测试", variant = 'primary')
    tb_output = gr.Textbox(label = 'AI回答:')
    btn.click(fn = on_click, inputs = [tb_input, cg_type], outputs = tb_output)
block.queue(concurrency_count = 1)
block.launch()