Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
2c14aaf
1
Parent(s):
29d1af8
分离准备用户输入的模块
Browse files- modules/base_model.py +38 -32
modules/base_model.py
CHANGED
@@ -164,52 +164,35 @@ class BaseLLMModel:
|
|
164 |
status_text = self.token_message()
|
165 |
return chatbot, status_text
|
166 |
|
167 |
-
def
|
168 |
-
self,
|
169 |
-
inputs,
|
170 |
-
chatbot,
|
171 |
-
stream=False,
|
172 |
-
use_websearch=False,
|
173 |
-
files=None,
|
174 |
-
reply_language="中文",
|
175 |
-
should_check_token_count=True,
|
176 |
-
): # repetition_penalty, top_k
|
177 |
-
from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
|
178 |
-
from llama_index.indices.query.schema import QueryBundle
|
179 |
-
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
180 |
-
from langchain.chat_models import ChatOpenAI
|
181 |
-
from llama_index import (
|
182 |
-
GPTSimpleVectorIndex,
|
183 |
-
ServiceContext,
|
184 |
-
LangchainEmbedding,
|
185 |
-
OpenAIEmbedding,
|
186 |
-
)
|
187 |
-
|
188 |
-
logging.info(
|
189 |
-
"输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
|
190 |
-
)
|
191 |
-
if should_check_token_count:
|
192 |
-
yield chatbot + [(inputs, "")], "开始生成回答……"
|
193 |
-
if reply_language == "跟随问题语言(不稳定)":
|
194 |
-
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
|
195 |
old_inputs = None
|
196 |
display_reference = []
|
197 |
limited_context = False
|
198 |
if files:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
limited_context = True
|
200 |
old_inputs = inputs
|
201 |
msg = "加载索引中……(这可能需要几分钟)"
|
202 |
logging.info(msg)
|
203 |
-
yield chatbot + [(inputs, "")], msg
|
204 |
index = construct_index(self.api_key, file_src=files)
|
205 |
assert index is not None, "索引构建失败"
|
206 |
-
msg = "
|
|
|
207 |
if local_embedding or self.model_type != ModelType.OpenAI:
|
208 |
embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
|
209 |
else:
|
210 |
embed_model = OpenAIEmbedding()
|
211 |
-
|
212 |
-
yield chatbot + [(inputs, "")], msg
|
213 |
with retrieve_proxy():
|
214 |
prompt_helper = PromptHelper(
|
215 |
max_input_size=4096,
|
@@ -263,6 +246,29 @@ class BaseLLMModel:
|
|
263 |
)
|
264 |
else:
|
265 |
display_reference = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
if (
|
268 |
self.need_api_key and
|
|
|
164 |
status_text = self.token_message()
|
165 |
return chatbot, status_text
|
166 |
|
167 |
+
def prepare_inputs(self, inputs, use_websearch, files, reply_language):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
old_inputs = None
|
169 |
display_reference = []
|
170 |
limited_context = False
|
171 |
if files:
|
172 |
+
from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
|
173 |
+
from llama_index.indices.query.schema import QueryBundle
|
174 |
+
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
175 |
+
from langchain.chat_models import ChatOpenAI
|
176 |
+
from llama_index import (
|
177 |
+
GPTSimpleVectorIndex,
|
178 |
+
ServiceContext,
|
179 |
+
LangchainEmbedding,
|
180 |
+
OpenAIEmbedding,
|
181 |
+
)
|
182 |
limited_context = True
|
183 |
old_inputs = inputs
|
184 |
msg = "加载索引中……(这可能需要几分钟)"
|
185 |
logging.info(msg)
|
186 |
+
# yield chatbot + [(inputs, "")], msg
|
187 |
index = construct_index(self.api_key, file_src=files)
|
188 |
assert index is not None, "索引构建失败"
|
189 |
+
msg = "索引获取成功,生成回答中……"
|
190 |
+
logging.info(msg)
|
191 |
if local_embedding or self.model_type != ModelType.OpenAI:
|
192 |
embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
|
193 |
else:
|
194 |
embed_model = OpenAIEmbedding()
|
195 |
+
# yield chatbot + [(inputs, "")], msg
|
|
|
196 |
with retrieve_proxy():
|
197 |
prompt_helper = PromptHelper(
|
198 |
max_input_size=4096,
|
|
|
246 |
)
|
247 |
else:
|
248 |
display_reference = ""
|
249 |
+
return limited_context, old_inputs, display_reference, inputs
|
250 |
+
|
251 |
+
def predict(
|
252 |
+
self,
|
253 |
+
inputs,
|
254 |
+
chatbot,
|
255 |
+
stream=False,
|
256 |
+
use_websearch=False,
|
257 |
+
files=None,
|
258 |
+
reply_language="中文",
|
259 |
+
should_check_token_count=True,
|
260 |
+
): # repetition_penalty, top_k
|
261 |
+
|
262 |
+
|
263 |
+
logging.info(
|
264 |
+
"输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
|
265 |
+
)
|
266 |
+
if should_check_token_count:
|
267 |
+
yield chatbot + [(inputs, "")], "开始生成回答……"
|
268 |
+
if reply_language == "跟随问题语言(不稳定)":
|
269 |
+
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
|
270 |
+
|
271 |
+
limited_context, old_inputs, display_reference, inputs = self.prepare_inputs(inputs=inputs, use_websearch=use_websearch, files=files, reply_language=reply_language)
|
272 |
|
273 |
if (
|
274 |
self.need_api_key and
|