230723 _bt.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. import time
  2. import os
  3. import traceback
  4. import numpy as np
  5. from sqlalchemy import create_engine
  6. import pandas as pd
  7. import pymysql
  8. import backtrader as bt
  9. import backtrader.indicators as btind
  10. import datetime
  11. import math
  12. from datetime import datetime as dt
  13. import multiprocessing as mp
  14. from multiprocessing import Pool, Lock, Value, freeze_support
  15. import concurrent.futures
  16. import functools
  17. from backtrader.feeds import PandasData
  18. import platform
  19. import psutil
  20. import logging
  21. lock = Lock()
  22. counter = Value('i', 0)
  23. # engine = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_tech?charset=utf8',
  24. # pool_size=5000, max_overflow=200)
  25. # db_pool = pymysql.connect(host='localhost',
  26. # user='root',
  27. # port=3307,
  28. # password='r6kEwqWU9!v3',
  29. # database='qmt_stocks_tech')
  30. class MyPandasData(PandasData):
  31. lines = ('hl', 'dif', 'dea', 'macd', 'rsi_6', 'rsi_12', 'rsi_24',)
  32. params = (('hl', 7),
  33. ('dif', 8),
  34. ('dea', 9),
  35. ('macd', 10),
  36. ('rsi_6', 11),
  37. ('rsi_12', 12),
  38. ('rsi_24', 13),
  39. )
  40. '''
  41. lines = ('change_pct', 'net_amount_main', 'net_pct_main', 'net_amount_xl', 'net_pct_xl', 'net_amount_l', 'net_pct_l'
  42. , 'net_amount_m', 'net_pct_m', 'net_amount_s', 'net_pct_s',)
  43. params = (('change_pct', 7),
  44. ('net_amount_main', 8),
  45. ('net_pct_main', 9),
  46. ('net_amount_xl', 10),
  47. ('net_pct_xl', 11),
  48. ('net_amount_l', 12),
  49. ('net_pct_l', 13),
  50. ('net_amount_m', 14),
  51. ('net_pct_m', 15),
  52. ('net_amount_s', 16),
  53. ('net_pct_s', 17),
  54. )
  55. '''
  56. class TestStrategy(bt.Strategy):
  57. params = (
  58. ("num", 3),
  59. ('Volatility', 0),
  60. ('rate', 3), # 注意要有逗号!!
  61. )
  62. def log(self, txt, dt=None):
  63. ''' Logging function for this strategy'''
  64. dt = dt or self.datas[0].datetime.date(0)
  65. # print('%s, %s' % (dt.isoformat(), txt))
  66. def __init__(self):
  67. # self.num = num
  68. # self.Volatility = Volatility/100
  69. # Keep a reference to the "close" line in the data[0] dataseries
  70. self.pos_price = 0
  71. self.dataclose = self.datas[0].close
  72. self.dataopen = self.datas[0].open
  73. self.high = self.datas[0].high
  74. self.low = self.datas[0].low
  75. self.volume = self.datas[0].volume
  76. self.hl = self.datas[0].hl
  77. self.dif = self.datas[0].dif
  78. self.dea = self.datas[0].dea
  79. self.macd = self.datas[0].macd
  80. self.rsi_6 = self.datas[0].rsi_6
  81. self.rsi_12 = self.datas[0].rsi_12
  82. self.rsi_24 = self.datas[0].rsi_24
  83. # self.change_pct = self.datas[0].change_pct
  84. # self.net_amount_main = self.datas[0].net_amount_main
  85. # self.net_pct_main = self.datas[0].net_pct_main
  86. # self.net_amount_xl = self.datas[0].net_amount_xl
  87. # self.net_pct_xl = self.datas[0].net_pct_xl
  88. # self.net_amount_l = self.datas[0].net_amount_l
  89. # self.net_pct_l = self.datas[0].net_pct_l
  90. self.sma5 = btind.MovingAverageSimple(self.datas[0].close, period=5)
  91. self.sma10 = btind.MovingAverageSimple(self.datas[0].close, period=10)
  92. self.sma20 = btind.MovingAverageSimple(self.datas[0].close, period=20)
  93. self.sma60 = btind.MovingAverageSimple(self.datas[0].close, period=60)
  94. def notify_order(self, order):
  95. """
  96. 订单状态处理
  97. Arguments:
  98. order {object} -- 订单状态
  99. """
  100. if order.status in [order.Submitted, order.Accepted]:
  101. # 如订单已被处理,则不用做任何事情
  102. return
  103. # 检查订单是否完成
  104. if order.status in [order.Completed]:
  105. if order.isbuy():
  106. self.buyprice = order.executed.price
  107. self.buycomm = order.executed.comm
  108. self.bar_executed = len(self)
  109. # 订单因为缺少资金之类的原因被拒绝执行
  110. elif order.status in [order.Canceled, order.Margin, order.Rejected]:
  111. pass
  112. # self.log('Order Canceled/Margin/Rejected')
  113. # 订单状态处理完成,设为空
  114. self.order = None
  115. def notify_trade(self, trade):
  116. """
  117. 交易成果
  118. Arguments:
  119. trade {object} -- 交易状态
  120. """
  121. if not trade.isclosed:
  122. return
  123. # 显示交易的毛利率和净利润
  124. # self.log('OPERATION PROFIT, GROSS %.2f, NET %.2f' % (trade.pnl, trade.pnlcomm))
  125. def next(self):
  126. # if len(self) > self.params.num:
  127. vola = self.params.Volatility / 100
  128. rate = self.params.rate / 100
  129. lowest = np.min(self.low.get(size=self.params.num))
  130. highest = np.max(self.high.get(size=self.params.num))
  131. if self.hl[-1] == 2 or self.hl[-1] == 1:
  132. m = -2
  133. # self.order = self.buy()
  134. # self.pos_price = self.low[-1]
  135. while True:
  136. if (self.hl[m] == 2 or self.hl[m] == 1) and self.macd[m] > self.macd[-1] \
  137. and self.dataclose[0] > self.sma5[0] \
  138. and self.dataclose[-1] > self.dataopen[-1] \
  139. and (self.sma10[-2] - self.sma5[-2]) < (self.sma10[-1] - self.sma5[-1]) \
  140. and self.low[-2] < self.sma5[-2] * (1 - rate) \
  141. and self.sma5[-1] < self.sma10[-1] < self.sma20[-1] < self.sma20[-2] < self.sma20[-3] \
  142. and lowest * (1 - vola) < self.low[-1] < lowest * (1 + vola):
  143. self.order = self.buy()
  144. self.pos_price = self.low[-1]
  145. break
  146. m -= 1
  147. if m + len(self) == 2:
  148. break
  149. # elif (self.hl[0] == 5 or self.dataclose[0] < self.sma5[0]):
  150. elif self.dataclose[0] < self.sma5[0] or self.sma5[0] < self.sma5[-1] \
  151. or self.dataclose[0] < self.pos_price or self.high[0] > self.sma5[0] * (1 + vola):
  152. self.order = self.close()
  153. self.pos_price = 0
  154. def stop(self):
  155. # pass
  156. self.log(u'(MA趋势交易效果) Ending Value %.2f' % (self.broker.getvalue()))
  157. def err_call_back(err):
  158. print(f'出错啦~ error:{str(err)}')
  159. traceback.format_exc(err)
  160. def to_df(df):
  161. print('开始存数据')
  162. # df = pd.DataFrame(list(lt),
  163. # columns=['周期', '波动率', 'MA5斜率', '盈利个数', '盈利比例', '总盈利', '平均盈利', '最大盈利',
  164. # '最小盈利', '总亏损', '平均亏损', '最大亏损', '最小亏损', '盈亏对比'])
  165. df.sort_values(by=['周期', '波动率', 'MA5斜率'], ascending=True, inplace=True)
  166. df = df.reset_index(drop=True)
  167. if platform.node() == 'DanieldeMBP.lan':
  168. df.to_csv(f"/Users/daniel/Documents/策略/策略穷举-均线粘连后底分型{dt.now().strftime('%Y%m%d%H%m%S')}.csv",
  169. index=True,
  170. encoding='utf_8_sig', mode='w')
  171. else:
  172. df.to_csv(f"C:\策略结果\策略穷举底分型_均线缠绕_只买一次{dt.now().strftime('%Y%m%d%H%m%S')}.csv", index=True,
  173. encoding='utf_8_sig', mode='w')
  174. print(f'结果:, \n, {df}')
  175. def backtrader(stock, result, result_change, result_change_fall, num, Volatility, rate, err_list):
  176. # global engine
  177. # global db_pool
  178. global lock
  179. sttime = dt.now()
  180. engine = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_tech?charset=utf8',
  181. pool_size=10, max_overflow=20)
  182. try:
  183. # cursor = db_pool.cursor()
  184. # sql_query = f"select * from `{stock}`"
  185. # stk_df = pd.read_sql_query(sql_query, engine)
  186. conn = engine.connect()
  187. # with engine.connect() as conn:
  188. stk_df = pd.read_sql_table(stock, conn)
  189. stk_df.time = pd.to_datetime(stk_df.time)
  190. conn.close()
  191. engine.dispose()
  192. # stk_df = stk
  193. except BaseException as e:
  194. print(f'{stock}读取有问题', e)
  195. else:
  196. pass
  197. try:
  198. # stk_df = stk_df[stk_df['HL'] != '-']
  199. try:
  200. stk_df['HL'] = stk_df['HL'].map({'L': 1,
  201. 'LL': 2,
  202. 'L*': 3,
  203. 'H': 4,
  204. 'HH': 5,
  205. 'H*': 6,
  206. '-': 7})
  207. except BaseException:
  208. print(f'{stock}数据不全,不做测试')
  209. finally:
  210. # print(f'{stock}读取通过')
  211. pass
  212. try:
  213. if len(stk_df) > 60:
  214. try:
  215. cerebro = bt.Cerebro()
  216. cerebro.addstrategy(TestStrategy, num=num, Volatility=Volatility, rate=rate)
  217. cerebro.addsizer(bt.sizers.FixedSize, stake=10000)
  218. data = MyPandasData(dataname=stk_df,
  219. fromdate=datetime.datetime(2017, 1, 1),
  220. todate=datetime.datetime(2022, 10, 30),
  221. datetime='time',
  222. open='open_back',
  223. close='close_back',
  224. high='high_back',
  225. low='low_back',
  226. volume='volume_back',
  227. hl='HL',
  228. dif='dif',
  229. dea='dea',
  230. macd='macd',
  231. rsi_6='rsi_6',
  232. rsi_12='rsi_12',
  233. rsi_24='rsi_24',
  234. )
  235. # print('取值完成')
  236. cerebro.adddata(data, name=stock)
  237. cerebro.broker.setcash(100000.0)
  238. cerebro.broker.setcommission(0.005)
  239. cerebro.addanalyzer(bt.analyzers.PyFolio)
  240. # 策略执行前的资金
  241. # print('启动资金: %.2f' % cerebro.broker.getvalue())
  242. # 策略执行
  243. cerebro.run()
  244. except BaseException as e:
  245. lock.acquire()
  246. err_list.append(stock)
  247. lock.release()
  248. # print(f'{num}天波动率为{Volatility}%MA5斜率为{rate}的{stock}错误')
  249. print(stock, 'cerebro错误', e)
  250. else:
  251. lock.acquire()
  252. if cerebro.broker.getvalue() > 100000.0:
  253. result_change.append(cerebro.broker.getvalue() - 100000)
  254. result.append(stock)
  255. # print('recode!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
  256. # print(result)
  257. elif cerebro.broker.getvalue() <= 100000.0:
  258. result_change_fall.append(cerebro.broker.getvalue() - 100000)
  259. lock.release()
  260. else:
  261. lock.acquire()
  262. err_list.append(stock)
  263. lock.release()
  264. # print('aaaaaaaaaaa')
  265. # print(result_change_fall)
  266. # print('最终资金: %.2f' % cerebro.broker.getvalue())
  267. # finally:
  268. # with lock:
  269. # counter.value += 1
  270. # logging.info('执行完成:(%d / %d) 进程号: %d --------------- %s', counter.value, len(table_list), os.getpid(), stock)
  271. # print(f'已计算{counter.value}/{len(table_list)}只股票')
  272. except BaseException as e:
  273. print(f'{stock}backtrader问题', e)
  274. finally:
  275. print(f'{stock}通过')
  276. # print(f'已计算{(len(result) + len(result_change_fall)+len(err_list))}/{len(table_list)}只股票')
  277. if len(result) * len(result_change) * len(result_change_fall) != 0:
  278. print(f'以{num}内最低值波动{Volatility}为支撑、MA5斜率为{rate}%,结果状态为:')
  279. print('正盈利的个股为:', len(result), '成功率为:', len(result) / len(table_list))
  280. print(
  281. f'总盈利:{np.sum(result_change)} 平均盈利:{np.mean(result_change) / len(result)},最大盈利:{np.max(result_change)}, 最小盈利:{np.min(result_change)}')
  282. print(
  283. f'总亏损:{np.sum(result_change_fall)},平均亏损:{np.mean(result_change_fall) / len(result_change_fall)},最大亏损:{np.min(result_change_fall)} 最小亏损:{np.max(result_change_fall)}')
  284. # '周期', '波动率', 'MA5斜率', '盈利个数', '盈利比例', '总盈利', '平均盈利', '最大盈利', '最小盈利', '总亏损', '平均亏损', '最大亏损', '最小亏损', '盈亏对比']
  285. list_date.append([num, Volatility, rate, len(result), len(result) / len(table_list), np.nansum(result_change),
  286. np.nanmean(result_change), np.nanmax(result_change), np.min(result_change),
  287. np.nansum(result_change_fall), np.nanmean(result_change_fall),
  288. np.nanmin(result_change_fall), np.nanmax(result_change_fall),
  289. len(result_change) / len(result_change_fall)])
  290. # to_df(list_date)
  291. endtime = dt.now()
  292. print(f'{num}天波动率为{Volatility}%MA5斜率为{rate},myPID is {os.getpid()}.本轮耗时为{endtime - sttime}')
  293. else:
  294. print('阿欧', len(result), len(result_change), len(result_change_fall), num, Volatility, rate, err_list)
  295. list_date.append([num, Volatility, rate, 0, len(result) / len(table_list), len(result),
  296. len(result), len(result), len(result), len(result), len(result), len(result), 0])
  297. # list_date.append([num, Volatility, rate, len(result), len(result) / len(table_list), np.nansum(result_change),
  298. # np.nanmean(result_change), np.nanmax(result_change), np.min(result_change),
  299. # np.nansum(result_change_fall), np.nanmean(result_change_fall),
  300. # np.nanmin(result_change_fall), np.nanmax(result_change_fall),
  301. # len(result_change) / len(result_change_fall)])
  302. # cerebro.plot()
  303. # df = pd.DataFrame(
  304. # columns=['周期', '波动率', 'MA5斜率', '盈利个数', '盈利比例', '总盈利', '平均盈利', '最大盈利', '最小盈利', '总亏损',
  305. # '平均亏损', '最大亏损', '最小亏损'])
  306. #
  307. if __name__ == '__main__':
  308. freeze_support()
  309. logger = mp.log_to_stderr()
  310. logger.setLevel(logging.INFO)
  311. starttime = dt.now()
  312. print(starttime)
  313. # pus = psutil.Process()
  314. fre = '1d'
  315. db = pymysql.connect(host='localhost',
  316. user='root',
  317. port=3307,
  318. password='r6kEwqWU9!v3',
  319. database='qmt_stocks_tech')
  320. cursor = db.cursor()
  321. cursor.execute("show tables like '%%%s%%' " % fre)
  322. table_list = [tuple[0] for tuple in cursor.fetchall()]
  323. cursor.close()
  324. db.close()
  325. # print(table_list)
  326. # table_list = table_list[0:500]
  327. print(f'计算个股数为:{len(table_list)}')
  328. list_date = []
  329. pddate = pd.DataFrame(columns=['周期', '波动率', 'MA5斜率', '盈利个数', '盈利比例', '总盈利',
  330. '平均盈利', '最大盈利', '最小盈利', '总亏损', '平均亏损',
  331. '最大亏损',
  332. '最小亏损', '盈亏对比'])
  333. engine = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_tech?charset=utf8',
  334. pool_size=10, max_overflow=20)
  335. stk_df = pd.read_sql_table(table_list[0], engine)
  336. engine.dispose()
  337. print(stk_df)
  338. for num in range(60, 80, 20):
  339. for Volatility in range(7, 12, 1):
  340. for rate in range(3, 13, 1):
  341. stattime = dt.now().strftime('%Y-%m-%d %H:%M:%S')
  342. print(stattime)
  343. # pool = mp.Pool()
  344. result = mp.Manager().list()
  345. result_change = mp.Manager().list()
  346. result_change_fall = mp.Manager().list()
  347. err_list = mp.Manager().list()
  348. print(os.getpid())
  349. print(num, Volatility, rate, result, result_change, result_change_fall, err_list)
  350. # 保存AsyncResult对象的列表
  351. async_results = []
  352. partial_func_list = []
  353. m = 0
  354. try:
  355. pool = mp.Pool(processes=8)
  356. for stock in table_list:
  357. async_result = pool.apply_async(func=backtrader,
  358. args=(
  359. stock, result, result_change,
  360. result_change_fall,
  361. num, Volatility, rate, err_list,),
  362. error_callback=err_call_back)
  363. m += 1
  364. async_results.append(async_result)
  365. # p.start()
  366. pool.close()
  367. time.sleep(1)
  368. pool.join()
  369. except BaseException as e:
  370. print(f'进程池报错{e}')
  371. print(f'共有{m}只股票')
  372. # 统计返回为 None 的结果数量
  373. none_count = 0
  374. for i, result_async in enumerate(async_results):
  375. _ = result_async.get() # 获取任务的结果
  376. if _ is None:
  377. none_count += 1
  378. print(f'{num}天波动率为{Volatility}%MA5斜率为{rate}')
  379. print(f"正确计算的有{none_count},错误的有{len(err_list)},共计算{len(async_results)}只股票")
  380. '''
  381. list_date = [num, Volatility, rate, len(result), len(result) / len(table_list),
  382. np.nansum(result_change),
  383. np.nanmean(result_change), np.nanmax(result_change), np.min(result_change),
  384. np.nansum(result_change_fall), np.nanmean(result_change_fall),
  385. np.nanmin(result_change_fall), np.nanmax(result_change_fall),
  386. len(result_change) / len(result_change_fall)]
  387. ld = pd.Series(list_date, index=['周期', '波动率', 'MA5斜率', '盈利个数', '盈利比例', '总盈利',
  388. '平均盈利', '最大盈利', '最小盈利', '总亏损', '平均亏损',
  389. '最大亏损', '最小亏损', '盈亏对比'])
  390. pddate = pd.concat([pddate, ld.to_frame().T], ignore_index=True)
  391. print(f'计算总数={len(result) + len(result_change_fall)}\n计数为:{none_count}')
  392. print(pddate)
  393. to_df(pddate)
  394. # time.sleep(1)
  395. '''
  396. # to_df(list_date)
  397. print(pddate)
  398. to_df(pddate)
  399. edtime = dt.now()
  400. print('总耗时:', edtime - starttime)
  401. # with concurrent.futures.ProcessPoolExecutor() as executor:
  402. # for stock_code in table_list:
  403. # partial_func = functools.partial(backtrader, table_list, stock_code, result, result_change,
  404. # result_change_fall, num, Volatility, rate, err_list)
  405. # partial_func_list.append(partial_func)
  406. # executor.submit(partial_func)
  407. # executor.submit(backtrader, table_list, stock_code, result, result_change,
  408. # result_change_fall, num, Volatility, rate, err_list,)
  409. # print(pool)
  410. # df.to_csv(r'C:\Users\Daniel\Documents\策略穷举2.csv', index=True)