parkerjj commited on
Commit
5c10677
·
1 Parent(s): a743ea2

优化预测函数的输入文本打印逻辑,增加文本长度信息;改进长文本处理函数,考虑特殊标记长度以保持句子完整性

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. preprocess.py +9 -3
app.py CHANGED
@@ -52,8 +52,8 @@ async def predict(request: PredictRequest):
52
  try:
53
  input_text = request.text # FastAPI 会自动解析为 PredictRequest 对象
54
  affected_stock_codes = request.stock_codes
55
- print("Input text:", input_text[:200] if len(input_text) > 200 else input_text)
56
- print("Affected stock codes:", affected_stock_codes)
57
  return predict(input_text, affected_stock_codes)
58
  except Exception as e:
59
  return {"error": str(e)}
 
52
  try:
53
  input_text = request.text # FastAPI 会自动解析为 PredictRequest 对象
54
  affected_stock_codes = request.stock_codes
55
+ print(f"Input Text Length: {len(input_text)}, Start with: {input_text[:200] if len(input_text) > 200 else input_text}")
56
+ print("Input stock codes:", affected_stock_codes)
57
  return predict(input_text, affected_stock_codes)
58
  except Exception as e:
59
  return {"error": str(e)}
preprocess.py CHANGED
@@ -10,6 +10,7 @@ import pandas as pd
10
  import time
11
 
12
  # 如果使用 spaCy 进行 NLP 处理
 
13
  import spacy
14
 
15
  # 如果使用某种情感分析工具,比如 Hugging Face 的模型
@@ -225,7 +226,7 @@ def get_document_vector(words, model = word2vec_model):
225
  # 函数:获取情感得分
226
  def process_long_text(text, tokenizer, max_length=512):
227
  """
228
- 将长文本分段并保持句子完整性
229
  """
230
  import nltk
231
  try:
@@ -239,15 +240,19 @@ def process_long_text(text, tokenizer, max_length=512):
239
  nltk.download('punkt_tab')
240
 
241
 
 
 
 
 
 
242
  sentences = nltk.sent_tokenize(text)
243
  segments = []
244
  current_segment = ""
245
 
246
  for sentence in sentences:
247
- print(f"Processing sentence: {sentence}")
248
  # 检查添加当前句子后是否会超过最大长度
249
  test_segment = current_segment + " " + sentence if current_segment else sentence
250
- if len(tokenizer.tokenize(test_segment)) > max_length:
251
  if current_segment:
252
  segments.append(current_segment.strip())
253
  current_segment = sentence
@@ -340,6 +345,7 @@ def get_sentiment_score(text):
340
  return 0.0
341
 
342
 
 
343
  def get_stock_info(stock_code: str, history_days=30):
344
  # 获取股票代码和新闻日期
345
 
 
10
  import time
11
 
12
  # 如果使用 spaCy 进行 NLP 处理
13
+ from regex import R
14
  import spacy
15
 
16
  # 如果使用某种情感分析工具,比如 Hugging Face 的模型
 
226
  # 函数:获取情感得分
227
  def process_long_text(text, tokenizer, max_length=512):
228
  """
229
+ 将长文本分段并保持句子完整性,同时考虑特殊标记的长度
230
  """
231
  import nltk
232
  try:
 
240
  nltk.download('punkt_tab')
241
 
242
 
243
+ # 计算特殊标记占用的长度(CLS, SEP等)
244
+ special_tokens_count = tokenizer.num_special_tokens_to_add()
245
+ # 实际可用于文本的最大长度
246
+ effective_max_length = max_length - special_tokens_count
247
+
248
  sentences = nltk.sent_tokenize(text)
249
  segments = []
250
  current_segment = ""
251
 
252
  for sentence in sentences:
 
253
  # 检查添加当前句子后是否会超过最大长度
254
  test_segment = current_segment + " " + sentence if current_segment else sentence
255
+ if len(tokenizer.tokenize(test_segment)) > effective_max_length:
256
  if current_segment:
257
  segments.append(current_segment.strip())
258
  current_segment = sentence
 
345
  return 0.0
346
 
347
 
348
+
349
  def get_stock_info(stock_code: str, history_days=30):
350
  # 获取股票代码和新闻日期
351