230218_backtrader.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. import os
  2. import traceback
  3. import numpy as np
  4. from sqlalchemy import create_engine
  5. import pandas as pd
  6. import pymysql
  7. import backtrader as bt
  8. import backtrader.indicators as btind
  9. import datetime
  10. import math
  11. from datetime import datetime as dt
  12. import multiprocessing as mp
  13. from backtrader.feeds import PandasData
  14. # import multiprocessing
  15. # import matplotlib
  16. class MyPandasData(PandasData):
  17. lines = ('hl', 'dif', 'dea', 'macd', 'rsi_6', 'rsi_12', 'rsi_24',)
  18. params = (('hl', 7),
  19. ('dif', 8),
  20. ('dea', 9),
  21. ('macd', 10),
  22. ('rsi_6', 11),
  23. ('rsi_12', 12),
  24. ('rsi_24', 13),
  25. )
  26. '''
  27. lines = ('change_pct', 'net_amount_main', 'net_pct_main', 'net_amount_xl', 'net_pct_xl', 'net_amount_l', 'net_pct_l'
  28. , 'net_amount_m', 'net_pct_m', 'net_amount_s', 'net_pct_s',)
  29. params = (('change_pct', 7),
  30. ('net_amount_main', 8),
  31. ('net_pct_main', 9),
  32. ('net_amount_xl', 10),
  33. ('net_pct_xl', 11),
  34. ('net_amount_l', 12),
  35. ('net_pct_l', 13),
  36. ('net_amount_m', 14),
  37. ('net_pct_m', 15),
  38. ('net_amount_s', 16),
  39. ('net_pct_s', 17),
  40. )
  41. '''
  42. class TestStrategy(bt.Strategy):
  43. params = (
  44. ("num", 3),
  45. ('Volatility', 0),
  46. ('rate', 5), # 注意要有逗号!!
  47. )
  48. def log(self, txt, dt=None):
  49. ''' Logging function for this strategy'''
  50. dt = dt or self.datas[0].datetime.date(0)
  51. # print('%s, %s' % (dt.isoformat(), txt))
  52. def __init__(self):
  53. # self.num = num
  54. # self.Volatility = Volatility/100
  55. # Keep a reference to the "close" line in the data[0] dataseries
  56. self.dataclose = self.datas[0].close
  57. self.dataopen = self.datas[0].open
  58. self.high = self.datas[0].high
  59. self.low = self.datas[0].low
  60. self.volume = self.datas[0].volume
  61. self.hl = self.datas[0].hl
  62. self.dif = self.datas[0].dif
  63. self.dea = self.datas[0].dea
  64. self.macd = self.datas[0].macd
  65. self.rsi_6 = self.datas[0].rsi_6
  66. self.rsi_12 = self.datas[0].rsi_12
  67. self.rsi_24 = self.datas[0].rsi_24
  68. # self.change_pct = self.datas[0].change_pct
  69. # self.net_amount_main = self.datas[0].net_amount_main
  70. # self.net_pct_main = self.datas[0].net_pct_main
  71. # self.net_amount_xl = self.datas[0].net_amount_xl
  72. # self.net_pct_xl = self.datas[0].net_pct_xl
  73. # self.net_amount_l = self.datas[0].net_amount_l
  74. # self.net_pct_l = self.datas[0].net_pct_l
  75. self.sma5 = btind.MovingAverageSimple(self.datas[0].close, period=5)
  76. self.sma10 = btind.MovingAverageSimple(self.datas[0].close, period=10)
  77. self.sma20 = btind.MovingAverageSimple(self.datas[0].close, period=20)
  78. def notify_order(self, order):
  79. """
  80. 订单状态处理
  81. Arguments:
  82. order {object} -- 订单状态
  83. """
  84. if order.status in [order.Submitted, order.Accepted]:
  85. # 如订单已被处理,则不用做任何事情
  86. return
  87. # 检查订单是否完成
  88. if order.status in [order.Completed]:
  89. if order.isbuy():
  90. self.buyprice = order.executed.price
  91. self.buycomm = order.executed.comm
  92. self.bar_executed = len(self)
  93. # 订单因为缺少资金之类的原因被拒绝执行
  94. elif order.status in [order.Canceled, order.Margin, order.Rejected]:
  95. pass
  96. # self.log('Order Canceled/Margin/Rejected')
  97. # 订单状态处理完成,设为空
  98. self.order = None
  99. def notify_trade(self, trade):
  100. """
  101. 交易成果
  102. Arguments:
  103. trade {object} -- 交易状态
  104. """
  105. if not trade.isclosed:
  106. return
  107. # 显示交易的毛利率和净利润
  108. # self.log('OPERATION PROFIT, GROSS %.2f, NET %.2f' % (trade.pnl, trade.pnlcomm))
  109. def next(self):
  110. # print(self.num,self.Volatility)
  111. # Simply log the closing price of the series from the reference
  112. # self.sma20[-2] < self.sma20[-1] < self.sma20[0] and self.sma10[-2] < self.sma10[-1] < self.sma10[0]
  113. # and (self.sma5[-1] < self.sma10[-1])
  114. # and (self.net_pct_l[0] > 10) and (self.net_pct_xl[0] > 3) \
  115. # and (self.net_amount_main[-1] > 0) and (self.net_amount_main[0] > 0)
  116. if len(self) > self.params.num:
  117. vola = self.params.Volatility / 100
  118. rate = self.params.rate / 100
  119. lowest = np.min(self.low.get(size=self.params.num))
  120. highest = np.max(self.high.get(size=self.params.num))
  121. if self.hl[-2] == 2 and self.dataclose[0] > self.sma5[0] > self.sma5[-1] \
  122. and (((lowest * (1 - vola)) < self.low[-2] < (lowest * (1 + vola))) or (
  123. (lowest * (1 - vola)) < self.low[-1] < (lowest * (1 + vola)))) and self.rsi_6[0] > self.rsi_12[0] \
  124. and self.rsi_12[0] < 40 and self.rsi_6[0] > self.rsi_6[-1] and self.rsi_6[-1] < self.rsi_6[-2] \
  125. and self.volume[0] >= self.volume[-1]:
  126. self.order = self.buy()
  127. elif self.hl[0] == 5 and ((highest * (1 - vola)) < self.high[-2] < (highest * (1 + vola))):
  128. self.order = self.close()
  129. '''
  130. if len(self) > self.params.num:
  131. lowest = np.min(self.low.get(size=self.params.num))
  132. highest = np.max(self.high.get(size=self.params.num))
  133. vola = self.params.Volatility / 100
  134. rate = self.params.rate / 100
  135. # print(f'{self.params.num}日天最低值:{lowest},波动率为{self.params.Volatility/100}')
  136. if (self.dataclose[0] > self.dataopen[0]) \
  137. and (((lowest * (1 - vola)) < self.low[-2] < (lowest * (1 + vola))) or (
  138. (lowest * (1 - vola)) < self.low[-1] < (lowest * (1 + vola)))) \
  139. and (self.dataclose[0] > self.sma5[0]) and self.sma5[0] > self.sma5[-1] \
  140. and (not self.position) and (self.sma5[0] > self.sma10[0]):
  141. # self.log('BUY CREATE, %.2f' % self.dataclose[0])
  142. self.order = self.buy()
  143. elif self.dataclose < self.sma5[0] or self.sma5[0] < self.sma10[0] \
  144. or (self.dataclose[0] > (self.sma5[0] * (1 + rate))) or \
  145. (((highest * (1 - vola)) < self.high[-2] < (highest * (1 + vola))) or (
  146. (highest * (1 - vola)) < self.high[-1] < (highest * (1 + vola)))):
  147. self.order = self.close()
  148. # self.log('Close, %.2f' % self.dataclose[0])
  149. '''
  150. def stop(self):
  151. # pass
  152. self.log(u'(MA趋势交易效果) Ending Value %.2f' % (self.broker.getvalue()))
  153. def err_call_back(err):
  154. print(f'出错啦~ error:{str(err)}')
  155. traceback.format_exc(err)
  156. def to_df(lt):
  157. df = pd.DataFrame(list(lt), columns=['周期', '波动率', '量能增长率', '盈利个数', '盈利比例', '总盈利', '平均盈利', '最大盈利',
  158. '最小盈利', '总亏损', '平均亏损', '最大亏损', '最小亏损'])
  159. df.sort_values(by=['周期', '波动率', '量能增长率'], ascending=True, inplace=True)
  160. df = df.reset_index(drop=True)
  161. df.to_csv(f"D:\Daniel\策略\策略穷举{dt.now().strftime('%Y%m%d')}.csv", index=True, encoding='utf-8', mode='w')
  162. # df.to_csv(f"/Users/daniel/Documents/策略/策略穷举{dt.now().strftime('%Y%m%d')}.csv", index=True, encoding='utf-8', mode='w')
  163. print(df)
  164. def backtrader(list_date, table_list, result, result_change, result_change_fall, num, Volatility, rate, err_list):
  165. print(f'{num}天波动率为{Volatility}%量能增长率为{rate}', 'myPID is ', os.getpid())
  166. sttime = dt.now()
  167. engine = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_tech?charset=utf8')
  168. for stock in table_list:
  169. # print(stock)
  170. stk_df = pd.read_sql_table(stock, engine)
  171. stk_df.time = pd.to_datetime(stk_df.time)
  172. try:
  173. stk_df['HL'] = stk_df['HL'].map({'L': 1,
  174. 'LL': 2,
  175. 'L*': 3,
  176. 'H': 4,
  177. 'HH': 5,
  178. 'H*': 6,
  179. '-': 7})
  180. except BaseException:
  181. print(stock, 'HL 可能没有')
  182. else:
  183. if len(stk_df) > 60:
  184. cerebro = bt.Cerebro()
  185. cerebro.addstrategy(TestStrategy, num=num, Volatility=Volatility, rate=rate)
  186. cerebro.addsizer(bt.sizers.FixedSize, stake=10000)
  187. data = MyPandasData(dataname=stk_df,
  188. fromdate=datetime.datetime(2017, 1, 1),
  189. todate=datetime.datetime(2022, 10, 30),
  190. datetime='time',
  191. open='open_back',
  192. close='close_back',
  193. high='high_back',
  194. low='low_back',
  195. volume='volume_back',
  196. hl='HL',
  197. dif='dif',
  198. dea='dea',
  199. macd='macd',
  200. rsi_6='rsi_6',
  201. rsi_12='rsi_12',
  202. rsi_24='rsi_24',
  203. # change,_pct='change_pct',
  204. # net_amount_main='net_amount_main',
  205. # net_pct_main='net_pct_main',
  206. # net_amount_xl='net_amount_xl',
  207. # net_pct_xl='net_pct_xl',
  208. # net_amount_l='net_amount_l',
  209. # net_pct_l='net_pct_l',
  210. # net_amount_m='net_amount_m',
  211. # net_pct_m='net_pct_m',
  212. # net_amount_s='net_amount_s',
  213. # net_pct_s='net_pct_s',
  214. )
  215. # print('取值完成')
  216. cerebro.adddata(data, name=stock)
  217. cerebro.broker.setcash(100000.0)
  218. cerebro.broker.setcommission(0.005)
  219. cerebro.addanalyzer(bt.analyzers.PyFolio)
  220. # 策略执行前的资金
  221. # print('启动资金: %.2f' % cerebro.broker.getvalue())
  222. try:
  223. # 策略执行
  224. cerebro.run()
  225. except IndexError:
  226. err_list.append(stock)
  227. else:
  228. if cerebro.broker.getvalue() > 100000.0:
  229. result_change.append((cerebro.broker.getvalue() / 10000 - 1))
  230. result.append(stock)
  231. # print('recode!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
  232. # print(result)
  233. else:
  234. result_change_fall.append((1 - cerebro.broker.getvalue() / 10000))
  235. # print('aaaaaaaaaaa')
  236. # print(result_change_fall)
  237. if len(result) * len(result_change) * len(result_change_fall) != 0:
  238. print(f'以{num}内最低值波动{Volatility}为支撑、量能增长率为{rate}%,结果状态为:')
  239. print('正盈利的个股为:', len(result_change), '成功率为:', len(result) / len(table_list))
  240. print(
  241. f'总盈利:{np.sum(result_change)} 平均盈利:{np.mean(result_change)},最大盈利:{np.max(result_change)}, 最小盈利:{np.min(result_change)}')
  242. print(
  243. f'总亏损:{np.sum(result_change_fall)},平均亏损:{np.mean(result_change_fall)},最大亏损:{np.min(result_change_fall)} 最小亏损:{np.max(result_change_fall)}')
  244. list_date.append([num, Volatility, rate, len(result), len(result) / len(table_list), np.nansum(result_change),
  245. np.nanmean(result_change), np.nanmax(result_change), np.min(result_change),
  246. np.nansum(result_change_fall), np.nanmean(result_change_fall),
  247. np.nanmin(result_change_fall), np.nanmax(result_change_fall)])
  248. to_df(list_date)
  249. endtime = dt.now()
  250. print(f'{num}天波动率为{Volatility}%量能增长率为{rate},myPID is {os.getpid()}.本轮耗时为{endtime - sttime}')
  251. else:
  252. print(result, result_change, result_change_fall, num, Volatility, rate, err_list)
  253. # cerebro.plot()
  254. df = pd.DataFrame(
  255. columns=['周期', '波动率', '盈利个数', '盈利比例', '总盈利', '平均盈利', '最大盈利', '最小盈利', '总亏损',
  256. '平均亏损', '最大亏损', '最小亏损'])
  257. if __name__ == '__main__':
  258. starttime = dt.now()
  259. print(starttime)
  260. # engine = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/hlfx?charset=utf8', poolclass=NullPool)
  261. # stocks = pd.read_sql_query(
  262. # 'select value from MA5_1d', engine_hlfx)
  263. fre = '1d'
  264. db = pymysql.connect(host='localhost',
  265. user='root',
  266. port=3307,
  267. password='r6kEwqWU9!v3',
  268. database='qmt_stocks_tech')
  269. cursor = db.cursor()
  270. cursor.execute("show tables like '%%%s%%' " % fre)
  271. table_list = [tuple[0] for tuple in cursor.fetchall()]
  272. # print(table_list)
  273. # table_list = table_list[0:100]
  274. list_date = mp.Manager().list()
  275. thread_list = []
  276. pool = mp.Pool(processes=mp.cpu_count())
  277. for num in range(60, 180, 20):
  278. for Volatility in range(5, 8, 1):
  279. for rate in range(5, 12, 1):
  280. step = math.ceil(len(table_list) / mp.cpu_count())
  281. result = []
  282. result_change = []
  283. result_change_fall = []
  284. err_list = []
  285. print(f'{num}天波动率为{Volatility}%量能增长率为{rate}')
  286. # for i in range(0, len(table_list), step):
  287. stattime = dt.now()
  288. # thd = threading.local()
  289. # print(i)
  290. # p = mp.Process(target=backtrader, args=(df, table_list, result, result_change, result_change_fall,
  291. # num, Volatility, rate, err_list))
  292. # thread_list.append(p)
  293. pool.apply_async(func=backtrader,
  294. args=(list_date, table_list, result, result_change, result_change_fall,
  295. num, Volatility, rate, err_list,), error_callback=err_call_back)
  296. # p.start()
  297. # p.join()
  298. # print(thread_list)
  299. # for thread in thread_list:
  300. # thread.start()
  301. # for thread in thread_list:
  302. # thread.join()
  303. pool.close()
  304. pool.join()
  305. edtime = dt.now()
  306. print('总耗时:', edtime - starttime)
  307. # df.to_csv(r'C:\Users\Daniel\Documents\策略穷举2.csv', index=True)