230930_bt.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. # coding:utf-8
  2. import time
  3. from multiprocessing import freeze_support, Value, Lock
  4. import backtrader as bt
  5. from backtrader.feeds import PandasData
  6. import backtrader.indicators as btind
  7. from sqlalchemy import create_engine, text
  8. import pymysql
  9. from tqdm import tqdm
  10. import concurrent.futures
  11. import numpy as np
  12. import pandas as pd
  13. import platform
  14. import datetime
  15. from datetime import datetime as dt
  16. from itertools import product
  17. import psutil
  18. import logging
  19. import multiprocessing as mp
  20. from itertools import islice
  21. from func_timeout import func_set_timeout, FunctionTimedOut
  22. from functools import partial
  23. class MyPandasData(PandasData):
  24. lines = ('hl', 'dif', 'dea', 'macd', 'rsi_6', 'rsi_12', 'rsi_24',)
  25. params = (('hl', 7),
  26. ('dif', 8),
  27. ('dea', 9),
  28. ('macd', 10),
  29. ('rsi_6', 11),
  30. ('rsi_12', 12),
  31. ('rsi_24', 13),
  32. )
  33. class TestStrategy(bt.Strategy):
  34. params = (
  35. ("num", 3),
  36. ('Volatility', 0),
  37. ('rate', 3), # 注意要有逗号!!
  38. )
  39. def log(self, txt, dt=None):
  40. # 记录策略的执行日志
  41. dt = dt or self.datas[0].datetime.date(0)
  42. # print('%s, %s' % (dt.isoformat(), txt))
  43. def __init__(self):
  44. try:
  45. self.pos_price = 0
  46. self.dataclose = self.datas[0].close
  47. self.dataopen = self.datas[0].open
  48. self.high = self.datas[0].high
  49. self.low = self.datas[0].low
  50. self.volume = self.datas[0].volume
  51. self.hl = self.datas[0].hl
  52. self.dif = self.datas[0].dif
  53. self.dea = self.datas[0].dea
  54. self.macd = self.datas[0].macd
  55. self.rsi_6 = self.datas[0].rsi_6
  56. self.rsi_12 = self.datas[0].rsi_12
  57. self.rsi_24 = self.datas[0].rsi_24
  58. self.sma5 = btind.MovingAverageSimple(self.datas[0].close, period=5)
  59. self.sma10 = btind.MovingAverageSimple(self.datas[0].close, period=10)
  60. self.sma20 = btind.MovingAverageSimple(self.datas[0].close, period=20)
  61. self.sma60 = btind.MovingAverageSimple(self.datas[0].close, period=60)
  62. # self.sma_vol = btind.MovingAverageSimple(self.datas[0].close, period=Volatility)
  63. except BaseException as e:
  64. print(f'初始化错误{e}')
  65. def notify_order(self, order):
  66. """
  67. 订单状态处理
  68. Arguments:
  69. order {object} -- 订单状态
  70. """
  71. if order.status in [order.Submitted, order.Accepted]:
  72. # 如订单已被处理,则不用做任何事情
  73. return
  74. # 检查订单是否完成
  75. if order.status in [order.Completed]:
  76. if order.isbuy():
  77. self.buyprice = order.executed.price
  78. self.buycomm = order.executed.comm
  79. self.bar_executed = len(self)
  80. # 订单因为缺少资金之类的原因被拒绝执行
  81. elif order.status in [order.Canceled, order.Margin, order.Rejected]:
  82. pass
  83. # self.log('Order Canceled/Margin/Rejected')
  84. # 订单状态处理完成,设为空
  85. self.order = None
  86. def notify_trade(self, trade):
  87. """
  88. 交易成果
  89. Arguments:
  90. trade {object} -- 交易状态
  91. """
  92. if not trade.isclosed:
  93. return
  94. # 显示交易的毛利率和净利润
  95. # self.log('OPERATION PROFIT, GROSS %.2f, NET %.2f' % (trade.pnl, trade.pnlcomm))
  96. def next(self):
  97. if self.volume[-1] < self.volume[0] and self.sma5[0] < self.dataclose[0]\
  98. and self.dataclose[0] > self.sma20[0] \
  99. and (self.hl[0] == 1 or self.hl[0] == 2 or self.hl[0] == 3):
  100. self.order = self.buy()
  101. self.pos_price = self.dataclose[0]
  102. elif (self.hl[0] == 5 or self.dataclose[0] < self.sma5[0]):
  103. self.order = self.close()
  104. self.pos_price = 0
  105. def stop(self):
  106. # pass
  107. self.log(u'(MA趋势交易效果) Ending Value %.2f' % (self.broker.getvalue()))
  108. def to_df(df):
  109. print('开始存数据')
  110. df.sort_values(by=['MA5乖离率', '当日回落'], ascending=True, inplace=True)
  111. df = df.reset_index(drop=True)
  112. if platform.node() == 'DanieldeMBP.lan':
  113. df.to_csv(f"/Users/daniel/Documents/策略/Ma5乖离7买入{dt.now().strftime('%Y%m%d%H%m%S')}.csv",
  114. index=True,
  115. encoding='utf_8_sig', mode='w')
  116. else:
  117. df.to_csv(f"C:\策略结果\Ma5乖离7买入{dt.now().strftime('%Y%m%d%H%m%S')}.csv", index=True,
  118. encoding='utf_8_sig', mode='w')
  119. print(f'结果:, \n, {df}')
  120. def chunked_iterable(iterable, size):
  121. """将可迭代对象分割为指定大小的块"""
  122. it = iter(iterable)
  123. while True:
  124. chunk = tuple(islice(it, size))
  125. if not chunk:
  126. return
  127. yield chunk
  128. def query_database(table_name):
  129. engine = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_tech?charset=utf8')
  130. df = pd.read_sql_table(table_name, engine)
  131. return df
  132. def get_stock_data():
  133. while True:
  134. try:
  135. db = pymysql.connect(host='localhost',
  136. user='root',
  137. port=3307,
  138. password='r6kEwqWU9!v3',
  139. database='qmt_stocks_tech')
  140. cursor = db.cursor()
  141. cursor.execute("show tables like '%%%s%%' " % '1d')
  142. table_list = [tuple[0] for tuple in cursor.fetchall()]
  143. # table_list = table_list[0: 10]
  144. cursor.close()
  145. db.close()
  146. print(f'开始数据库读取')
  147. with concurrent.futures.ProcessPoolExecutor(max_workers=16) as executor:
  148. # 使用executor.map方法实现多进程并行查询数据库,得到每个表的数据,并存储在一个字典中
  149. data_dict = {table_name: df for table_name, df in
  150. tqdm(zip(table_list, executor.map(query_database, table_list)))}
  151. print(f'数据库读取完成')
  152. break
  153. except BaseException as e:
  154. print(f'数据库读取错误{e}')
  155. continue
  156. return data_dict
  157. def backtrader_test(stock_data, stock_name, vot):
  158. # print(f'开始回测{stock_name}')
  159. try:
  160. cerebro = bt.Cerebro()
  161. stock_data.time = pd.to_datetime(stock_data.time)
  162. stock_data['HL'] = stock_data['HL'].map({'L': 1,
  163. 'LL': 2,
  164. 'L*': 3,
  165. 'H': 4,
  166. 'HH': 5,
  167. 'H*': 6,
  168. '-': 7})
  169. cerebro.addstrategy(TestStrategy, Volatility=vot)
  170. data = MyPandasData(dataname=stock_data,
  171. fromdate=datetime.datetime(2017, 1, 1),
  172. todate=datetime.datetime(2022, 10, 30),
  173. datetime='time',
  174. open='open_back',
  175. close='close_back',
  176. high='high_back',
  177. low='low_back',
  178. volume='volume_back',
  179. hl='HL',
  180. dif='dif',
  181. dea='dea',
  182. macd='macd',
  183. rsi_6='rsi_6',
  184. rsi_12='rsi_12',
  185. rsi_24='rsi_24',
  186. )
  187. cerebro.adddata(data)
  188. cerebro.addstrategy(TestStrategy)
  189. cerebro.broker.setcash(100000.0)
  190. cerebro.addsizer(bt.sizers.FixedSize, stake=10000)
  191. cerebro.broker.setcommission(commission=0.001)
  192. cerebro.run()
  193. except BaseException as e:
  194. print(f'{stock_name}回测错误{e}')
  195. return np.nan
  196. # print(cerebro.broker.getvalue() - 100000.0)
  197. # print(stock_name)
  198. else:
  199. return cerebro.broker.getvalue() - 100000.0
  200. def tdf(tt, Volatility):
  201. num_nan = np.isnan(tt).sum() # Count NaN values
  202. print(f'num_nan={num_nan}')
  203. filtered_result = [r for r in tt if not np.isnan(r)] # Filter out NaN values
  204. print(f'filtered_result={filtered_result}')
  205. # Calculate statistics
  206. num_profits = len([r for r in tt if r > 0])
  207. num_losses = len([r for r in tt if r < 0])
  208. profit_ratio = num_profits / (len(filtered_result))
  209. total_profit = sum([r for r in tt if r > 0])
  210. avg_profit = total_profit / num_profits if num_profits else 0
  211. max_profit = max(tt)
  212. min_profit = min([r for r in tt if r > 0]) if num_profits else 0
  213. total_loss = sum([r for r in tt if r < 0])
  214. avg_loss = total_loss / num_losses if num_losses else 0
  215. max_loss = min(tt)
  216. min_loss = max([r for r in tt if r < 0]) if num_losses else 0
  217. # Append the results into the DataFrame
  218. result_dict = {'基准均线': Volatility, '盈利个数': num_profits,
  219. '盈利比例': profit_ratio, '总盈利': total_profit, '平均盈利': avg_profit,
  220. '最大盈利': max_profit, '最小盈利': min_profit, '总亏损': total_loss,
  221. '平均亏损': avg_loss, '最大亏损': max_loss, '最小亏损': min_loss, '未计算个股数': num_nan}
  222. df_t = pd.Series(result_dict)
  223. return df_t
  224. if __name__ == '__main__':
  225. logger = mp.log_to_stderr()
  226. logger.setLevel(logging.DEBUG)
  227. cpu_list = list(range(0, 23))
  228. print(cpu_list)
  229. pus = psutil.Process()
  230. pus.cpu_affinity(cpu_list)
  231. start_time = dt.now()
  232. # 定义需要穷举的参数值
  233. Volatility = range(5, 500, 5) # 当日回撤
  234. # rates = range(3, 20, 1) # 乖离率
  235. # 生成所有参数组合
  236. all_combinations = list(product(Volatility))
  237. print(f'共需计算{len(all_combinations)}次')
  238. # 获取数据
  239. stock_data_dict = get_stock_data()
  240. results = []
  241. df = pd.DataFrame(
  242. columns=['MA5乖离率', '当日回落', '盈利个数', '盈利比例', '总盈利', '平均盈利', '最大盈利', '最小盈利',
  243. '总亏损',
  244. '平均亏损', '最大亏损', '最小亏损', '未计算个股数'])
  245. err_list = []
  246. # 设置每一轮的任务数
  247. CHUNK_SIZE = 200 # 您可以根据需要进行调整
  248. timeout = 120
  249. max_retries = 3
  250. with concurrent.futures.ProcessPoolExecutor(max_workers=24) as inner_executor:
  251. for Volatility in tqdm(all_combinations, desc='计算进度'):
  252. while True:
  253. try:
  254. # 使用executor.map方法实现多进程并行计算不同参数组合的结果
  255. res = [result for result in tqdm(
  256. inner_executor.map(backtrader_test, stock_data_dict.values(), stock_data_dict.keys(),
  257. [Volatility] * len(stock_data_dict)),
  258. desc='单轮计算进度')]
  259. except BaseException as e:
  260. print(f'计算错误{e}')
  261. inner_executor = concurrent.futures.ProcessPoolExecutor(max_workers=24)
  262. else:
  263. results.append(res)
  264. df_t = tdf(res, Volatility)
  265. df = pd.concat([df, df_t.to_frame().T], ignore_index=True)
  266. break
  267. # time.sleep(1)
  268. print(f'{Volatility}计算完成,共计算{len(res)}个股票')
  269. print(df)
  270. print('循环结束')
  271. to_df(df)
  272. print(f'计算完成,共耗时{dt.now() - start_time}秒')