Spaces:
Running
Running
优化预测函数的输入文本打印逻辑,增加文本长度信息;改进长文本处理函数,考虑特殊标记长度以保持句子完整性
Browse files- app.py +2 -2
- preprocess.py +9 -3
app.py
CHANGED
@@ -52,8 +52,8 @@ async def predict(request: PredictRequest):
|
|
52 |
try:
|
53 |
input_text = request.text # FastAPI 会自动解析为 PredictRequest 对象
|
54 |
affected_stock_codes = request.stock_codes
|
55 |
-
print("Input
|
56 |
-
print("
|
57 |
return predict(input_text, affected_stock_codes)
|
58 |
except Exception as e:
|
59 |
return {"error": str(e)}
|
|
|
52 |
try:
|
53 |
input_text = request.text # FastAPI 会自动解析为 PredictRequest 对象
|
54 |
affected_stock_codes = request.stock_codes
|
55 |
+
print(f"Input Text Length: {len(input_text)}, Start with: {input_text[:200] if len(input_text) > 200 else input_text}")
|
56 |
+
print("Input stock codes:", affected_stock_codes)
|
57 |
return predict(input_text, affected_stock_codes)
|
58 |
except Exception as e:
|
59 |
return {"error": str(e)}
|
preprocess.py
CHANGED
@@ -10,6 +10,7 @@ import pandas as pd
|
|
10 |
import time
|
11 |
|
12 |
# 如果使用 spaCy 进行 NLP 处理
|
|
|
13 |
import spacy
|
14 |
|
15 |
# 如果使用某种情感分析工具,比如 Hugging Face 的模型
|
@@ -225,7 +226,7 @@ def get_document_vector(words, model = word2vec_model):
|
|
225 |
# 函数:获取情感得分
|
226 |
def process_long_text(text, tokenizer, max_length=512):
|
227 |
"""
|
228 |
-
|
229 |
"""
|
230 |
import nltk
|
231 |
try:
|
@@ -239,15 +240,19 @@ def process_long_text(text, tokenizer, max_length=512):
|
|
239 |
nltk.download('punkt_tab')
|
240 |
|
241 |
|
|
|
|
|
|
|
|
|
|
|
242 |
sentences = nltk.sent_tokenize(text)
|
243 |
segments = []
|
244 |
current_segment = ""
|
245 |
|
246 |
for sentence in sentences:
|
247 |
-
print(f"Processing sentence: {sentence}")
|
248 |
# 检查添加当前句子后是否会超过最大长度
|
249 |
test_segment = current_segment + " " + sentence if current_segment else sentence
|
250 |
-
if len(tokenizer.tokenize(test_segment)) >
|
251 |
if current_segment:
|
252 |
segments.append(current_segment.strip())
|
253 |
current_segment = sentence
|
@@ -340,6 +345,7 @@ def get_sentiment_score(text):
|
|
340 |
return 0.0
|
341 |
|
342 |
|
|
|
343 |
def get_stock_info(stock_code: str, history_days=30):
|
344 |
# 获取股票代码和新闻日期
|
345 |
|
|
|
10 |
import time
|
11 |
|
12 |
# 如果使用 spaCy 进行 NLP 处理
|
13 |
+
from regex import R
|
14 |
import spacy
|
15 |
|
16 |
# 如果使用某种情感分析工具,比如 Hugging Face 的模型
|
|
|
226 |
# 函数:获取情感得分
|
227 |
def process_long_text(text, tokenizer, max_length=512):
|
228 |
"""
|
229 |
+
将长文本分段并保持句子完整性,同时考虑特殊标记的长度
|
230 |
"""
|
231 |
import nltk
|
232 |
try:
|
|
|
240 |
nltk.download('punkt_tab')
|
241 |
|
242 |
|
243 |
+
# 计算特殊标记占用的长度(CLS, SEP等)
|
244 |
+
special_tokens_count = tokenizer.num_special_tokens_to_add()
|
245 |
+
# 实际可用于文本的最大长度
|
246 |
+
effective_max_length = max_length - special_tokens_count
|
247 |
+
|
248 |
sentences = nltk.sent_tokenize(text)
|
249 |
segments = []
|
250 |
current_segment = ""
|
251 |
|
252 |
for sentence in sentences:
|
|
|
253 |
# 检查添加当前句子后是否会超过最大长度
|
254 |
test_segment = current_segment + " " + sentence if current_segment else sentence
|
255 |
+
if len(tokenizer.tokenize(test_segment)) > effective_max_length:
|
256 |
if current_segment:
|
257 |
segments.append(current_segment.strip())
|
258 |
current_segment = sentence
|
|
|
345 |
return 0.0
|
346 |
|
347 |
|
348 |
+
|
349 |
def get_stock_info(stock_code: str, history_days=30):
|
350 |
# 获取股票代码和新闻日期
|
351 |
|