Spaces:
Sleeping
Sleeping
File size: 5,837 Bytes
b4a95f7 d332ff8 b4a95f7 d332ff8 a34348d d332ff8 b4a95f7 d332ff8 b4a95f7 d332ff8 b4a95f7 d332ff8 b4a95f7 d332ff8 b4a95f7 d332ff8 b4a95f7 d332ff8 b4a95f7 d332ff8 b4a95f7 d332ff8 b4a95f7 d332ff8 b4a95f7 d332ff8 b4a95f7 d332ff8 b4a95f7 d332ff8 b4a95f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import openai
import faiss
import numpy as np
import pickle
from tqdm import tqdm
import argparse
import os
from PyPDF2 import PdfReader
class Paper(object):
def __init__(self, pdf_obj: PdfReader) -> None:
self._pdf_obj = pdf_obj
self._paper_meta = self._pdf_obj.metadata
self.texts = []
def iter_pages(self, iter_text_len: int = 1000):
page_idx = 0
for page in self._pdf_obj.pages:
txt = page.extract_text()
for i in range((len(txt) // iter_text_len) + 1):
yield page_idx, i, txt[i * iter_text_len:(i + 1) * iter_text_len]
page_idx += 1
def get_texts(self):
for (page_idx, part_idx, text) in self.iter_pages():
self.texts.append(text.strip())
return self.texts
def create_embeddings(inputs):
"""Create embeddings for the provided input."""
# input = ['ddd','aaa','ccccccccccccccc','ddddd']
result = []
tokens = 0
def get_embedding(input_slice):
input_slice = [input_slice]
embedding = openai.Embedding.create(model="text-embedding-ada-002", input=input_slice)
return [(text, data.embedding) for text, data in zip(input_slice, embedding.data)], embedding.usage.total_tokens
for i in range(0,len(inputs)):
ebd, tk = get_embedding(inputs[i])
tokens += tk
result.extend(ebd)
return result, tokens
def create_embedding(text):
"""Create an embedding for the provided text."""
embedding = openai.Embedding.create(model="text-embedding-ada-002", input=text)
return text, embedding.data[0].embedding
class QA():
def __init__(self,data_embe) -> None:
d = 1536
index = faiss.IndexFlatL2(d)
embe = np.array([emm[1] for emm in data_embe])
data = [emm[0] for emm in data_embe]
index.add(embe)
#所有emdding
self.index = index
#所有文字
self.data = data
print("now all data is:\n",self.data)
def __call__(self, query):
embedding = create_embedding(query)
#输出与用户的问题相关的文字
context = self.get_texts(embedding[1])
#将用户的问题和涉及的文字告诉gpt,并将答案返回
answer = self.completion(query,context)
return answer,context
def get_texts(self,embeding,limit=5):
_,text_index = self.index.search(np.array([embeding]),limit)
context = []
for i in list(text_index[0]):
context.extend(self.data[i:i+2])
# context = [self.data[i] for i in list(text_index[0])]
#输出与用户的问题相关的文字
return context
def completion(self,query, context):
"""Create a completion."""
# lens = [len(text) for text in context]
# maximum = 3000
# for index, l in enumerate(lens):
# maximum -= l
# if maximum < 0:
# context = context[:index + 1]
# print("超过最大长度,截断到前", index + 1, "个片段")
# break
text = "\n".join(f"{index}. {text}" for index, text in enumerate(context))
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{'role': 'system',
'content': f'你是一个有帮助的AI文章助手,从下文中提取有用的内容进行回答,不能回答不在下文提到的内容,相关性从高到底排序:\n\n{text}'},
{'role': 'user', 'content': query},
],
)
print("使用的tokens:", response.usage.total_tokens)
return response.choices[0].message.content
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Document QA")
parser.add_argument("--input_file", default="slimming-pages-1.pdf", dest="input_file", type=str,help="输入文件路径")
# parser.add_argument("--file_embeding", default="input_embed.pkl", dest="file_embeding", type=str,help="文件embeding文件路径")
parser.add_argument("--print_context", action='store_true',help="是否打印上下文")
args = parser.parse_args()
# if os.path.isfile(args.file_embeding):
# data_embe = pickle.load(open(args.file_embeding,'rb'))
# else:
# with open(args.input_file,'r',encoding='utf-8') as f:
# texts = f.readlines()
# #按照行对文章进行切割
# texts = [text.strip() for text in texts if text.strip()]
# data_embe,tokens = create_embeddings(texts)
# pickle.dump(data_embe,open(args.file_embeding,'wb'))
# print("文本消耗 {} tokens".format(tokens))
paper = Paper(args.input_file)
all_texts = paper.get_texts()
data_embe, tokens = create_embeddings(all_texts)
print("全部文本消耗 {} tokens".format(tokens))
qa =QA(data_embe)
limit = 10
while True:
query = input("请输入查询(help可查看指令):")
if query == "quit":
break
elif query.startswith("limit"):
try:
limit = int(query.split(" ")[1])
print("已设置limit为", limit)
except Exception as e:
print("设置limit失败", e)
continue
elif query == "help":
print("输入limit [数字]设置limit")
print("输入quit退出")
continue
answer,context = qa(query)
if args.print_context:
print("已找到相关片段:")
for text in context:
print('\t', text)
print("=====================================")
print("回答如下\n\n")
print(answer.strip())
print("=====================================")
|