333.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. from multiprocessing import freeze_support, Value, Lock
  2. import backtrader as bt
  3. from backtrader.feeds import PandasData
  4. import backtrader.indicators as btind
  5. from sqlalchemy import create_engine, text
  6. import pymysql
  7. from tqdm import tqdm
  8. import concurrent.futures
  9. import pandas as pd
  10. import matplotlib
  11. import datetime
  12. from datetime import datetime as dt
  13. from itertools import product
  14. import psutil
  15. import logging
  16. import multiprocessing as mp
  17. from itertools import islice
  18. class MyPandasData(PandasData):
  19. lines = ('hl', 'dif', 'dea', 'macd', 'rsi_6', 'rsi_12', 'rsi_24',)
  20. params = (('hl', 7),
  21. ('dif', 8),
  22. ('dea', 9),
  23. ('macd', 10),
  24. ('rsi_6', 11),
  25. ('rsi_12', 12),
  26. ('rsi_24', 13),
  27. )
  28. '''
  29. lines = ('change_pct', 'net_amount_main', 'net_pct_main', 'net_amount_xl', 'net_pct_xl', 'net_amount_l', 'net_pct_l'
  30. , 'net_amount_m', 'net_pct_m', 'net_amount_s', 'net_pct_s',)
  31. params = (('change_pct', 7),
  32. ('net_amount_main', 8),
  33. ('net_pct_main', 9),
  34. ('net_amount_xl', 10),
  35. ('net_pct_xl', 11),
  36. ('net_amount_l', 12),
  37. ('net_pct_l', 13),
  38. ('net_amount_m', 14),
  39. ('net_pct_m', 15),
  40. ('net_amount_s', 16),
  41. ('net_pct_s', 17),
  42. )
  43. '''
  44. class TestStrategy(bt.Strategy):
  45. def log(self, txt, dt=None):
  46. # 记录策略的执行日志
  47. dt = dt or self.datas[0].datetime.date(0)
  48. # print('%s, %s' % (dt.isoformat(), txt))
  49. def __init__(self):
  50. # 保存收盘价的引用
  51. self.dataclose = self.datas[0].close
  52. def next(self):
  53. # 记录收盘价
  54. self.log('Close, %.2f' % self.dataclose[0])
  55. # 今天的收盘价 < 昨天收盘价
  56. if self.dataclose[0] < self.dataclose[-1]:
  57. # 昨天收盘价 < 前天的收盘价
  58. if self.dataclose[-1] < self.dataclose[-2]:
  59. # 买入
  60. self.log('买入, %.2f' % self.dataclose[0])
  61. self.buy()
  62. def t():
  63. print('tttt')
  64. def chunked_iterable(iterable, size):
  65. """将可迭代对象分割为指定大小的块"""
  66. it = iter(iterable)
  67. while True:
  68. chunk = tuple(islice(it, size))
  69. if not chunk:
  70. return
  71. yield chunk
  72. def query_database(table_name):
  73. engine = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_tech?charset=utf8')
  74. df = pd.read_sql_table(table_name, engine)
  75. return df
  76. def get_stock_data():
  77. while True:
  78. try:
  79. db = pymysql.connect(host='localhost',
  80. user='root',
  81. port=3307,
  82. password='r6kEwqWU9!v3',
  83. database='qmt_stocks_tech')
  84. cursor = db.cursor()
  85. cursor.execute("show tables like '%%%s%%' " % '1d')
  86. table_list = [tuple[0] for tuple in cursor.fetchall()]
  87. # table_list = table_list[0: 10]
  88. cursor.close()
  89. db.close()
  90. print(f'开始数据库读取')
  91. with concurrent.futures.ProcessPoolExecutor(max_workers=24) as executor:
  92. # 使用executor.map方法实现多进程并行查询数据库,得到每个表的数据,并存储在一个字典中
  93. data_dict = {table_name: df for table_name, df in
  94. tqdm(zip(table_list, executor.map(query_database, table_list)))}
  95. print(f'数据库读取完成')
  96. break
  97. except BaseException as e:
  98. print(f'数据库读取错误{e}')
  99. continue
  100. return data_dict
  101. def backtrader_test(stock_data, stock_name, num, vot, rate):
  102. cerebro = bt.Cerebro()
  103. stock_data.time = pd.to_datetime(stock_data.time)
  104. stock_data['HL'] = stock_data['HL'].map({'L': 1,
  105. 'LL': 2,
  106. 'L*': 3,
  107. 'H': 4,
  108. 'HH': 5,
  109. 'H*': 6,
  110. '-': 7})
  111. data = MyPandasData(dataname=stock_data,
  112. fromdate=datetime.datetime(2017, 1, 1),
  113. todate=datetime.datetime(2022, 10, 30),
  114. datetime='time',
  115. open='open_back',
  116. close='close_back',
  117. high='high_back',
  118. low='low_back',
  119. volume='volume_back',
  120. hl='HL',
  121. dif='dif',
  122. dea='dea',
  123. macd='macd',
  124. rsi_6='rsi_6',
  125. rsi_12='rsi_12',
  126. rsi_24='rsi_24',
  127. )
  128. cerebro.adddata(data)
  129. cerebro.addstrategy(TestStrategy)
  130. cerebro.broker.setcash(100000.0)
  131. cerebro.addsizer(bt.sizers.FixedSize, stake=100)
  132. cerebro.broker.setcommission(commission=0.001)
  133. cerebro.run()
  134. return cerebro.broker.getvalue() - 100000.0
  135. def bbt(stock_data_dict, num, Volatility, rate):
  136. # while True:
  137. # exception_flag = False
  138. async_results = []
  139. try:
  140. # 设置每一轮的任务数
  141. CHUNK_SIZE = 200 # 您可以根据需要进行调整
  142. for chunk in tqdm(chunked_iterable(stock_data_dict.items(), CHUNK_SIZE)):
  143. print(f'chunk:{chunk[0][0]}-{chunk[-1][0]}')
  144. with mp.Pool(processes=min(CHUNK_SIZE, len(chunk), 24)) as pool: # 使用最小值确保不会超出任务数或超过24核心
  145. for stock, df_stock in chunk:
  146. async_result = pool.apply_async(func=backtrader_test, args=(df_stock, stock, num, Volatility, rate))
  147. async_results.append(async_result)
  148. pool.close()
  149. pool.join()
  150. # with concurrent.futures.ProcessPoolExecutor(max_workers=18) as inner_executor:
  151. # print(f'开始计算{num},{Volatility},{rate}')
  152. # # 使用executor.map方法实现多进程并行计算不同参数组合的结果
  153. # results = [result for result in
  154. # inner_executor.map(backtrader_test, stock_data_dict.values(), stock_data_dict.keys(),
  155. # [num] * len(stock_data_dict),
  156. # [Volatility] * len(stock_data_dict), [rate] * len(stock_data_dict),
  157. # timeout=1200)]
  158. # except concurrent.futures.TimeoutError as e:
  159. # print(f'计算超时{e}')
  160. # results = []
  161. # exception_flag = True
  162. except BaseException as e:
  163. print(f'计算错误{e}')
  164. results = True
  165. outputs = [result.get() for result in async_results]
  166. print(outputs)
  167. return outputs
  168. if __name__ == '__main__':
  169. logger = mp.log_to_stderr()
  170. logger.setLevel(logging.DEBUG)
  171. cpu_list = list(range(24))
  172. pus = psutil.Process()
  173. pus.cpu_affinity(cpu_list)
  174. # 定义需要穷举的参数值
  175. nums = range(60, 80, 20)
  176. Volatilitys = range(5, 6, 1)
  177. rates = range(3, 4, 1)
  178. # 生成所有参数组合
  179. all_combinations = list(product(nums, Volatilitys, rates))
  180. print(f'共需计算{len(all_combinations)}次')
  181. # 获取数据
  182. stock_data_dict = get_stock_data()
  183. results = []
  184. # 获取stock_data_dict的第1个value,即第1个DataFrame
  185. # stock_data = next(iter(stock_data_dict.values()))
  186. # print(stock_data)
  187. for num, Volatility, rate in tqdm(all_combinations, desc='计算进度'):
  188. result = bbt(stock_data_dict, num, Volatility, rate)
  189. results.append(result)
  190. print(results, len(results), len(results[0]))
  191. df = pd.DataFrame(
  192. columns=['周期', '波动率', 'MA5斜率', '盈利个数', '盈利比例', '总盈利', '平均盈利', '最大盈利', '最小盈利', '总亏损',
  193. '平均亏损', '最大亏损', '最小亏损'])
  194. for tt in results:
  195. num_profits = len([r for r in tt if r > 0])
  196. num_losses = len([r for r in tt if r < 0])
  197. profit_ratio = num_profits / len(stock_data_dict)
  198. total_profit = sum([r for r in tt if r > 0])
  199. avg_profit = total_profit / num_profits if num_profits else 0
  200. max_profit = max(tt)
  201. min_profit = min([r for r in tt if r > 0]) if num_profits else 0
  202. total_loss = sum([r for r in tt if r < 0])
  203. avg_loss = total_loss / num_losses if num_losses else 0
  204. max_loss = min(tt)
  205. min_loss = max([r for r in tt if r < 0]) if num_losses else 0
  206. # Append the results into the DataFrame
  207. result_dict = {'周期': num, '波动率': Volatility, 'MA5斜率': rate, '盈利个数': num_profits,
  208. '盈利比例': profit_ratio, '总盈利': total_profit, '平均盈利': avg_profit,
  209. '最大盈利': max_profit, '最小盈利': min_profit, '总亏损': total_loss,
  210. '平均亏损': avg_loss, '最大亏损': max_loss, '最小亏损': min_loss}
  211. df_t = pd.Series(result_dict)
  212. print(df_t)
  213. df = pd.concat([df, df_t.to_frame().T], ignore_index=True)
  214. print(df)
  215. exit()
  216. num = 60
  217. Volatility = 5
  218. rate = 3
  219. i = 0
  220. st = dt.now()
  221. while True:
  222. i += 1
  223. try:
  224. results = bbt(stock_data_dict, num, Volatility, rate)
  225. except BaseException as e:
  226. print(f'计算错误{e}')
  227. break
  228. print(results)
  229. if results is True:
  230. print(f'计算错误,重新计算')
  231. continue
  232. else:
  233. print(f'第{i}次计算完成,耗时{dt.now() - st}')
  234. print(f'计算结果为{len(results)}')
  235. print(results)
  236. print(f'全部计算完成,共{len(results)}次')
  237. exit()
  238. getvalue = backtrader_test(stock_data)
  239. if getvalue > 100000:
  240. print('盈利')
  241. else:
  242. print('亏损')
  243. # 绘制图像
  244. # cerebro.plot()