coldlarry commited on
Commit
d332ff8
1 Parent(s): b5b98f8
Files changed (3) hide show
  1. Document_QA.py +57 -46
  2. app.py +9 -17
  3. requirements.txt +1 -0
Document_QA.py CHANGED
@@ -6,40 +6,43 @@ import pickle
6
  from tqdm import tqdm
7
  import argparse
8
  import os
 
9
 
10
- def create_embeddings(input):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  """Create embeddings for the provided input."""
12
  # input = ['ddd','aaa','ccccccccccccccc','ddddd']
13
  result = []
14
- # limit about 1000 tokens per request
15
- # 记录文章每行的长度
16
- # 0 [100]
17
- # 1 [200]
18
- # 2 [4100]
19
- # 3 [999]
20
- lens = [len(text) for text in input]
21
- query_len = 0
22
- start_index = 0
23
  tokens = 0
24
 
25
  def get_embedding(input_slice):
 
26
  embedding = openai.Embedding.create(model="text-embedding-ada-002", input=input_slice)
27
- #返回了(文字,embedding)和文字的token
28
  return [(text, data.embedding) for text, data in zip(input_slice, embedding.data)], embedding.usage.total_tokens
29
- #将文字的数量按照4096切分成多块,每一块去计算一次embedding,如果不足4096则一次计算所有文本的embedding
30
- for index, l in tqdm(enumerate(lens)):
31
- query_len += l
32
- if query_len > 4096:
33
- ebd, tk = get_embedding(input[start_index:index + 1])
34
- query_len = 0
35
- start_index = index + 1
36
- tokens += tk
37
- result.extend(ebd)
38
-
39
- if query_len > 0:
40
- ebd, tk = get_embedding(input[start_index:])
41
  tokens += tk
42
  result.extend(ebd)
 
43
  return result, tokens
44
 
45
  def create_embedding(text):
@@ -58,33 +61,35 @@ class QA():
58
  self.index = index
59
  #所有文字
60
  self.data = data
 
 
61
  def __call__(self, query):
62
  embedding = create_embedding(query)
63
  #输出与用户的问题相关的文字
64
- context = self.get_texts(embedding[1], limit)
65
  #将用户的问题和涉及的文字告诉gpt,并将答案返回
66
  answer = self.completion(query,context)
67
  return answer,context
68
- def get_texts(self,embeding,limit):
69
  _,text_index = self.index.search(np.array([embeding]),limit)
70
  context = []
71
  for i in list(text_index[0]):
72
- context.extend(self.data[i:i+5])
73
  # context = [self.data[i] for i in list(text_index[0])]
74
  #输出与用户的问题相关的文字
75
  return context
76
 
77
  def completion(self,query, context):
78
  """Create a completion."""
79
- lens = [len(text) for text in context]
80
 
81
- maximum = 3000
82
- for index, l in enumerate(lens):
83
- maximum -= l
84
- if maximum < 0:
85
- context = context[:index + 1]
86
- print("超过最大长度,截断到前", index + 1, "个片段")
87
- break
88
 
89
  text = "\n".join(f"{index}. {text}" for index, text in enumerate(context))
90
  response = openai.ChatCompletion.create(
@@ -100,24 +105,30 @@ class QA():
100
 
101
  if __name__ == '__main__':
102
  parser = argparse.ArgumentParser(description="Document QA")
103
- parser.add_argument("--input_file", default="input.txt", dest="input_file", type=str,help="输入文件路径")
104
- parser.add_argument("--file_embeding", default="input_embed.pkl", dest="file_embeding", type=str,help="文件embeding文件路径")
105
  parser.add_argument("--print_context", action='store_true',help="是否打印上下文")
106
 
107
 
108
  args = parser.parse_args()
109
 
110
- if os.path.isfile(args.file_embeding):
111
- data_embe = pickle.load(open(args.file_embeding,'rb'))
112
- else:
113
- with open(args.input_file,'r',encoding='utf-8') as f:
114
- texts = f.readlines()
115
- #按照行对文章进行切割
116
- texts = [text.strip() for text in texts if text.strip()]
117
- data_embe,tokens = create_embeddings(texts)
118
- pickle.dump(data_embe,open(args.file_embeding,'wb'))
119
- print("文本消耗 {} tokens".format(tokens))
 
 
 
 
120
 
 
 
121
  qa =QA(data_embe)
122
 
123
  limit = 10
 
6
  from tqdm import tqdm
7
  import argparse
8
  import os
9
+ from PyPDF2 import PdfReader
10
 
11
+ class Paper(object):
12
+
13
+ def __init__(self, pdf_path) -> None:
14
+ self._pdf_obj = PdfReader(pdf_path)
15
+ self._paper_meta = self._pdf_obj.metadata
16
+ self.texts = []
17
+
18
+ def iter_pages(self, iter_text_len: int = 1000):
19
+ page_idx = 0
20
+ for page in self._pdf_obj.pages:
21
+ txt = page.extract_text()
22
+ for i in range((len(txt) // iter_text_len) + 1):
23
+ yield page_idx, i, txt[i * iter_text_len:(i + 1) * iter_text_len]
24
+ page_idx += 1
25
+ def get_texts(self):
26
+ for (page_idx, part_idx, text) in self.iter_pages():
27
+ self.texts.append(text.strip())
28
+ return self.texts
29
+
30
+ def create_embeddings(inputs):
31
  """Create embeddings for the provided input."""
32
  # input = ['ddd','aaa','ccccccccccccccc','ddddd']
33
  result = []
 
 
 
 
 
 
 
 
 
34
  tokens = 0
35
 
36
  def get_embedding(input_slice):
37
+ input_slice = [input_slice]
38
  embedding = openai.Embedding.create(model="text-embedding-ada-002", input=input_slice)
 
39
  return [(text, data.embedding) for text, data in zip(input_slice, embedding.data)], embedding.usage.total_tokens
40
+
41
+ for i in range(0,len(inputs)):
42
+ ebd, tk = get_embedding(inputs[i])
 
 
 
 
 
 
 
 
 
43
  tokens += tk
44
  result.extend(ebd)
45
+
46
  return result, tokens
47
 
48
  def create_embedding(text):
 
61
  self.index = index
62
  #所有文字
63
  self.data = data
64
+ print("now all data is:\n",self.data)
65
+
66
  def __call__(self, query):
67
  embedding = create_embedding(query)
68
  #输出与用户的问题相关的文字
69
+ context = self.get_texts(embedding[1])
70
  #将用户的问题和涉及的文字告诉gpt,并将答案返回
71
  answer = self.completion(query,context)
72
  return answer,context
73
+ def get_texts(self,embeding,limit=5):
74
  _,text_index = self.index.search(np.array([embeding]),limit)
75
  context = []
76
  for i in list(text_index[0]):
77
+ context.extend(self.data[i:i+2])
78
  # context = [self.data[i] for i in list(text_index[0])]
79
  #输出与用户的问题相关的文字
80
  return context
81
 
82
  def completion(self,query, context):
83
  """Create a completion."""
84
+ # lens = [len(text) for text in context]
85
 
86
+ # maximum = 3000
87
+ # for index, l in enumerate(lens):
88
+ # maximum -= l
89
+ # if maximum < 0:
90
+ # context = context[:index + 1]
91
+ # print("超过最大长度,截断到前", index + 1, "个片段")
92
+ # break
93
 
94
  text = "\n".join(f"{index}. {text}" for index, text in enumerate(context))
95
  response = openai.ChatCompletion.create(
 
105
 
106
  if __name__ == '__main__':
107
  parser = argparse.ArgumentParser(description="Document QA")
108
+ parser.add_argument("--input_file", default="slimming-pages-1.pdf", dest="input_file", type=str,help="输入文件路径")
109
+ # parser.add_argument("--file_embeding", default="input_embed.pkl", dest="file_embeding", type=str,help="文件embeding文件路径")
110
  parser.add_argument("--print_context", action='store_true',help="是否打印上下文")
111
 
112
 
113
  args = parser.parse_args()
114
 
115
+ # if os.path.isfile(args.file_embeding):
116
+ # data_embe = pickle.load(open(args.file_embeding,'rb'))
117
+ # else:
118
+ # with open(args.input_file,'r',encoding='utf-8') as f:
119
+ # texts = f.readlines()
120
+ # #按照行对文章进行切割
121
+ # texts = [text.strip() for text in texts if text.strip()]
122
+ # data_embe,tokens = create_embeddings(texts)
123
+ # pickle.dump(data_embe,open(args.file_embeding,'wb'))
124
+ # print("文本消耗 {} tokens".format(tokens))
125
+
126
+ paper = Paper(args.input_file)
127
+ all_texts = paper.get_texts()
128
+
129
 
130
+ data_embe, tokens = create_embeddings(all_texts)
131
+ print("全部文本消耗 {} tokens".format(tokens))
132
  qa =QA(data_embe)
133
 
134
  limit = 10
app.py CHANGED
@@ -4,6 +4,8 @@ import openai
4
  # from gpt_reader.prompt import BASE_POINTS
5
  from Document_QA import QA
6
  from Document_QA import create_embeddings
 
 
7
 
8
  class GUI:
9
  def __init__(self):
@@ -14,27 +16,17 @@ class GUI:
14
  #load pdf and create all embedings
15
  def pdf_init(self, api_key, pdf_path):
16
  openai.api_key = api_key
17
- print("--------------pdf_path is:",pdf_path)
18
- with open(pdf_path,'r',encoding='utf-8') as f:
19
- texts = f.readlines()
20
- #按照行对文章进行切割
21
- texts = [text.strip() for text in texts if text.strip()]
22
- self.all_embedding,self.tokens = create_embeddings(texts)
23
  def get_answer(self, question):
24
  qa = QA(self.all_embedding)
25
  answer,context = qa(question)
26
  return answer.strip()
27
 
28
- # def analyse(self, api_key, pdf_file):
29
- # self.session = PaperReader(api_key, points_to_focus=BASE_POINTS)
30
- # return self.session.read_pdf_and_summarize(pdf_file)
31
-
32
- # def ask_question(self, question):
33
- # if self.session == "":
34
- # return "Please upload PDF file first!"
35
- # return self.session.question(question)
36
-
37
-
38
  with gr.Blocks() as demo:
39
  gr.Markdown(
40
  """
@@ -57,4 +49,4 @@ with gr.Blocks() as demo:
57
 
58
  if __name__ == "__main__":
59
  demo.title = "CHATGPT-PAPER-READER"
60
- demo.launch(debug=True,share=True) # add "share=True" to share CHATGPT-PAPER-READER app on Internet.
 
4
  # from gpt_reader.prompt import BASE_POINTS
5
  from Document_QA import QA
6
  from Document_QA import create_embeddings
7
+ from Document_QA import Paper
8
+ from PyPDF2 import PdfReader
9
 
10
  class GUI:
11
  def __init__(self):
 
16
  #load pdf and create all embedings
17
  def pdf_init(self, api_key, pdf_path):
18
  openai.api_key = api_key
19
+ pdf_reader = PdfReader(pdf_path)
20
+ paper = Paper(pdf_reader)
21
+ all_texts = paper.get_texts()
22
+ self.all_embedding, self.tokens = create_embeddings(all_texts)
23
+ print("全部文本消耗 {} tokens".format(self.tokens))
24
+
25
  def get_answer(self, question):
26
  qa = QA(self.all_embedding)
27
  answer,context = qa(question)
28
  return answer.strip()
29
 
 
 
 
 
 
 
 
 
 
 
30
  with gr.Blocks() as demo:
31
  gr.Markdown(
32
  """
 
49
 
50
  if __name__ == "__main__":
51
  demo.title = "CHATGPT-PAPER-READER"
52
+ demo.launch() # add "share=True" to share CHATGPT-PAPER-READER app on Internet.
requirements.txt CHANGED
@@ -2,3 +2,4 @@ numpy
2
  faiss-cpu
3
  tqdm
4
  openai
 
 
2
  faiss-cpu
3
  tqdm
4
  openai
5
+ PyPDF2