123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- import multiprocessing as mp
- from concurrent import futures
- import concurrent.futures
- import logging
- from sqlalchemy import create_engine, text
- import pandas as pd
- from myindicator import myind
- from pandas.testing import assert_frame_equal
- from multiprocessing import freeze_support, Value, Lock
- import logging
- from datetime import datetime as dt
- import traceback
- import dask.dataframe as dd
- from dask import delayed, compute
- from concurrent.futures import ThreadPoolExecutor
- import threading
- from itertools import islice
- pd.set_option('display.max_columns', None) # 设置显示最大行
- def error(msg, *args):
- return mp.get_logger().error(msg, *args)
- class LogExceptions(object):
- def __init__(self, callable):
- self.__callable = callable
- return
- def __call__(self, *args, **kwargs):
- try:
- result = self.__callable(*args, **kwargs)
- except Exception as e:
- # Here we add some debugging help. If multiprocessing's
- # debugging is on, it will arrange to log the traceback
- error(traceback.format_exc())
- # Re-raise the original exception so the Pool worker can
- # clean up
- raise
- # It was fine, give a normal answer
- return result
- pass
- def err_call_back(err):
- print(f'问题在这里~ error:{str(err)}')
- traceback.print_exc()
- def chunked_iterable(iterable, size):
- """将可迭代对象分割为指定大小的块"""
- it = iter(iterable)
- while True:
- chunk = tuple(islice(it, size))
- if not chunk:
- return
- yield chunk
- def assert_frame_equal(df, u):
- # print(f'{u}开始')
- data_temp = df
- st = dt.now()
-
- myind.get_macd_data(df)
- df_temp, trading_signals= myind.get_hlfx(df)
- try:
- df_temp_2, trading_signals_2 = myind.get_hlfx_optimization(df)
- # myind.get_ddfx(df, data_temp, u)
- # print(f'get_ddfx', u, u, u)
- # df_temp, t_signals = myind.get_hlfx(df)
- except BaseException as e:
- print('err', e)
- # print('df_temp', df_temp)
- # print('df_temp_2', df_temp_2)
- try:
- print('tttttt', df_temp.equals(df_temp_2))
- except BaseException as e:
- print('err', e)
- return df_temp
- def t(df, u):
- st = dt.now()
- t = pd.DataFrame()
- for i in range(len(df)):
- t = pd.concat([t.copy(), df.loc[i].to_frame().T], axis=0)
- t.loc['HL'] = 1
- print(f'{u}完成,{dt.now() - st}')
- def ts(df, u):
- print(f'{u}开始')
- st = dt.now()
- t_list = [] # 创建一个空列表用于保存每次连接的DataFrame
- for i in range(len(df)):
- t_list.append(df.loc[i].to_frame().T)
- t = dd.concat(t_list, axis=0) # 一次性执行连接操作
- t = t.assign(HL=1) # 在Dask DataFrame中添加一列HL并赋值为1
- result = compute(t) # 执行计算
- # print(result)
- print(f'{u}完成,{dt.now() - st}')
- '''
- # 主函数
- if __name__ == '__main__':
- freeze_support()
- logger = mp.log_to_stderr()
- logger.setLevel(logging.DEBUG)
- engine = create_engine(
- 'mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_whole?charset=utf8')
- df = pd.read_sql_table('000001.SZ_1d', con=engine)
- # pool = futures.ProcessPoolExecutor(max_workers=24)
- # pool.map(assert_frame_equal(df, range(5000)))
- # with concurrent.futures.ProcessPoolExecutor(max_workers=24) as executor:
- # for i in range(5000):
- # # executor.submit(LogExceptions(assert_frame_equal), df, i)
- # executor.submit(assert_frame_equal, df, i)
- # print(i)
- pool = mp.Pool(24)
- for j in range(5000):
- # pool.apply_async(LogExceptions(assert_frame_equal), args=(df, j))
- # pool.apply_async(func=assert_frame_equal, args=(df, j), error_callback=err_call_back)
- # pool.apply_async(func=t, args=(df, j))
- pool.apply_async(func=ts, args=(df, j))
- # # pool.map_async(func=assert_frame_equal, iterable=[df], chunksize=1)
- # print(j)
- pool.close()
- pool.join()
- '''
- if __name__ == '__main__':
- mp.freeze_support()
- logger = mp.log_to_stderr()
- # logger.setLevel(logging.DEBUG)
-
- engine = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_whole?charset=utf8')
- df = pd.read_sql_table('000001.SZ_1d', con=engine)
- # print(df)
- # a, b= assert_frame_equal(df, 1)
- # print(a,b)
- # exit()
- df_dict = {f'{i:06}.SZ_1d': df.copy() for i in range(1, 11)}
- print(len(df_dict))
- # exit()
-
- async_results = []
-
- # 设置每一轮的任务数
- CHUNK_SIZE = 50 # 您可以根据需要进行调整
-
- for chunk in chunked_iterable(df_dict.items(), CHUNK_SIZE):
- print(f'chunk:{chunk[0][0]}-{chunk[-1][0]}')
- with mp.Pool(processes=min(CHUNK_SIZE, len(chunk), 24)) as pool: # 使用最小值确保不会超出任务数或超过24核心
- for stock, df_stock in chunk:
- async_result = pool.apply_async(func=assert_frame_equal, args=(df_stock, stock))
- async_results.append(async_result)
-
- pool.close()
- pool.join()
- exit()
- # 在主进程中统一处理结果
- for res in async_results:
- print(res.get())
|