Spaces:
Sleeping
Sleeping
import openai | |
import faiss | |
import numpy as np | |
import pickle | |
from tqdm import tqdm | |
import argparse | |
import os | |
from PyPDF2 import PdfReader | |
class Paper(object): | |
def __init__(self, pdf_obj: PdfReader) -> None: | |
self._pdf_obj = pdf_obj | |
self._paper_meta = self._pdf_obj.metadata | |
self.texts = [] | |
def iter_pages(self, iter_text_len: int = 1000): | |
page_idx = 0 | |
for page in self._pdf_obj.pages: | |
txt = page.extract_text() | |
for i in range((len(txt) // iter_text_len) + 1): | |
yield page_idx, i, txt[i * iter_text_len:(i + 1) * iter_text_len] | |
page_idx += 1 | |
def get_texts(self): | |
for (page_idx, part_idx, text) in self.iter_pages(): | |
self.texts.append(text.strip()) | |
return self.texts | |
def create_embeddings(inputs): | |
"""Create embeddings for the provided input.""" | |
# input = ['ddd','aaa','ccccccccccccccc','ddddd'] | |
result = [] | |
tokens = 0 | |
def get_embedding(input_slice): | |
input_slice = [input_slice] | |
embedding = openai.Embedding.create(model="text-embedding-ada-002", input=input_slice) | |
return [(text, data.embedding) for text, data in zip(input_slice, embedding.data)], embedding.usage.total_tokens | |
for i in range(0,len(inputs)): | |
ebd, tk = get_embedding(inputs[i]) | |
tokens += tk | |
result.extend(ebd) | |
return result, tokens | |
def create_embedding(text): | |
"""Create an embedding for the provided text.""" | |
embedding = openai.Embedding.create(model="text-embedding-ada-002", input=text) | |
return text, embedding.data[0].embedding | |
class QA(): | |
def __init__(self,data_embe) -> None: | |
d = 1536 | |
index = faiss.IndexFlatL2(d) | |
embe = np.array([emm[1] for emm in data_embe]) | |
data = [emm[0] for emm in data_embe] | |
index.add(embe) | |
#所有emdding | |
self.index = index | |
#所有文字 | |
self.data = data | |
print("now all data is:\n",self.data) | |
def __call__(self, query): | |
embedding = create_embedding(query) | |
#输出与用户的问题相关的文字 | |
context = self.get_texts(embedding[1]) | |
#将用户的问题和涉及的文字告诉gpt,并将答案返回 | |
answer = self.completion(query,context) | |
return answer,context | |
def get_texts(self,embeding,limit=5): | |
_,text_index = self.index.search(np.array([embeding]),limit) | |
context = [] | |
for i in list(text_index[0]): | |
context.extend(self.data[i:i+2]) | |
# context = [self.data[i] for i in list(text_index[0])] | |
#输出与用户的问题相关的文字 | |
return context | |
def completion(self,query, context): | |
"""Create a completion.""" | |
# lens = [len(text) for text in context] | |
# maximum = 3000 | |
# for index, l in enumerate(lens): | |
# maximum -= l | |
# if maximum < 0: | |
# context = context[:index + 1] | |
# print("超过最大长度,截断到前", index + 1, "个片段") | |
# break | |
text = "\n".join(f"{index}. {text}" for index, text in enumerate(context)) | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{'role': 'system', | |
'content': f'你是一个有帮助的AI文章助手,从下文中提取有用的内容进行回答,不能回答不在下文提到的内容,相关性从高到底排序:\n\n{text}'}, | |
{'role': 'user', 'content': query}, | |
], | |
) | |
print("使用的tokens:", response.usage.total_tokens) | |
return response.choices[0].message.content | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description="Document QA") | |
parser.add_argument("--input_file", default="slimming-pages-1.pdf", dest="input_file", type=str,help="输入文件路径") | |
# parser.add_argument("--file_embeding", default="input_embed.pkl", dest="file_embeding", type=str,help="文件embeding文件路径") | |
parser.add_argument("--print_context", action='store_true',help="是否打印上下文") | |
args = parser.parse_args() | |
# if os.path.isfile(args.file_embeding): | |
# data_embe = pickle.load(open(args.file_embeding,'rb')) | |
# else: | |
# with open(args.input_file,'r',encoding='utf-8') as f: | |
# texts = f.readlines() | |
# #按照行对文章进行切割 | |
# texts = [text.strip() for text in texts if text.strip()] | |
# data_embe,tokens = create_embeddings(texts) | |
# pickle.dump(data_embe,open(args.file_embeding,'wb')) | |
# print("文本消耗 {} tokens".format(tokens)) | |
paper = Paper(args.input_file) | |
all_texts = paper.get_texts() | |
data_embe, tokens = create_embeddings(all_texts) | |
print("全部文本消耗 {} tokens".format(tokens)) | |
qa =QA(data_embe) | |
limit = 10 | |
while True: | |
query = input("请输入查询(help可查看指令):") | |
if query == "quit": | |
break | |
elif query.startswith("limit"): | |
try: | |
limit = int(query.split(" ")[1]) | |
print("已设置limit为", limit) | |
except Exception as e: | |
print("设置limit失败", e) | |
continue | |
elif query == "help": | |
print("输入limit [数字]设置limit") | |
print("输入quit退出") | |
continue | |
answer,context = qa(query) | |
if args.print_context: | |
print("已找到相关片段:") | |
for text in context: | |
print('\t', text) | |
print("=====================================") | |
print("回答如下\n\n") | |
print(answer.strip()) | |
print("=====================================") | |