Tuchuanhuhuhu commited on
Commit
0ce1a9f
·
1 Parent(s): 6a88a02

去除llama index,转而使用langchain。索引支持更多文件格式。

Browse files
ChuanhuChatbot.py CHANGED
@@ -15,7 +15,6 @@ from modules.models.models import get_model
15
 
16
  gr.Chatbot._postprocess_chat_messages = postprocess_chat_messages
17
  gr.Chatbot.postprocess = postprocess
18
- PromptHelper.compact_text_chunks = compact_text_chunks
19
 
20
  with open("assets/custom.css", "r", encoding="utf-8") as f:
21
  customCSS = f.read()
 
15
 
16
  gr.Chatbot._postprocess_chat_messages = postprocess_chat_messages
17
  gr.Chatbot.postprocess = postprocess
 
18
 
19
  with open("assets/custom.css", "r", encoding="utf-8") as f:
20
  customCSS = f.read()
modules/{llama_func.py → index_func.py} RENAMED
@@ -1,14 +1,6 @@
1
  import os
2
  import logging
3
 
4
- from llama_index import download_loader
5
- from llama_index import (
6
- Document,
7
- LLMPredictor,
8
- PromptHelper,
9
- QuestionAnswerPrompt,
10
- RefinePrompt,
11
- )
12
  import colorama
13
  import PyPDF2
14
  from tqdm import tqdm
@@ -40,6 +32,10 @@ def block_split(text):
40
 
41
 
42
  def get_documents(file_src):
 
 
 
 
43
  documents = []
44
  logging.debug("Loading documents...")
45
  logging.debug(f"file_src: {file_src}")
@@ -63,34 +59,39 @@ def get_documents(file_src):
63
  pdfReader = PyPDF2.PdfReader(pdfFileObj)
64
  for page in tqdm(pdfReader.pages):
65
  pdftext += page.extract_text()
66
- text_raw = pdftext
67
  elif file_type == ".docx":
68
  logging.debug("Loading Word...")
69
- DocxReader = download_loader("DocxReader")
70
- loader = DocxReader()
71
- text_raw = loader.load_data(file=filepath)[0].text
 
 
 
 
 
72
  elif file_type == ".epub":
73
  logging.debug("Loading EPUB...")
74
- EpubReader = download_loader("EpubReader")
75
- loader = EpubReader()
76
- text_raw = loader.load_data(file=filepath)[0].text
77
  elif file_type == ".xlsx":
78
  logging.debug("Loading Excel...")
79
  text_list = excel_to_string(filepath)
80
  for elem in text_list:
81
- documents.append(Document(elem))
82
  continue
83
  else:
84
  logging.debug("Loading text file...")
85
- with open(filepath, "r", encoding="utf-8") as f:
86
- text_raw = f.read()
 
87
  except Exception as e:
88
  logging.error(f"Error loading file: {filename}")
89
  pass
90
- text = add_space(text_raw)
91
- # text = block_split(text)
92
- # documents += text
93
- documents += [Document(text)]
94
  logging.debug("Documents loaded.")
95
  return documents
96
 
@@ -106,8 +107,7 @@ def construct_index(
106
  separator=" ",
107
  ):
108
  from langchain.chat_models import ChatOpenAI
109
- from langchain.embeddings.huggingface import HuggingFaceEmbeddings
110
- from llama_index import GPTVectorStoreIndex, ServiceContext, LangchainEmbedding, OpenAIEmbedding
111
 
112
  if api_key:
113
  os.environ["OPENAI_API_KEY"] = api_key
@@ -118,38 +118,26 @@ def construct_index(
118
  embedding_limit = None if embedding_limit == 0 else embedding_limit
119
  separator = " " if separator == "" else separator
120
 
121
- prompt_helper = PromptHelper(
122
- max_input_size=max_input_size,
123
- num_output=num_outputs,
124
- max_chunk_overlap=max_chunk_overlap,
125
- embedding_limit=embedding_limit,
126
- chunk_size_limit=600,
127
- separator=separator,
128
- )
129
  index_name = get_index_name(file_src)
130
- if os.path.exists(f"./index/{index_name}.json"):
 
 
 
 
 
 
 
131
  logging.info("找到了缓存的索引文件,加载中……")
132
- return GPTVectorStoreIndex.load_from_disk(f"./index/{index_name}.json")
133
  else:
134
  try:
135
  documents = get_documents(file_src)
136
- if local_embedding:
137
- embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name = "sentence-transformers/distiluse-base-multilingual-cased-v2"))
138
- else:
139
- embed_model = OpenAIEmbedding()
140
  logging.info("构建索引中……")
141
  with retrieve_proxy():
142
- service_context = ServiceContext.from_defaults(
143
- prompt_helper=prompt_helper,
144
- chunk_size_limit=chunk_size_limit,
145
- embed_model=embed_model,
146
- )
147
- index = GPTVectorStoreIndex.from_documents(
148
- documents, service_context=service_context
149
- )
150
  logging.debug("索引构建完成!")
151
  os.makedirs("./index", exist_ok=True)
152
- index.storage_context.persist(f"./index/{index_name}")
153
  logging.debug("索引已保存至本地!")
154
  return index
155
 
@@ -157,10 +145,3 @@ def construct_index(
157
  logging.error("索引构建失败!", e)
158
  print(e)
159
  return None
160
-
161
-
162
- def add_space(text):
163
- punctuations = {",": ", ", "。": "。 ", "?": "? ", "!": "! ", ":": ": ", ";": "; "}
164
- for cn_punc, en_punc in punctuations.items():
165
- text = text.replace(cn_punc, en_punc)
166
- return text
 
1
  import os
2
  import logging
3
 
 
 
 
 
 
 
 
 
4
  import colorama
5
  import PyPDF2
6
  from tqdm import tqdm
 
32
 
33
 
34
  def get_documents(file_src):
35
+ from langchain.schema import Document
36
+ from langchain.text_splitter import TokenTextSplitter
37
+ text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=30)
38
+
39
  documents = []
40
  logging.debug("Loading documents...")
41
  logging.debug(f"file_src: {file_src}")
 
59
  pdfReader = PyPDF2.PdfReader(pdfFileObj)
60
  for page in tqdm(pdfReader.pages):
61
  pdftext += page.extract_text()
62
+ texts = Document(page_content=pdftext, metadata={"source": filepath})
63
  elif file_type == ".docx":
64
  logging.debug("Loading Word...")
65
+ from langchain.document_loaders import UnstructuredWordDocumentLoader
66
+ loader = UnstructuredWordDocumentLoader(filepath)
67
+ texts = loader.load()
68
+ elif file_type == ".pptx":
69
+ logging.debug("Loading PowerPoint...")
70
+ from langchain.document_loaders import UnstructuredPowerPointLoader
71
+ loader = UnstructuredPowerPointLoader(filepath)
72
+ texts = loader.load()
73
  elif file_type == ".epub":
74
  logging.debug("Loading EPUB...")
75
+ from langchain.document_loaders import UnstructuredEPubLoader
76
+ loader = UnstructuredEPubLoader(filepath)
77
+ texts = loader.load()
78
  elif file_type == ".xlsx":
79
  logging.debug("Loading Excel...")
80
  text_list = excel_to_string(filepath)
81
  for elem in text_list:
82
+ documents.append(Document(page_content=elem, metadata={"source": filepath}))
83
  continue
84
  else:
85
  logging.debug("Loading text file...")
86
+ from langchain.document_loaders import TextLoader
87
+ loader = TextLoader(filepath, "utf8")
88
+ texts = loader.load()
89
  except Exception as e:
90
  logging.error(f"Error loading file: {filename}")
91
  pass
92
+
93
+ texts = text_splitter.split_documents(texts)
94
+ documents.extend(texts)
 
95
  logging.debug("Documents loaded.")
96
  return documents
97
 
 
107
  separator=" ",
108
  ):
109
  from langchain.chat_models import ChatOpenAI
110
+ from langchain.vectorstores import FAISS
 
111
 
112
  if api_key:
113
  os.environ["OPENAI_API_KEY"] = api_key
 
118
  embedding_limit = None if embedding_limit == 0 else embedding_limit
119
  separator = " " if separator == "" else separator
120
 
 
 
 
 
 
 
 
 
121
  index_name = get_index_name(file_src)
122
+ index_path = f"./index/{index_name}"
123
+ if local_embedding:
124
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
125
+ embeddings = HuggingFaceEmbeddings(model_name = "sentence-transformers/distiluse-base-multilingual-cased-v2")
126
+ else:
127
+ from langchain.embeddings import OpenAIEmbeddings
128
+ embeddings = OpenAIEmbeddings()
129
+ if os.path.exists(index_path):
130
  logging.info("找到了缓存的索引文件,加载中……")
131
+ return FAISS.load_local(index_path, embeddings)
132
  else:
133
  try:
134
  documents = get_documents(file_src)
 
 
 
 
135
  logging.info("构建索引中……")
136
  with retrieve_proxy():
137
+ index = FAISS.from_documents(documents, embeddings)
 
 
 
 
 
 
 
138
  logging.debug("索引构建完成!")
139
  os.makedirs("./index", exist_ok=True)
140
+ index.save_local(index_path)
141
  logging.debug("索引已保存至本地!")
142
  return index
143
 
 
145
  logging.error("索引构建失败!", e)
146
  print(e)
147
  return None
 
 
 
 
 
 
 
modules/models/base_model.py CHANGED
@@ -19,7 +19,7 @@ import aiohttp
19
  from enum import Enum
20
 
21
  from ..presets import *
22
- from ..llama_func import *
23
  from ..utils import *
24
  from .. import shared
25
  from ..config import retrieve_proxy
@@ -192,53 +192,20 @@ class BaseLLMModel:
192
  limited_context = False
193
  fake_inputs = real_inputs
194
  if files:
195
- from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
196
- from llama_index.indices.query.schema import QueryBundle
197
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
198
- from langchain.chat_models import ChatOpenAI
199
- from llama_index import (
200
- GPTSimpleVectorIndex,
201
- ServiceContext,
202
- LangchainEmbedding,
203
- OpenAIEmbedding,
204
- )
205
  limited_context = True
206
  msg = "加载索引中……"
207
  logging.info(msg)
208
- # yield chatbot + [(inputs, "")], msg
209
  index = construct_index(self.api_key, file_src=files)
210
  assert index is not None, "获取索引失败"
211
  msg = "索引获取成功,生成回答中……"
212
  logging.info(msg)
213
- if local_embedding or self.model_type != ModelType.OpenAI:
214
- embed_model = LangchainEmbedding(HuggingFaceEmbeddings(model_name = "sentence-transformers/distiluse-base-multilingual-cased-v2"))
215
- else:
216
- embed_model = OpenAIEmbedding()
217
- # yield chatbot + [(inputs, "")], msg
218
  with retrieve_proxy():
219
- prompt_helper = PromptHelper(
220
- max_input_size=4096,
221
- num_output=5,
222
- max_chunk_overlap=20,
223
- chunk_size_limit=600,
224
- )
225
- from llama_index import ServiceContext
226
-
227
- service_context = ServiceContext.from_defaults(
228
- prompt_helper=prompt_helper, embed_model=embed_model
229
- )
230
- query_object = GPTVectorStoreIndexQuery(
231
- index.index_struct,
232
- service_context=service_context,
233
- similarity_top_k=5,
234
- vector_store=index._vector_store,
235
- docstore=index._docstore,
236
- response_synthesizer=None
237
- )
238
- query_bundle = QueryBundle(real_inputs)
239
- nodes = query_object.retrieve(query_bundle)
240
- reference_results = [n.node.text for n in nodes]
241
- reference_results = add_source_numbers(reference_results, use_source=False)
242
  display_append = add_details(reference_results)
243
  display_append = "\n\n" + "".join(display_append)
244
  real_inputs = (
 
19
  from enum import Enum
20
 
21
  from ..presets import *
22
+ from ..index_func import *
23
  from ..utils import *
24
  from .. import shared
25
  from ..config import retrieve_proxy
 
192
  limited_context = False
193
  fake_inputs = real_inputs
194
  if files:
 
 
195
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
196
+ from langchain.vectorstores.base import VectorStoreRetriever
 
 
 
 
 
 
197
  limited_context = True
198
  msg = "加载索引中……"
199
  logging.info(msg)
 
200
  index = construct_index(self.api_key, file_src=files)
201
  assert index is not None, "获取索引失败"
202
  msg = "索引获取成功,生成回答中……"
203
  logging.info(msg)
 
 
 
 
 
204
  with retrieve_proxy():
205
+ retriever = VectorStoreRetriever(vectorstore=index, search_type="similarity_score_threshold",search_kwargs={"k":6, "score_threshold": 0.5})
206
+ relevant_documents = retriever.get_relevant_documents(real_inputs)
207
+ reference_results = [[d.page_content.strip("�"), os.path.basename(d.metadata["source"])] for d in relevant_documents]
208
+ reference_results = add_source_numbers(reference_results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  display_append = add_details(reference_results)
210
  display_append = "\n\n" + "".join(display_append)
211
  real_inputs = (
modules/models/models.py CHANGED
@@ -22,7 +22,7 @@ from enum import Enum
22
  import uuid
23
 
24
  from ..presets import *
25
- from ..llama_func import *
26
  from ..utils import *
27
  from .. import shared
28
  from ..config import retrieve_proxy, usage_limit
 
22
  import uuid
23
 
24
  from ..presets import *
25
+ from ..index_func import *
26
  from ..utils import *
27
  from .. import shared
28
  from ..config import retrieve_proxy, usage_limit
modules/overwrites.py CHANGED
@@ -7,7 +7,7 @@ import mdtex2html
7
  from gradio_client import utils as client_utils
8
 
9
  from modules.presets import *
10
- from modules.llama_func import *
11
  from modules.config import render_latex
12
 
13
  def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[str]:
 
7
  from gradio_client import utils as client_utils
8
 
9
  from modules.presets import *
10
+ from modules.index_func import *
11
  from modules.config import render_latex
12
 
13
  def compact_text_chunks(self, prompt: Prompt, text_chunks: List[str]) -> List[str]:
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- gradio==3.30.0
2
  gradio_client==0.1.4
3
  mdtex2html
4
  pypinyin
@@ -16,3 +16,4 @@ pdfplumber
16
  pandas
17
  commentjson
18
  openpyxl
 
 
1
+ gradio==3.28.0
2
  gradio_client==0.1.4
3
  mdtex2html
4
  pypinyin
 
16
  pandas
17
  commentjson
18
  openpyxl
19
+ pandocs