Tuchuanhuhuhu commited on
Commit
2c14aaf
·
1 Parent(s): 29d1af8

分离准备用户输入的模块

Browse files
Files changed (1) hide show
  1. modules/base_model.py +38 -32
modules/base_model.py CHANGED
@@ -164,52 +164,35 @@ class BaseLLMModel:
164
  status_text = self.token_message()
165
  return chatbot, status_text
166
 
167
- def predict(
168
- self,
169
- inputs,
170
- chatbot,
171
- stream=False,
172
- use_websearch=False,
173
- files=None,
174
- reply_language="中文",
175
- should_check_token_count=True,
176
- ): # repetition_penalty, top_k
177
- from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
178
- from llama_index.indices.query.schema import QueryBundle
179
- from langchain.embeddings.huggingface import HuggingFaceEmbeddings
180
- from langchain.chat_models import ChatOpenAI
181
- from llama_index import (
182
- GPTSimpleVectorIndex,
183
- ServiceContext,
184
- LangchainEmbedding,
185
- OpenAIEmbedding,
186
- )
187
-
188
- logging.info(
189
- "输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
190
- )
191
- if should_check_token_count:
192
- yield chatbot + [(inputs, "")], "开始生成回答……"
193
- if reply_language == "跟随问题语言(不稳定)":
194
- reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
195
  old_inputs = None
196
  display_reference = []
197
  limited_context = False
198
  if files:
 
 
 
 
 
 
 
 
 
 
199
  limited_context = True
200
  old_inputs = inputs
201
  msg = "加载索引中……(这可能需要几分钟)"
202
  logging.info(msg)
203
- yield chatbot + [(inputs, "")], msg
204
  index = construct_index(self.api_key, file_src=files)
205
  assert index is not None, "索引构建失败"
206
- msg = "索引构建完成,获取回答中……"
 
207
  if local_embedding or self.model_type != ModelType.OpenAI:
208
  embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
209
  else:
210
  embed_model = OpenAIEmbedding()
211
- logging.info(msg)
212
- yield chatbot + [(inputs, "")], msg
213
  with retrieve_proxy():
214
  prompt_helper = PromptHelper(
215
  max_input_size=4096,
@@ -263,6 +246,29 @@ class BaseLLMModel:
263
  )
264
  else:
265
  display_reference = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  if (
268
  self.need_api_key and
 
164
  status_text = self.token_message()
165
  return chatbot, status_text
166
 
167
+ def prepare_inputs(self, inputs, use_websearch, files, reply_language):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  old_inputs = None
169
  display_reference = []
170
  limited_context = False
171
  if files:
172
+ from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
173
+ from llama_index.indices.query.schema import QueryBundle
174
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
175
+ from langchain.chat_models import ChatOpenAI
176
+ from llama_index import (
177
+ GPTSimpleVectorIndex,
178
+ ServiceContext,
179
+ LangchainEmbedding,
180
+ OpenAIEmbedding,
181
+ )
182
  limited_context = True
183
  old_inputs = inputs
184
  msg = "加载索引中……(这可能需要几分钟)"
185
  logging.info(msg)
186
+ # yield chatbot + [(inputs, "")], msg
187
  index = construct_index(self.api_key, file_src=files)
188
  assert index is not None, "索引构建失败"
189
+ msg = "索引获取成功,生成回答中……"
190
+ logging.info(msg)
191
  if local_embedding or self.model_type != ModelType.OpenAI:
192
  embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
193
  else:
194
  embed_model = OpenAIEmbedding()
195
+ # yield chatbot + [(inputs, "")], msg
 
196
  with retrieve_proxy():
197
  prompt_helper = PromptHelper(
198
  max_input_size=4096,
 
246
  )
247
  else:
248
  display_reference = ""
249
+ return limited_context, old_inputs, display_reference, inputs
250
+
251
+ def predict(
252
+ self,
253
+ inputs,
254
+ chatbot,
255
+ stream=False,
256
+ use_websearch=False,
257
+ files=None,
258
+ reply_language="中文",
259
+ should_check_token_count=True,
260
+ ): # repetition_penalty, top_k
261
+
262
+
263
+ logging.info(
264
+ "输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
265
+ )
266
+ if should_check_token_count:
267
+ yield chatbot + [(inputs, "")], "开始生成回答……"
268
+ if reply_language == "跟随问题语言(不稳定)":
269
+ reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
270
+
271
+ limited_context, old_inputs, display_reference, inputs = self.prepare_inputs(inputs=inputs, use_websearch=use_websearch, files=files, reply_language=reply_language)
272
 
273
  if (
274
  self.need_api_key and