111.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import multiprocessing as mp
  2. from concurrent import futures
  3. import concurrent.futures
  4. import logging
  5. from sqlalchemy import create_engine, text
  6. import pandas as pd
  7. from myindicator import myind
  8. from pandas.testing import assert_frame_equal
  9. from multiprocessing import freeze_support, Value, Lock
  10. import logging
  11. from datetime import datetime as dt
  12. import traceback
  13. import dask.dataframe as dd
  14. from dask import delayed, compute
  15. from concurrent.futures import ThreadPoolExecutor
  16. import threading
  17. from itertools import islice
  18. pd.set_option('display.max_columns', None) # 设置显示最大行
  19. def error(msg, *args):
  20. return mp.get_logger().error(msg, *args)
  21. class LogExceptions(object):
  22. def __init__(self, callable):
  23. self.__callable = callable
  24. return
  25. def __call__(self, *args, **kwargs):
  26. try:
  27. result = self.__callable(*args, **kwargs)
  28. except Exception as e:
  29. # Here we add some debugging help. If multiprocessing's
  30. # debugging is on, it will arrange to log the traceback
  31. error(traceback.format_exc())
  32. # Re-raise the original exception so the Pool worker can
  33. # clean up
  34. raise
  35. # It was fine, give a normal answer
  36. return result
  37. pass
  38. def err_call_back(err):
  39. print(f'问题在这里~ error:{str(err)}')
  40. traceback.print_exc()
  41. def chunked_iterable(iterable, size):
  42. """将可迭代对象分割为指定大小的块"""
  43. it = iter(iterable)
  44. while True:
  45. chunk = tuple(islice(it, size))
  46. if not chunk:
  47. return
  48. yield chunk
  49. def assert_frame_equal(df, u):
  50. # print(f'{u}开始')
  51. data_temp = df
  52. st = dt.now()
  53. myind.get_macd_data(df)
  54. df_temp, trading_signals= myind.get_hlfx(df)
  55. try:
  56. df_temp_2, trading_signals_2 = myind.get_hlfx_optimization(df)
  57. # myind.get_ddfx(df, data_temp, u)
  58. # print(f'get_ddfx', u, u, u)
  59. # df_temp, t_signals = myind.get_hlfx(df)
  60. except BaseException as e:
  61. print('err', e)
  62. # print('df_temp', df_temp)
  63. # print('df_temp_2', df_temp_2)
  64. try:
  65. print('tttttt', df_temp.equals(df_temp_2))
  66. except BaseException as e:
  67. print('err', e)
  68. return df_temp
  69. def t(df, u):
  70. st = dt.now()
  71. t = pd.DataFrame()
  72. for i in range(len(df)):
  73. t = pd.concat([t.copy(), df.loc[i].to_frame().T], axis=0)
  74. t.loc['HL'] = 1
  75. print(f'{u}完成,{dt.now() - st}')
  76. def ts(df, u):
  77. print(f'{u}开始')
  78. st = dt.now()
  79. t_list = [] # 创建一个空列表用于保存每次连接的DataFrame
  80. for i in range(len(df)):
  81. t_list.append(df.loc[i].to_frame().T)
  82. t = dd.concat(t_list, axis=0) # 一次性执行连接操作
  83. t = t.assign(HL=1) # 在Dask DataFrame中添加一列HL并赋值为1
  84. result = compute(t) # 执行计算
  85. # print(result)
  86. print(f'{u}完成,{dt.now() - st}')
  87. '''
  88. # 主函数
  89. if __name__ == '__main__':
  90. freeze_support()
  91. logger = mp.log_to_stderr()
  92. logger.setLevel(logging.DEBUG)
  93. engine = create_engine(
  94. 'mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_whole?charset=utf8')
  95. df = pd.read_sql_table('000001.SZ_1d', con=engine)
  96. # pool = futures.ProcessPoolExecutor(max_workers=24)
  97. # pool.map(assert_frame_equal(df, range(5000)))
  98. # with concurrent.futures.ProcessPoolExecutor(max_workers=24) as executor:
  99. # for i in range(5000):
  100. # # executor.submit(LogExceptions(assert_frame_equal), df, i)
  101. # executor.submit(assert_frame_equal, df, i)
  102. # print(i)
  103. pool = mp.Pool(24)
  104. for j in range(5000):
  105. # pool.apply_async(LogExceptions(assert_frame_equal), args=(df, j))
  106. # pool.apply_async(func=assert_frame_equal, args=(df, j), error_callback=err_call_back)
  107. # pool.apply_async(func=t, args=(df, j))
  108. pool.apply_async(func=ts, args=(df, j))
  109. # # pool.map_async(func=assert_frame_equal, iterable=[df], chunksize=1)
  110. # print(j)
  111. pool.close()
  112. pool.join()
  113. '''
  114. if __name__ == '__main__':
  115. mp.freeze_support()
  116. logger = mp.log_to_stderr()
  117. # logger.setLevel(logging.DEBUG)
  118. engine = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_whole?charset=utf8')
  119. df = pd.read_sql_table('000001.SZ_1d', con=engine)
  120. # print(df)
  121. # a, b= assert_frame_equal(df, 1)
  122. # print(a,b)
  123. # exit()
  124. df_dict = {f'{i:06}.SZ_1d': df.copy() for i in range(1, 11)}
  125. print(len(df_dict))
  126. # exit()
  127. async_results = []
  128. # 设置每一轮的任务数
  129. CHUNK_SIZE = 50 # 您可以根据需要进行调整
  130. for chunk in chunked_iterable(df_dict.items(), CHUNK_SIZE):
  131. print(f'chunk:{chunk[0][0]}-{chunk[-1][0]}')
  132. with mp.Pool(processes=min(CHUNK_SIZE, len(chunk), 24)) as pool: # 使用最小值确保不会超出任务数或超过24核心
  133. for stock, df_stock in chunk:
  134. async_result = pool.apply_async(func=assert_frame_equal, args=(df_stock, stock))
  135. async_results.append(async_result)
  136. pool.close()
  137. pool.join()
  138. exit()
  139. # 在主进程中统一处理结果
  140. for res in async_results:
  141. print(res.get())