YH_backtrader.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. import os
  2. import numpy as np
  3. from sqlalchemy import create_engine
  4. import pandas as pd
  5. import pymysql
  6. import backtrader as bt
  7. import backtrader.indicators as btind
  8. import backtrader.analyzers as btanalyzers
  9. import datetime
  10. import math
  11. from datetime import datetime as dt
  12. import multiprocessing as mp
  13. from backtrader.feeds import PandasData
  14. from numba import jit, cuda, njit
  15. # import multiprocessing
  16. import matplotlib
  17. pd.set_option('display.max_columns', None) # 设置显示最大行
  18. # global result
  19. # result = pd.DataFrame(columns=['code', 'result', 'num', 'Volatility', 'rate'])
  20. class MyPandasData(PandasData):
  21. lines = ()
  22. params = ()
  23. '''
  24. lines = ('change_pct', 'net_amount_main', 'net_pct_main', 'net_amount_xl', 'net_pct_xl', 'net_amount_l', 'net_pct_l'
  25. , 'net_amount_m', 'net_pct_m', 'net_amount_s', 'net_pct_s',)
  26. params = (('change_pct', 7),
  27. ('net_amount_main', 8),
  28. ('net_pct_main', 9),
  29. ('net_amount_xl', 10),
  30. ('net_pct_xl', 11),
  31. ('net_amount_l', 12),
  32. ('net_pct_l', 13),
  33. ('net_amount_m', 14),
  34. ('net_pct_m', 15),
  35. ('net_amount_s', 16),
  36. ('net_pct_s', 17),
  37. )
  38. '''
  39. class TestStrategy(bt.Strategy):
  40. params = (
  41. ("num", 3),
  42. ('Volatility', 0),
  43. ('rate', 5), # 注意要有逗号!!
  44. )
  45. def log(self, txt, dt=None):
  46. ''' Logging function for this strategy'''
  47. dt = dt or self.datas[0].datetime.date(0)
  48. print('%s, %s' % (dt.isoformat(), txt))
  49. def notify_order(self, order):
  50. """
  51. 订单状态处理
  52. Arguments:
  53. order {object} -- 订单状态
  54. """
  55. if order.status in [order.Submitted, order.Accepted]:
  56. # 如订单已被处理,则不用做任何事情
  57. return
  58. # 检查订单是否完成
  59. if order.status in [order.Completed]:
  60. if order.isbuy():
  61. self.buyprice = order.executed.price
  62. self.buycomm = order.executed.comm
  63. self.bar_executed = len(self)
  64. # 订单因为缺少资金之类的原因被拒绝执行
  65. elif order.status in [order.Canceled, order.Margin, order.Rejected]:
  66. pass
  67. self.log('Order Canceled/Margin/Rejected')
  68. # 订单状态处理完成,设为空
  69. self.order = None
  70. def notify_trade(self, trade):
  71. """
  72. 交易成果
  73. Arguments:
  74. trade {object} -- 交易状态
  75. """
  76. if not trade.isclosed:
  77. return
  78. # 显示交易的毛利率和净利润
  79. # self.log('OPERATION PROFIT, GROSS %.2f, NET %.2f' % (trade.pnl, trade.pnlcomm))
  80. def __init__(self):
  81. # print('__init__', dt.now())
  82. # print(f'{self.params.num}天波动率为{self.params.Volatility}%乖离率为{self.params.rate}', 'myPID is ', os.getpid())
  83. self.dataclose = self.datas[0].close
  84. self.dataopen = self.datas[0].open
  85. self.high = self.datas[0].high
  86. self.low = self.datas[0].low
  87. self.volume = self.datas[0].volume
  88. # self.change_pct = self.datas[0].change_pct
  89. # self.net_amount_main = self.datas[0].net_amount_main
  90. # self.net_pct_main = self.datas[0].net_pct_main
  91. # self.net_amount_xl = self.datas[0].net_amount_xl
  92. # self.net_pct_xl = self.datas[0].net_pct_xl
  93. # self.net_amount_l = self.datas[0].net_amount_l
  94. # self.net_pct_l = self.datas[0].net_pct_l
  95. self.sma5 = btind.MovingAverageSimple(self.datas[0].close, period=5)
  96. self.sma10 = btind.MovingAverageSimple(self.datas[0].close, period=10)
  97. self.sma20 = btind.MovingAverageSimple(self.datas[0].close, period=20)
  98. self.yx = self.dataclose > self.dataopen
  99. self.lowest = min(self.dataclose.get(ago=-1, size=self.params.num))
  100. self.highest = max(self.high.get(ago=-1, size=self.params.num))
  101. self.vola = self.params.Volatility / 100
  102. self.rate = self.params.rate / 100
  103. # print('初始化完成', dt.now())
  104. # @njit
  105. def next(self):
  106. '''
  107. if self.yx and (self.dataclose[0] > self.dataclose[-1] > self.dataclose[-2])\
  108. and self.sma5[0] > self.sma10[0] and self.sma5[-1] < self.sma10[-1]:
  109. print('next', self.yx[0], (self.dataclose[0] > self.dataclose[-1] > self.dataclose[-2]),
  110. self.sma5[0] > self.sma10[0], self.sma5[-1] < self.sma10[-1])
  111. :return:
  112. '''
  113. if self.yx and ((self.lowest * (1 - self.vola)) < self.low[-2] < (self.lowest * (1 + self.vola)))\
  114. and self.dataclose[0] > self.dataclose[-1] > self.dataclose[-2] and self.dataclose[0] > self.sma5[0]:
  115. # print(f'buy, {self.lowest},{self.vola},{self.low[-2]}, {self.rate}')
  116. # self.log('BUY CREATE, %.2f' % self.dataclose[0])
  117. self.order = self.buy()
  118. elif self.dataclose < self.sma5[0]:
  119. self.order = self.close()
  120. # self.log('close<ma5 Close, %.2f' % self.dataclose[0])
  121. elif self.sma5[0] < self.sma10[0]:
  122. self.order = self.close()
  123. # self.log('ma5<ma10 Close, %.2f' % self.dataclose[0])
  124. elif self.dataclose[0] > self.sma5[0]*(1+self.rate):
  125. self.order = self.close()
  126. # self.log('close>rate Close, %.2f' % self.dataclose[0])
  127. '''
  128. if self.yx \
  129. and (((self.lowest[0] * (1 - self.vola)) < self.low[-2] < (self.lowest[0] * (1 + self.vola))) or (
  130. (self.lowest[0] * (1 - self.vola)) < self.low[-1] < (self.lowest[0] * (1 + self.vola)))) \
  131. and (self.dataclose[0] > self.sma5[0]) and self.sma5[0] > self.sma5[-1] \
  132. and (not self.position) and (self.sma5[0] > self.sma10[0]):
  133. # self.log('BUY CREATE, %.2f' % self.dataclose[0])
  134. self.order = self.buy()
  135. elif self.dataclose < self.sma5[0] or self.sma5[0] < self.sma10[0] \
  136. or (self.dataclose[0] > (self.sma5[0] * (1 + self.rate))) or \
  137. (((self.highest[0] * (1 - self.vola)) < self.high[-2] < (self.highest[0] * (1 + self.vola))) or (
  138. (self.highest[0] * (1 - self.vola)) < self.high[-1] < (self.highest[0] * (1 + self.vola)))):
  139. self.order = self.close()
  140. # self.log('Close, %.2f' % self.dataclose[0])
  141. '''
  142. def stop(self):
  143. # pass
  144. global result
  145. self.log(u'(MA趋势交易效果) Ending Value %.2f num %d Vol %d rate %d' % (self.broker.getvalue(),
  146. self.params.num,
  147. self.params.Volatility,
  148. self.params.rate))
  149. # self.log(f'time:{dt.now()}')
  150. # temp = pd.DataFrame(columns=['code', 'result', 'num', 'Volatility', 'rate'],
  151. # data=[self.getdatanames(), self.broker.getvalue(), self.params.num, self.params.rate])
  152. # result = pd.concat([result,temp],axis=0)
  153. def err_call_back(err):
  154. print(f'出错啦~ error:{str(err)}')
  155. def to_df(lt):
  156. df = pd.DataFrame(list(lt), columns=['周期', '波动率', '乖离率', '盈利个数', '盈利比例', '总盈利', '平均盈利', '最大盈利',
  157. '最小盈利', '总亏损', '平均亏损', '最大亏损', '最小亏损'])
  158. df.sort_values(by=['周期', '波动率', '乖离率'], ascending=True, inplace=True)
  159. df = df.reset_index(drop=True)
  160. df.to_csv(r'D:\Daniel\策略\策略穷举.csv', index=True, encoding='utf-8', mode='w')
  161. print(df)
  162. # 打印结果
  163. def get_my_analyzer(result):
  164. analyzer = {}
  165. # 返回参数
  166. analyzer['num'] = result.params.num
  167. analyzer['Volatility'] = result.params.Volatility
  168. analyzer['rate'] = result.params.rate
  169. # 提取年化收益
  170. analyzer['年化收益率'] = result.analyzers._Returns.get_analysis()['rnorm']
  171. analyzer['年化收益率(%)'] = result.analyzers._Returns.get_analysis()['rnorm100']
  172. # 提取最大回撤(习惯用负的做大回撤,所以加了负号)
  173. analyzer['最大回撤(%)'] = result.analyzers._DrawDown.get_analysis()['max']['drawdown'] * (-1)
  174. # 提取夏普比率
  175. analyzer['年化夏普比率'] = result.analyzers._SharpeRatio_A.get_analysis()['sharperatio']
  176. return analyzer
  177. def backtrader(table_list, result_change, result_change_fall, err_list):
  178. sttime = dt.now()
  179. engine = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_front?charset=utf8')
  180. cerebro = bt.Cerebro(stdstats=False)
  181. # cerebro.addobserver(bt.observers.Broker)
  182. # cerebro.addobserver(bt.observers.Trades)
  183. # cerebro.addobserver(bt.observers.BuySell)
  184. # cerebro.addobserver(bt.observers.DrawDown)
  185. # cerebro.addobserver(bt.observers.TimeReturn)
  186. # cerebro.addstrategy(TestStrategy)
  187. cerebro.addsizer(bt.sizers.FixedSize, stake=1000)
  188. cerebro.broker.setcash(100000.0)
  189. cerebro.broker.setcommission(0.005)
  190. for stock in table_list:
  191. print(stock)
  192. stk_df = pd.read_sql_table(stock, engine)
  193. stk_df.time = pd.to_datetime(stk_df.time)
  194. data = MyPandasData(dataname=stk_df,
  195. fromdate=datetime.datetime(2022, 1, 1),
  196. todate=datetime.datetime(2023, 2, 1),
  197. datetime='time',
  198. open='open',
  199. close='close',
  200. high='high',
  201. low='low',
  202. volume='volume',
  203. # change_pct='change_pct',
  204. # net_amount_main='net_amount_main',
  205. # net_pct_main='net_pct_main',
  206. # net_amount_xl='net_amount_xl',
  207. # net_pct_xl='net_pct_xl',
  208. # net_amount_l='net_amount_l',
  209. # net_pct_l='net_pct_l',
  210. # net_amount_m='net_amount_m',
  211. # net_pct_m='net_pct_m',
  212. # net_amount_s='net_amount_s',
  213. # net_pct_s='net_pct_s',
  214. )
  215. cerebro.adddata(data, name=stock)
  216. cerebro.optstrategy(TestStrategy, num=range(40, 130, 10), Volatility=range(5, 8), rate=range(5, 8))
  217. print('最优参定义', dt.now())
  218. # 添加分析指标
  219. # 返回年初至年末的年度收益率
  220. # cerebro.addanalyzer(bt.analyzers.AnnualReturn, _name='_AnnualReturn')
  221. # 计算最大回撤相关指标
  222. cerebro.addanalyzer(bt.analyzers.DrawDown, _name='_DrawDown')
  223. # 计算年化收益:日度收益
  224. cerebro.addanalyzer(bt.analyzers.Returns, _name='_Returns', tann=252)
  225. # 计算年化夏普比率:日度收益
  226. cerebro.addanalyzer(bt.analyzers.SharpeRatio, _name='_SharpeRatio', timeframe=bt.TimeFrame.Days, annualize=True,
  227. riskfreerate=0)
  228. # 计算夏普比率
  229. cerebro.addanalyzer(bt.analyzers.SharpeRatio_A, _name='_SharpeRatio_A')
  230. # 返回收益率时序
  231. cerebro.addanalyzer(bt.analyzers.TimeReturn, _name='_TimeReturn')
  232. # 策略执行前的资金
  233. # print('启动资金: %.2f' % cerebro.broker.getvalue())
  234. cerebro.addsizer(bt.sizers.PercentSizer, percents=10)
  235. cerebro.addanalyzer(btanalyzers.SharpeRatio, _name="sharpe")
  236. cerebro.addanalyzer(btanalyzers.DrawDown, _name="drawdown")
  237. cerebro.addanalyzer(btanalyzers.Returns, _name="returns")
  238. cerebro.addanalyzer(btanalyzers.TradeAnalyzer, _name='TradeAnalyzer')
  239. try:
  240. # 策略执行
  241. print('开始执行', dt.now())
  242. results = cerebro.run(maxcpus=None)
  243. print('回测结束', dt.now())
  244. except IndexError:
  245. err_list.append(stock)
  246. else:
  247. par_list = [[x[0].params.num,
  248. x[0].params.Volatility,
  249. x[0].params.rate,
  250. x[0].analyzers.returns.get_analysis()['rnorm100'],
  251. x[0].analyzers.drawdown.get_analysis()['max']['drawdown'],
  252. x[0].analyzers.sharpe.get_analysis()['sharperatio'],
  253. x[0].analyzers.TradeAnalyzer.get_analysis().won.total,
  254. ] for x in results]
  255. par_df = pd.DataFrame(par_list, columns=['num', 'Volatility', 'rate', 'return', 'drawdown', 'sharpe','TradeAnalyzer'])
  256. print(par_df)
  257. # par_df.to_csv('result.csv')
  258. ret = []
  259. for i in results:
  260. ret.append(get_my_analyzer(i[0]))
  261. pd.DataFrame(ret)
  262. '''
  263. if cerebro.broker.getvalue() > 100000.0:
  264. result_change.append((cerebro.broker.getvalue() / 10000 - 1))
  265. result.append(stock)
  266. # print('recode!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
  267. # print(result)
  268. else:
  269. result_change_fall.append((1 - cerebro.broker.getvalue() / 10000))
  270. # print('aaaaaaaaaaa')
  271. # print(result_change_fall)
  272. '''
  273. # if len(result) * len(result_change) * len(result_change_fall) != 0:
  274. # print(f'以{num}内最低值波动{Volatility}为支撑、乖离率为{rate}%,结果状态为:')
  275. # print('正盈利的个股为:', len(result_change), '成功率为:', len(result) / len(table_list))
  276. # print(
  277. # f'总盈利:{np.sum(result_change)} 平均盈利:{np.mean(result_change)},最大盈利:{np.max(result_change)}, 最小盈利:{np.min(result_change)}')
  278. # print(
  279. # f'总亏损:{np.sum(result_change_fall)},平均亏损:{np.mean(result_change_fall)},最大亏损:{np.min(result_change_fall)} 最小亏损:{np.max(result_change_fall)}')
  280. #
  281. # list_date.append([num, Volatility, rate, len(result), len(result) / len(table_list), np.nansum(result_change),
  282. # np.nanmean(result_change), np.nanmax(result_change), np.min(result_change),
  283. # np.nansum(result_change_fall), np.nanmean(result_change_fall),
  284. # np.nanmin(result_change_fall), np.nanmax(result_change_fall)])
  285. # to_df(list_date)
  286. # endtime = dt.now()
  287. # print(f'{num}天波动率为{Volatility}%乖离率为{rate},myPID is {os.getpid()}.本轮耗时为{endtime - sttime}')
  288. # else:
  289. # print(result, result_change, result_change_fall, num, Volatility, rate, err_list)
  290. # cerebro.plot()
  291. # df = pd.DataFrame(
  292. # columns=['周期', '波动率', '盈利个数', '盈利比例', '总盈利', '平均盈利', '最大盈利', '最小盈利', '总亏损',
  293. # '平均亏损', '最大亏损', '最小亏损'])
  294. if __name__ == '__main__':
  295. starttime = dt.now()
  296. print(starttime)
  297. fre = '1d'
  298. db = pymysql.connect(host='localhost',
  299. user='root',
  300. port=3307,
  301. password='r6kEwqWU9!v3',
  302. database='qmt_stocks')
  303. cursor = db.cursor()
  304. cursor.execute("show tables like '%%%s%%' " % fre)
  305. table_list = [tuple[0] for tuple in cursor.fetchall()]
  306. # print(table_list)
  307. table_list = table_list[0:2]
  308. result_change = []
  309. result_change_fall = []
  310. err_list = []
  311. stattime = dt.now()
  312. # print(f'{num}天波动率为{Volatility}%乖离率为{rate}')
  313. backtrader(table_list, result_change, result_change_fall, err_list)
  314. edtime = dt.now()
  315. print('总耗时:', edtime - starttime)
  316. # df.to_csv(r'C:\Users\Daniel\Documents\策略穷举2.csv', index=True)