Spaces:
Running
Running
import logging | |
import re | |
import akshare as ak | |
import pandas as pd | |
from datetime import datetime, timedelta | |
import time # 导入标准库的 time 模块 | |
import os | |
import requests | |
import threading | |
import asyncio | |
import yfinance | |
logging.basicConfig(level=logging.INFO) | |
# 获取当前文件的目录 | |
base_dir = os.path.dirname(os.path.abspath(__file__)) | |
# 构建CSV文件的绝对路径 | |
nasdaq_100_path = os.path.join(base_dir, './model/nasdaq100.csv') | |
dow_jones_path = os.path.join(base_dir, './model/dji.csv') | |
sp500_path = os.path.join(base_dir, './model/sp500.csv') | |
nasdaq_composite_path = os.path.join(base_dir, './model/nasdaq_all.csv') | |
# 从CSV文件加载成分股数据 | |
nasdaq_100_stocks = pd.read_csv(nasdaq_100_path) | |
dow_jones_stocks = pd.read_csv(dow_jones_path) | |
sp500_stocks = pd.read_csv(sp500_path) | |
nasdaq_composite_stocks = pd.read_csv(nasdaq_composite_path) | |
def fetch_stock_us_spot_data_with_retries(): | |
# 定义重试间隔时间序列(秒) | |
retry_intervals = [10, 20, 60, 300, 600] | |
retry_index = 0 # 初始重试序号 | |
while True: | |
try: | |
# 尝试获取API数据 | |
symbols = ak.stock_us_spot_em() | |
return symbols # 成功获取数据后返回 | |
except Exception as e: | |
print(f"Error fetching data: {e}") | |
# 获取当前重试等待时间 | |
wait_time = retry_intervals[retry_index] | |
print(f"Retrying in {wait_time} seconds...") | |
time.sleep(wait_time) # 等待指定的秒数 | |
# 更新重试索引,但不要超出重试时间列表的范围 | |
retry_index = min(retry_index + 1, len(retry_intervals) - 1) | |
async def fetch_stock_us_spot_data_with_retries_async(): | |
retry_intervals = [10, 20, 60, 300, 600] | |
retry_index = 0 | |
while True: | |
try: | |
symbols = await asyncio.to_thread(ak.stock_us_spot_em) | |
return symbols | |
except Exception as e: | |
print(f"Error fetching data: {e}") | |
wait_time = retry_intervals[retry_index] | |
print(f"Retrying in {wait_time} seconds...") | |
await asyncio.sleep(wait_time) | |
retry_index = min(retry_index + 1, len(retry_intervals) - 1) | |
symbols = None | |
async def fetch_symbols(): | |
global symbols | |
# 异步获取数据 | |
symbols = await fetch_stock_us_spot_data_with_retries_async() | |
print("Symbols initialized:", symbols) | |
# 全局变量 | |
index_us_stock_index_INX = None | |
index_us_stock_index_DJI = None | |
index_us_stock_index_IXIC = None | |
index_us_stock_index_NDX = None | |
def update_stock_indices(): | |
global index_us_stock_index_INX, index_us_stock_index_DJI, index_us_stock_index_IXIC, index_us_stock_index_NDX | |
try: | |
index_us_stock_index_INX = ak.index_us_stock_sina(symbol=".INX") | |
index_us_stock_index_DJI = ak.index_us_stock_sina(symbol=".DJI") | |
index_us_stock_index_IXIC = ak.index_us_stock_sina(symbol=".IXIC") | |
index_us_stock_index_NDX = ak.index_us_stock_sina(symbol=".NDX") | |
print("Stock indices updated") | |
except Exception as e: | |
print(f"Error updating stock indices: {e}") | |
# 设置定时器,每隔12小时更新一次 | |
threading.Timer(12 * 60 * 60, update_stock_indices).start() | |
# 程序开始时立即更新一次 | |
update_stock_indices() | |
# 创建列名转换的字典 | |
column_mapping = { | |
'日期': 'date', | |
'开盘': 'open', | |
'收盘': 'close', | |
'最高': 'high', | |
'最低': 'low', | |
'成交量': 'volume', | |
'成交额': 'amount', | |
'振幅': 'amplitude', | |
'涨跌幅': 'price_change_percentage', | |
'涨跌额': 'price_change_amount', | |
'换手率': 'turnover_rate' | |
} | |
# 定义一个标准的列顺序 | |
standard_columns = ['date', 'open', 'close', 'high', 'low', 'volume', 'amount'] | |
# 定义查找函数 | |
def find_stock_entry(stock_code): | |
# 使用 str.endswith 来匹配股票代码 | |
matching_row = symbols[symbols['代码'].str.endswith(stock_code)] | |
# print(symbols) | |
if not matching_row.empty: | |
# print(f"股票代码 {stock_code} 找到, 代码为 {matching_row['代码'].values[0]}") | |
return matching_row['代码'].values[0] | |
else: | |
return "" | |
''' | |
# 示例调用 | |
# 测试函数 | |
result = find_stock_entry('AAPL') | |
if isinstance(result, pd.DataFrame) and not result.empty: | |
# 如果找到的结果不为空,获取代码列的值 | |
code_value = result['代码'].values[0] | |
print(code_value) | |
else: | |
print(result) | |
''' | |
def reduce_columns(df, columns_to_keep): | |
return df[columns_to_keep] | |
# 创建缓存字典 | |
_price_cache = {} | |
def get_last_minute_stock_price(symbol: str, max_retries=3) -> float: | |
"""获取股票最新价格,使用30分钟缓存,并包含重试机制""" | |
if not symbol: | |
return -1.0 | |
if symbol == "NONE_SYMBOL_FOUND": | |
return -1.0 | |
current_time = datetime.now() | |
# 检查缓存 | |
if symbol in _price_cache: | |
cached_price, cached_time = _price_cache[symbol] | |
# 如果缓存时间在30分钟内,直接返回缓存的价格 | |
if current_time - cached_time < timedelta(minutes=30): | |
return cached_price | |
# 重试机制 | |
for attempt in range(max_retries): | |
try: | |
# 缓存无效或不存在,从yfinance获取新数据 | |
stock_data = yfinance.download( | |
symbol, | |
period='1d', | |
interval='5m', | |
progress=False, # 禁用进度条 | |
timeout=10 # 设置超时时间 | |
) | |
if stock_data.empty: | |
print(f"Warning: Empty data received for {symbol}, attempt {attempt + 1}/{max_retries}") | |
if attempt == max_retries - 1: | |
return -1.0 | |
time.sleep(1) # 等待1秒后重试 | |
continue | |
latest_price = float(stock_data['Close'].iloc[-1]) | |
# 更新缓存 | |
_price_cache[symbol] = (latest_price, current_time) | |
return latest_price | |
except Exception as e: | |
print(f"Error fetching price for {symbol}, attempt {attempt + 1}/{max_retries}: {str(e)}") | |
if attempt == max_retries - 1: | |
return -1.0 | |
time.sleep(1) # 等待1秒后重试 | |
return -1.0 | |
# 返回个股历史数据 | |
def get_stock_history(symbol, news_date, retries=10): | |
# 定义重试间隔时间序列(秒) | |
retry_intervals = [10, 20, 60, 300, 600] | |
retry_count = 0 | |
# 如果传入的symbol不包含数字前缀,则通过 find_stock_entry 获取完整的symbol | |
if not any(char.isdigit() for char in symbol): | |
full_symbol = find_stock_entry(symbol) | |
if len(symbol) != 0 and full_symbol: | |
symbol = full_symbol | |
else: | |
symbol = "" | |
# 将news_date转换为datetime对象 | |
current_date = datetime.now() | |
# 计算start_date和end_date | |
start_date = (current_date - timedelta(days=60)).strftime("%Y%m%d") | |
end_date = current_date.strftime("%Y%m%d") | |
stock_hist_df = None | |
retry_index = 0 # 初始化重试索引 | |
while retry_count <= retries and len(symbol) != 0: # 无限循环重试 | |
try: | |
# 尝试获取API数据 | |
stock_hist_df = ak.stock_us_hist(symbol=symbol, period="daily", start_date=start_date, end_date=end_date, adjust="") | |
if stock_hist_df.empty: # 检查是否为空数据 | |
# print(f"No data for {symbol} on {news_date}.") | |
stock_hist_df = None # 将 DataFrame 设置为 None | |
break | |
except (requests.exceptions.Timeout, ConnectionError) as e: | |
print(f"Request timed out: {e}. Retrying...") | |
retry_count += 1 # 增加重试次数 | |
continue | |
except (TypeError, ValueError, BaseException) as e: | |
print(f"Error {e} scraping data for {symbol} on {news_date}. Break...") | |
# 可能是没数据,直接Break | |
break | |
# 如果发生异常,等待一段时间再重试 | |
wait_time = retry_intervals[retry_index] | |
print(f"Waiting for {wait_time} seconds before retrying...") | |
time.sleep(wait_time) | |
retry_index = (retry_index + 1) if retry_index < len(retry_intervals) - 1 else retry_index # 更新重试索引,不超过列表长度 | |
# 如果获取失败或数据为空,返回填充为0的 DataFrame | |
if stock_hist_df is None or stock_hist_df.empty: | |
# 构建一个空的 DataFrame,包含指定日期范围的空数据 | |
date_range = pd.date_range(start=start_date, end=end_date) | |
stock_hist_df = pd.DataFrame({ | |
'date': date_range, | |
'开盘': 0, | |
'收盘': 0, | |
'最高': 0, | |
'最低': 0, | |
'成交量': 0, | |
'成交额': 0, | |
'振幅': 0, | |
'涨跌幅': 0, | |
'涨跌额': 0, | |
'换手率': 0 | |
}) | |
# 使用rename方法转换列名 | |
stock_hist_df = stock_hist_df.rename(columns=column_mapping) | |
stock_hist_df = stock_hist_df.reindex(columns=standard_columns) | |
# 处理个股数据,保留所需列 | |
stock_hist_df = reduce_columns(stock_hist_df, standard_columns) | |
return stock_hist_df | |
# 统一列名 | |
stock_hist_df = stock_hist_df.rename(columns=column_mapping) | |
stock_hist_df = stock_hist_df.reindex(columns=standard_columns) | |
# 处理个股数据,保留所需列 | |
stock_hist_df = reduce_columns(stock_hist_df, standard_columns) | |
return stock_hist_df | |
''' | |
# 示例调用 | |
result = get_stock_history('AAPL', '20240214') | |
print(result) | |
''' | |
# result = get_stock_history('ATMU', '20231218') | |
# print(result) | |
# 返回个股所属指数历史数据 | |
def get_stock_index_history(symbol, news_date, force_index=0): | |
# 检查股票所属的指数 | |
if symbol in nasdaq_100_stocks['Symbol'].values or force_index == 1: | |
index_code = ".NDX" | |
index_data = index_us_stock_index_NDX | |
elif symbol in dow_jones_stocks['Symbol'].values or force_index == 2: | |
index_code = ".DJI" | |
index_data = index_us_stock_index_DJI | |
elif symbol in sp500_stocks['Symbol'].values or force_index == 3: | |
index_code = ".INX" | |
index_data = index_us_stock_index_INX | |
elif symbol in nasdaq_composite_stocks["Symbol"].values or symbol is None or symbol == "" or force_index == 4: | |
index_code = ".IXIC" | |
index_data = index_us_stock_index_IXIC | |
else: | |
# print(f"股票代码 {symbol} 不属于纳斯达克100、道琼斯工业、标准普尔500或纳斯达克综合指数。") | |
index_code = ".IXIC" | |
index_data = index_us_stock_index_IXIC | |
# 获取当前日期 | |
current_date = datetime.now() | |
# 计算 start_date 和 end_date | |
start_date = (current_date - timedelta(weeks=8)).strftime("%Y-%m-%d") | |
end_date = current_date.strftime("%Y-%m-%d") | |
# 确保 index_data['date'] 是 datetime 类型 | |
index_data['date'] = pd.to_datetime(index_data['date']) | |
# 从指数历史数据中提取指定日期范围的数据 | |
index_hist_df = index_data[(index_data['date'] >= start_date) & (index_data['date'] <= end_date)] | |
# 统一列名 | |
index_hist_df = index_hist_df.rename(columns=column_mapping) | |
index_hist_df = index_hist_df.reindex(columns=standard_columns) | |
# 处理个股数据,保留所需列 | |
index_hist_df = reduce_columns(index_hist_df, standard_columns) | |
return index_hist_df | |
''' | |
# 示例调用 | |
result = get_stock_index_history('AAPL', '20240214') | |
print(result) | |
''' | |
def find_stock_codes_or_names(entities): | |
""" | |
从给定的实体列表中检索股票代码或公司名称。 | |
:param entities: 命名实体识别结果列表,格式为 [('实体名称', '实体类型'), ...] | |
:return: 相关的股票代码列表 | |
""" | |
stock_codes = set() | |
# 合并所有股票字典并清理数据,确保都是字符串 | |
all_symbols = pd.concat([nasdaq_100_stocks['Symbol'], | |
dow_jones_stocks['Symbol'], | |
sp500_stocks['Symbol'], | |
nasdaq_composite_stocks['Symbol']]).dropna().astype(str).unique().tolist() | |
all_names = pd.concat([nasdaq_100_stocks['Name'], | |
nasdaq_composite_stocks['Name'], | |
sp500_stocks['Security'], | |
dow_jones_stocks['Company']]).dropna().astype(str).unique().tolist() | |
# 创建一个 Name 到 Symbol 的映射 | |
name_to_symbol = {} | |
for idx, name in enumerate(all_names): | |
if idx < len(all_symbols): | |
symbol = all_symbols[idx] | |
name_to_symbol[name.lower()] = symbol | |
# 查找实体映射到的股票代码 | |
for entity, entity_type in entities: | |
entity_lower = entity.lower() | |
entity_upper = entity.upper() | |
# 检查 Symbol 列 | |
if entity_upper in all_symbols: | |
stock_codes.add(entity_upper) | |
#print(f"Matched symbol: {entity_upper}") | |
# 检查 Name 列,确保完整匹配而不是部分匹配 | |
for name, symbol in name_to_symbol.items(): | |
# 使用正则表达式进行严格匹配 | |
pattern = rf'\b{re.escape(entity_lower)}\b' | |
if re.search(pattern, name): | |
stock_codes.add(symbol.upper()) | |
#print(f"Matched name/company: '{entity_lower}' in '{name}' -> {symbol.upper()}") | |
#print(f"Stock codes found: {stock_codes}") | |
if not stock_codes: | |
return ['NONE_SYMBOL_FOUND'] | |
return list(stock_codes) | |
def process_history(stock_history, target_date, history_days=30, following_days=3): | |
# 检查数据是否为空 | |
if stock_history.empty: | |
return create_empty_data(history_days), create_empty_data(following_days) | |
# 确保日期列存在并转换为datetime格式 | |
if 'date' not in stock_history.columns: | |
return create_empty_data(history_days), create_empty_data(following_days) | |
stock_history['date'] = pd.to_datetime(stock_history['date']) | |
target_date = pd.to_datetime(target_date) | |
# 按日期升序排序 | |
stock_history = stock_history.sort_values('date') | |
# 找到目标日期对应的索引 | |
target_row = stock_history[stock_history['date'] <= target_date] | |
if target_row.empty: | |
return create_empty_data(history_days), create_empty_data(following_days) | |
# 获取目标日期最近的行 | |
target_index = target_row.index[-1] | |
target_pos = stock_history.index.get_loc(target_index) | |
# 获取历史数据(包括目标日期) | |
start_pos = max(0, target_pos - history_days + 1) | |
previous_rows = stock_history.iloc[start_pos:target_pos + 1] | |
# 获取后续数据 | |
following_rows = stock_history.iloc[target_pos + 1:target_pos + following_days + 1] | |
# 删除日期列并确保数据完整性 | |
previous_rows = previous_rows.drop(columns=['date']) | |
following_rows = following_rows.drop(columns=['date']) | |
# 处理数据不足的情况 | |
previous_rows = handle_insufficient_data(previous_rows, history_days) | |
following_rows = handle_insufficient_data(following_rows, following_days) | |
return previous_rows.iloc[:, :6], following_rows.iloc[:, :6] | |
def create_empty_data(days): | |
return pd.DataFrame({ | |
'开盘': [-1] * days, | |
'收盘': [-1] * days, | |
'最高': [-1] * days, | |
'最低': [-1] * days, | |
'成交量': [-1] * days, | |
'成交额': [-1] * days | |
}) | |
def handle_insufficient_data(data, required_days): | |
current_rows = len(data) | |
if current_rows < required_days: | |
missing_rows = required_days - current_rows | |
empty_data = create_empty_data(missing_rows) | |
return pd.concat([empty_data, data]).reset_index(drop=True) | |
return data | |
if __name__ == "__main__": | |
# 测试函数 | |
result = find_stock_entry('AAPL') | |
print(f"find_stock_entry: {result}") | |
result = get_stock_history('AAPL', '20240214') | |
print(f"get_stock_history: {result}") | |
result = get_stock_index_history('AAPL', '20240214') | |
print(f"get_stock_index_history: {result}") | |
result = find_stock_codes_or_names([('苹果', 'ORG'), ('苹果公司', 'ORG')]) | |
print(f"find_stock_codes_or_names: {result}") | |
result = process_history(get_stock_history('AAPL', '20240214'), '20240214') | |
print(f"process_history: {result}") | |
result = process_history(get_stock_index_history('AAPL', '20240214'), '20240214') | |
print(f"process_history: {result}") | |
pass |