Ver código fonte

修正backtrader代码,设置highest价位卖出

Daniel 2 anos atrás
pai
commit
d7fa95b1a5
1 arquivos alterados com 37 adições e 33 exclusões
  1. 37 33
      futures_backtrader.py

+ 37 - 33
futures_backtrader.py

@@ -5,15 +5,18 @@ import pymysql
 import backtrader as bt
 import backtrader.indicators as btind
 import datetime
-import threading
+import math
 from datetime import datetime as dt
 import multiprocessing as mp
 from backtrader.feeds import PandasData
 
 # import multiprocessing
-import matplotlib
+# import matplotlib
 
 class MyPandasData(PandasData):
+    lines = ()
+    params = ()
+    '''
     lines = ('change_pct', 'net_amount_main', 'net_pct_main', 'net_amount_xl', 'net_pct_xl', 'net_amount_l', 'net_pct_l'
              , 'net_amount_m', 'net_pct_m', 'net_amount_s', 'net_pct_s',)
     params = (('change_pct', 7),
@@ -28,7 +31,7 @@ class MyPandasData(PandasData):
               ('net_amount_s', 16),
               ('net_pct_s', 17),
               )
-
+    '''
 
 class TestStrategy(bt.Strategy):
     params = (
@@ -50,13 +53,13 @@ class TestStrategy(bt.Strategy):
         self.high = self.datas[0].high
         self.low = self.datas[0].low
         self.volume = self.datas[0].volume
-        self.change_pct = self.datas[0].change_pct
-        self.net_amount_main = self.datas[0].net_amount_main
-        self.net_pct_main = self.datas[0].net_pct_main
-        self.net_amount_xl = self.datas[0].net_amount_xl
-        self.net_pct_xl = self.datas[0].net_pct_xl
-        self.net_amount_l = self.datas[0].net_amount_l
-        self.net_pct_l = self.datas[0].net_pct_l
+        # self.change_pct = self.datas[0].change_pct
+        # self.net_amount_main = self.datas[0].net_amount_main
+        # self.net_pct_main = self.datas[0].net_pct_main
+        # self.net_amount_xl = self.datas[0].net_amount_xl
+        # self.net_pct_xl = self.datas[0].net_pct_xl
+        # self.net_amount_l = self.datas[0].net_amount_l
+        # self.net_pct_l = self.datas[0].net_pct_l
         self.sma5 = btind.MovingAverageSimple(self.datas[0].close, period=5)
         self.sma10 = btind.MovingAverageSimple(self.datas[0].close, period=10)
         self.sma20 = btind.MovingAverageSimple(self.datas[0].close, period=20)
@@ -109,19 +112,19 @@ class TestStrategy(bt.Strategy):
         # and (self.net_amount_main[-1] > 0) and (self.net_amount_main[0] > 0)
         if len(self) > self.params.num:
             lowest = np.min(self.low.get(size=self.params.num))
+            highest = np.max(self.high.get(size = self.params.num))
             vola = self.params.Volatility/100
             rate = self.params.rate/100
             # print(f'{self.params.num}日天最低值:{lowest},波动率为{self.params.Volatility/100}')
             if (self.dataclose[0] > self.dataopen[0]) \
-                    and(((lowest*(1-vola)) < self.low[-2] < (lowest*(1+vola))) or((lowest*(1-vola)) < self.low[-1] < (lowest*(1+vola))))\
+                    and (((lowest*(1-vola)) < self.low[-2] < (lowest*(1+vola))) or ((lowest*(1-vola)) < self.low[-1] < (lowest*(1+vola))))\
                     and (self.dataclose[0] > self.sma5[0]) and self.sma5[0] > self.sma5[-1] \
-                    and (not self.position) and (self.sma5[0] > self.sma10[0]) \
-                    and (self.net_pct_main[-2] > 5) \
-                    and (self.change_pct[0] < 5):
+                    and (not self.position) and (self.sma5[0] > self.sma10[0]):
                 # self.log('BUY CREATE, %.2f' % self.dataclose[0])
                 self.order = self.buy()
-            elif self.dataclose < self.sma5[0]  or self.sma5[0] < self.sma10[0]\
-                    or (self.dataclose[0] > (self.sma5[0] * (1+rate))):
+            elif self.dataclose < self.sma5[0] or self.sma5[0] < self.sma10[0] \
+                    or (self.dataclose[0] > (self.sma5[0] * (1+rate))) or \
+                (((highest*(1-vola)) < self.high[-2] < (highest*(1+vola))) or ((highest*(1-vola)) < self.high[-1] < (highest*(1+vola)))):
                 self.order = self.close()
                 # self.log('Close, %.2f' % self.dataclose[0])
 
@@ -132,11 +135,11 @@ class TestStrategy(bt.Strategy):
 
 
 def backtrader(table_list, result, result_change,result_change_fall, num, Volatility, rate,err_list):
-    engine = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/stocks_data?charset=utf8')
+    engine = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks?charset=utf8')
     for stock in table_list:
         # print(stock)
         stk_df = pd.read_sql_table(stock, engine)
-        stk_df.date = pd.to_datetime(stk_df.date)
+        stk_df.time = pd.to_datetime(stk_df.time)
         if len(stk_df) > 60:
             cerebro = bt.Cerebro()
 
@@ -146,23 +149,23 @@ def backtrader(table_list, result, result_change,result_change_fall, num, Volati
             data = MyPandasData(dataname=stk_df,
                                 fromdate=datetime.datetime(2010,1,1),
                                 todate=datetime.datetime(2022, 10, 30),
-                                datetime='date',
+                                datetime='time',
                                 open='open',
                                 close='close',
                                 high='high',
                                 low='low',
                                 volume='volume',
-                                change_pct='change_pct',
-                                net_amount_main='net_amount_main',
-                                net_pct_main='net_pct_main',
-                                net_amount_xl='net_amount_xl',
-                                net_pct_xl='net_pct_xl',
-                                net_amount_l='net_amount_l',
-                                net_pct_l='net_pct_l',
-                                net_amount_m='net_amount_m',
-                                net_pct_m='net_pct_m',
-                                net_amount_s='net_amount_s',
-                                net_pct_s='net_pct_s',
+                                # change_pct='change_pct',
+                                # net_amount_main='net_amount_main',
+                                # net_pct_main='net_pct_main',
+                                # net_amount_xl='net_amount_xl',
+                                # net_pct_xl='net_pct_xl',
+                                # net_amount_l='net_amount_l',
+                                # net_pct_l='net_pct_l',
+                                # net_amount_m='net_amount_m',
+                                # net_pct_m='net_pct_m',
+                                # net_amount_s='net_amount_s',
+                                # net_pct_s='net_pct_s',
                                 )
             # print('取值完成')
 
@@ -201,7 +204,7 @@ if __name__ == '__main__':
                          user='root',
                          port=3307,
                          password='r6kEwqWU9!v3',
-                         database='stocks_data')
+                         database='qmt_stocks')
     cursor = db.cursor()
     cursor.execute("show tables like '%%%s%%' " % fre)
     table_list = [tuple[0] for tuple in cursor.fetchall()]
@@ -217,7 +220,7 @@ if __name__ == '__main__':
     for num in range(60, 140, 20):
         for Volatility in range(8, 10, 1):
             for rate in range(7, 10, 1):
-                step = 1000
+                step = math.ceil(len(table_list) / mp.cpu_count())
                 thread_list = []
                 result = mp.Manager().list()
                 result_change = mp.Manager().list()
@@ -253,6 +256,7 @@ if __name__ == '__main__':
                                                 np.min(result_change_fall), np.max(result_change_fall)]
                 print(df)
                 print('每轮耗时:', endtime-stattime)
+                df.to_csv(r'D:\Daniel\Documents\策略穷举.csv', index=True)
     edtime = dt.now()
     print('总耗时:', edtime - sttime)
-    df.to_csv(r'C:\Users\Daniel\Documents\策略穷举2.csv', index=True)
+    # df.to_csv(r'C:\Users\Daniel\Documents\策略穷举2.csv', index=True)