coldlarry commited on
Commit
b4a95f7
1 Parent(s): 5347587

Add application file

Browse files
Files changed (2) hide show
  1. Document_QA.py +149 -0
  2. app.py +59 -0
Document_QA.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import openai
3
+ import faiss
4
+ import numpy as np
5
+ import pickle
6
+ from tqdm import tqdm
7
+ import argparse
8
+ import os
9
+
10
+ def create_embeddings(input):
11
+ """Create embeddings for the provided input."""
12
+ # input = ['ddd','aaa','ccccccccccccccc','ddddd']
13
+ result = []
14
+ # limit about 1000 tokens per request
15
+ # 记录文章每行的长度
16
+ # 0 [100]
17
+ # 1 [200]
18
+ # 2 [4100]
19
+ # 3 [999]
20
+ lens = [len(text) for text in input]
21
+ query_len = 0
22
+ start_index = 0
23
+ tokens = 0
24
+
25
+ def get_embedding(input_slice):
26
+ embedding = openai.Embedding.create(model="text-embedding-ada-002", input=input_slice)
27
+ #返回了(文字,embedding)和文字的token
28
+ return [(text, data.embedding) for text, data in zip(input_slice, embedding.data)], embedding.usage.total_tokens
29
+ #将文字的数量按照4096切分成多块,每一块去计算一次embedding,如果不足4096则一次计算所有文本的embedding
30
+ for index, l in tqdm(enumerate(lens)):
31
+ query_len += l
32
+ if query_len > 4096:
33
+ ebd, tk = get_embedding(input[start_index:index + 1])
34
+ query_len = 0
35
+ start_index = index + 1
36
+ tokens += tk
37
+ result.extend(ebd)
38
+
39
+ if query_len > 0:
40
+ ebd, tk = get_embedding(input[start_index:])
41
+ tokens += tk
42
+ result.extend(ebd)
43
+ return result, tokens
44
+
45
+ def create_embedding(text):
46
+ """Create an embedding for the provided text."""
47
+ embedding = openai.Embedding.create(model="text-embedding-ada-002", input=text)
48
+ return text, embedding.data[0].embedding
49
+
50
+ class QA():
51
+ def __init__(self,data_embe) -> None:
52
+ d = 1536
53
+ index = faiss.IndexFlatL2(d)
54
+ embe = np.array([emm[1] for emm in data_embe])
55
+ data = [emm[0] for emm in data_embe]
56
+ index.add(embe)
57
+ #所有emdding
58
+ self.index = index
59
+ #所有文字
60
+ self.data = data
61
+ def __call__(self, query):
62
+ embedding = create_embedding(query)
63
+ #输出与用户的问题相关的文字
64
+ context = self.get_texts(embedding[1], limit)
65
+ #将用户的问题和涉及的文字告诉gpt,并将答案返回
66
+ answer = self.completion(query,context)
67
+ return answer,context
68
+ def get_texts(self,embeding,limit):
69
+ _,text_index = self.index.search(np.array([embeding]),limit)
70
+ context = []
71
+ for i in list(text_index[0]):
72
+ context.extend(self.data[i:i+5])
73
+ # context = [self.data[i] for i in list(text_index[0])]
74
+ #输出与用户的问题相关的文字
75
+ return context
76
+
77
+ def completion(self,query, context):
78
+ """Create a completion."""
79
+ lens = [len(text) for text in context]
80
+
81
+ maximum = 3000
82
+ for index, l in enumerate(lens):
83
+ maximum -= l
84
+ if maximum < 0:
85
+ context = context[:index + 1]
86
+ print("超过最大长度,截断到前", index + 1, "个片段")
87
+ break
88
+
89
+ text = "\n".join(f"{index}. {text}" for index, text in enumerate(context))
90
+ response = openai.ChatCompletion.create(
91
+ model="gpt-3.5-turbo",
92
+ messages=[
93
+ {'role': 'system',
94
+ 'content': f'你是一个有帮助的AI文章助手,从下文中提取有用的内容进行回答,不能回答不在下文提到的内容,相关性从高到底排序:\n\n{text}'},
95
+ {'role': 'user', 'content': query},
96
+ ],
97
+ )
98
+ print("使用的tokens:", response.usage.total_tokens)
99
+ return response.choices[0].message.content
100
+
101
+ if __name__ == '__main__':
102
+ parser = argparse.ArgumentParser(description="Document QA")
103
+ parser.add_argument("--input_file", default="input.txt", dest="input_file", type=str,help="输入文件路径")
104
+ parser.add_argument("--file_embeding", default="input_embed.pkl", dest="file_embeding", type=str,help="文件embeding文件路径")
105
+ parser.add_argument("--print_context", action='store_true',help="是否打印上下文")
106
+
107
+
108
+ args = parser.parse_args()
109
+
110
+ if os.path.isfile(args.file_embeding):
111
+ data_embe = pickle.load(open(args.file_embeding,'rb'))
112
+ else:
113
+ with open(args.input_file,'r',encoding='utf-8') as f:
114
+ texts = f.readlines()
115
+ #按照行对文章进行切割
116
+ texts = [text.strip() for text in texts if text.strip()]
117
+ data_embe,tokens = create_embeddings(texts)
118
+ pickle.dump(data_embe,open(args.file_embeding,'wb'))
119
+ print("文本消耗 {} tokens".format(tokens))
120
+
121
+ qa =QA(data_embe)
122
+
123
+ limit = 10
124
+ while True:
125
+ query = input("请输入查询(help可查看指令):")
126
+ if query == "quit":
127
+ break
128
+ elif query.startswith("limit"):
129
+ try:
130
+ limit = int(query.split(" ")[1])
131
+ print("已设置limit为", limit)
132
+ except Exception as e:
133
+ print("设置limit失败", e)
134
+ continue
135
+ elif query == "help":
136
+ print("输入limit [数字]设置limit")
137
+ print("输入quit退出")
138
+ continue
139
+ answer,context = qa(query)
140
+ if args.print_context:
141
+ print("已找到相关片段:")
142
+ for text in context:
143
+ print('\t', text)
144
+ print("=====================================")
145
+ print("回答如下\n\n")
146
+ print(answer.strip())
147
+ print("=====================================")
148
+
149
+
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import openai
3
+ # from gpt_reader.pdf_reader import PaperReader
4
+ # from gpt_reader.prompt import BASE_POINTS
5
+ from Document_QA import QA
6
+ from Document_QA import create_embeddings
7
+
8
+ class GUI:
9
+ def __init__(self):
10
+ self.api_key = ""
11
+ self.session = ""
12
+ self.all_embedding =None
13
+ self.tokens = 0
14
+ #load pdf and create all embedings
15
+ def pdf_init(self, api_key, pdf_path):
16
+ openai.api_key = api_key
17
+ with open(pdf_path,'r',encoding='utf-8') as f:
18
+ texts = f.readlines()
19
+ #按照行对文章进行切割
20
+ texts = [text.strip() for text in texts if text.strip()]
21
+ self.all_embedding,self.tokens = create_embeddings(texts)
22
+ def get_answer(self, question):
23
+ qa = QA(self.all_embedding)
24
+ answer,context = qa(question)
25
+ return answer.strip()
26
+
27
+ # def analyse(self, api_key, pdf_file):
28
+ # self.session = PaperReader(api_key, points_to_focus=BASE_POINTS)
29
+ # return self.session.read_pdf_and_summarize(pdf_file)
30
+
31
+ # def ask_question(self, question):
32
+ # if self.session == "":
33
+ # return "Please upload PDF file first!"
34
+ # return self.session.question(question)
35
+
36
+
37
+ with gr.Blocks() as demo:
38
+ gr.Markdown(
39
+ """
40
+ # CHATGPT-PAPER-READER
41
+ """)
42
+
43
+ with gr.Tab("Upload PDF File"):
44
+ pdf_input = gr.File(label="PDF File")
45
+ api_input = gr.Textbox(label="OpenAI API Key")
46
+ #result = gr.Textbox(label="PDF Summary")
47
+ upload_button = gr.Button("Start Analyse")
48
+ with gr.Tab("Ask question about your PDF"):
49
+ question_input = gr.Textbox(label="Your Question", placeholder="Authors of this paper?")
50
+ answer = gr.Textbox(label="Answer")
51
+ ask_button = gr.Button("Ask")
52
+
53
+ app = GUI()
54
+ upload_button.click(fn=app.pdf_init, inputs=[api_input, pdf_input])
55
+ ask_button.click(app.get_answer, inputs=question_input, outputs=answer)
56
+
57
+ if __name__ == "__main__":
58
+ demo.title = "CHATGPT-PAPER-READER"
59
+ demo.launch() # add "share=True" to share CHATGPT-PAPER-READER app on Internet.