# coding:utf-8
import time
from multiprocessing import freeze_support, Value, Lock
import backtrader as bt
from backtrader.feeds import PandasData
import backtrader.indicators as btind
from sqlalchemy import create_engine, text
import pymysql
from tqdm import tqdm
import concurrent.futures
import numpy as np
import pandas as pd
import platform
import datetime
from datetime import datetime as dt
from itertools import product
import psutil
import logging
import multiprocessing as mp
from itertools import islice

from func_timeout import func_set_timeout, FunctionTimedOut
from functools import partial


class MyPandasData(PandasData):
    lines = ('hl', 'dif', 'dea', 'macd', 'rsi_6', 'rsi_12', 'rsi_24',)
    params = (('hl', 7),
              ('dif', 8),
              ('dea', 9),
              ('macd', 10),
              ('rsi_6', 11),
              ('rsi_12', 12),
              ('rsi_24', 13),
              )


class TestStrategy(bt.Strategy):
    params = (
        ("num", 3),
        ('Volatility', 0),
        ('rate', 3),  # 注意要有逗号!!
    )

    def log(self, txt, dt=None):
        # 记录策略的执行日志
        dt = dt or self.datas[0].datetime.date(0)
        # print('%s, %s' % (dt.isoformat(), txt))

    def __init__(self):
        try:
            self.pos_price = 0
            self.dataclose = self.datas[0].close
            self.dataopen = self.datas[0].open
            self.high = self.datas[0].high
            self.low = self.datas[0].low
            self.volume = self.datas[0].volume
            self.hl = self.datas[0].hl
            self.dif = self.datas[0].dif
            self.dea = self.datas[0].dea
            self.macd = self.datas[0].macd
            self.rsi_6 = self.datas[0].rsi_6
            self.rsi_12 = self.datas[0].rsi_12
            self.rsi_24 = self.datas[0].rsi_24
            self.sma5 = btind.MovingAverageSimple(self.datas[0].close, period=5)
            self.sma10 = btind.MovingAverageSimple(self.datas[0].close, period=10)
            self.sma20 = btind.MovingAverageSimple(self.datas[0].close, period=20)
            self.sma60 = btind.MovingAverageSimple(self.datas[0].close, period=60)
        except BaseException as e:
            print(f'初始化错误{e}')

    def notify_order(self, order):
        """
        订单状态处理

        Arguments:
            order {object} -- 订单状态
        """
        if order.status in [order.Submitted, order.Accepted]:
            # 如订单已被处理,则不用做任何事情
            return

        # 检查订单是否完成
        if order.status in [order.Completed]:
            if order.isbuy():
                self.buyprice = order.executed.price
                self.buycomm = order.executed.comm
            self.bar_executed = len(self)

        # 订单因为缺少资金之类的原因被拒绝执行
        elif order.status in [order.Canceled, order.Margin, order.Rejected]:
            pass
            # self.log('Order Canceled/Margin/Rejected')

        # 订单状态处理完成,设为空
        self.order = None

    def notify_trade(self, trade):
        """
        交易成果

        Arguments:
            trade {object} -- 交易状态
        """
        if not trade.isclosed:
            return

        # 显示交易的毛利率和净利润
        # self.log('OPERATION PROFIT, GROSS %.2f, NET %.2f' % (trade.pnl, trade.pnlcomm))

    def next(self):
        rate = self.params.rate / 100
        vola = self.params.Volatility / 100
        if self.low[0] < self.sma5[0] * (1 - rate) and self.sma5[0] < self.sma5[-1] < self.sma5[-2] < self.sma10[-2] < \
                self.sma20[-2]:
            self.order = self.buy()
            self.pos_price = self.low[-1]


        # elif (self.hl[0] == 5 or self.dataclose[0] < self.sma5[0]):
        elif ((self.high[0] < self.sma5[0] and self.dataclose[0] < (self.high[0] * (1 - vola))) or
              (self.high[0] > self.sma5[0] > self.dataclose[0]) or self.dataclose[0] < self.pos_price) \
                or (self.dataclose[0] < self.sma5[0]) or (self.dataclose[0] > self.sma5[0] * (1 + rate)):
            self.order = self.close()
            self.pos_price = 0

    def stop(self):
        # pass
        self.log(u'(MA趋势交易效果) Ending Value %.2f' % (self.broker.getvalue()))


def to_df(df):
    print('开始存数据')
    df.sort_values(by=['MA5乖离率', '当日回落'], ascending=True, inplace=True)
    df = df.reset_index(drop=True)
    if platform.node() == 'DanieldeMBP.lan':
        df.to_csv(f"/Users/daniel/Documents/策略/Ma5乖离7买入{dt.now().strftime('%Y%m%d%H%m%S')}.csv",
                  index=True,
                  encoding='utf_8_sig', mode='w')
    else:
        df.to_csv(f"C:\策略结果\Ma5乖离7买入{dt.now().strftime('%Y%m%d%H%m%S')}.csv", index=True,
                  encoding='utf_8_sig', mode='w')
    print(f'结果:, \n, {df}')


def chunked_iterable(iterable, size):
    """将可迭代对象分割为指定大小的块"""
    it = iter(iterable)
    while True:
        chunk = tuple(islice(it, size))
        if not chunk:
            return
        yield chunk


def query_database(table_name):
    engine = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_tech?charset=utf8')
    df = pd.read_sql_table(table_name, engine)
    return df


def get_stock_data():
    while True:
        try:
            db = pymysql.connect(host='localhost',
                                 user='root',
                                 port=3307,
                                 password='r6kEwqWU9!v3',
                                 database='qmt_stocks_tech')
            cursor = db.cursor()
            cursor.execute("show tables like '%%%s%%' " % '1d')
            table_list = [tuple[0] for tuple in cursor.fetchall()]
            # table_list = table_list[0: 10]
            cursor.close()
            db.close()
            print(f'开始数据库读取')
            with concurrent.futures.ProcessPoolExecutor(max_workers=16) as executor:
                # 使用executor.map方法实现多进程并行查询数据库,得到每个表的数据,并存储在一个字典中
                data_dict = {table_name: df for table_name, df in
                             tqdm(zip(table_list, executor.map(query_database, table_list)))}
            print(f'数据库读取完成')
            break
        except BaseException as e:
            print(f'数据库读取错误{e}')
            continue
    return data_dict


def backtrader_test(stock_data, stock_name, vot, rate):
    # print(f'开始回测{stock_name}')
    try:
        cerebro = bt.Cerebro()
        stock_data.time = pd.to_datetime(stock_data.time)
        stock_data['HL'] = stock_data['HL'].map({'L': 1,
                                                 'LL': 2,
                                                 'L*': 3,
                                                 'H': 4,
                                                 'HH': 5,
                                                 'H*': 6,
                                                 '-': 7})
        cerebro.addstrategy(TestStrategy, Volatility=vot, rate=rate)
        data = MyPandasData(dataname=stock_data,
                            fromdate=datetime.datetime(2017, 1, 1),
                            todate=datetime.datetime(2022, 10, 30),
                            datetime='time',
                            open='open_back',
                            close='close_back',
                            high='high_back',
                            low='low_back',
                            volume='volume_back',
                            hl='HL',
                            dif='dif',
                            dea='dea',
                            macd='macd',
                            rsi_6='rsi_6',
                            rsi_12='rsi_12',
                            rsi_24='rsi_24',
                            )
        cerebro.adddata(data)
        cerebro.addstrategy(TestStrategy)
        cerebro.broker.setcash(100000.0)
        cerebro.addsizer(bt.sizers.FixedSize, stake=10000)
        cerebro.broker.setcommission(commission=0.001)
        cerebro.run()
    except  BaseException as e:
        print(f'{stock_name}回测错误{e}')
        return np.nan
    # print(cerebro.broker.getvalue() - 100000.0)
    # print(stock_name)
    else:
        return cerebro.broker.getvalue() - 100000.0


def tdf(tt, rate, Volatility):
    num_nan = np.isnan(tt).sum()  # Count NaN values
    print(f'num_nan={num_nan}')

    filtered_result = [r for r in tt if not np.isnan(r)]  # Filter out NaN values
    print(f'filtered_result={filtered_result}')

    # Calculate statistics
    num_profits = len([r for r in tt if r > 0])
    num_losses = len([r for r in tt if r < 0])
    profit_ratio = num_profits / (len(filtered_result))
    total_profit = sum([r for r in tt if r > 0])
    avg_profit = total_profit / num_profits if num_profits else 0
    max_profit = max(tt)
    min_profit = min([r for r in tt if r > 0]) if num_profits else 0
    total_loss = sum([r for r in tt if r < 0])
    avg_loss = total_loss / num_losses if num_losses else 0
    max_loss = min(tt)
    min_loss = max([r for r in tt if r < 0]) if num_losses else 0
    # Append the results into the DataFrame
    result_dict = {'MA5乖离率': rate, '当日回落': Volatility, '盈利个数': num_profits,
                   '盈利比例': profit_ratio, '总盈利': total_profit, '平均盈利': avg_profit,
                   '最大盈利': max_profit, '最小盈利': min_profit, '总亏损': total_loss,
                   '平均亏损': avg_loss, '最大亏损': max_loss, '最小亏损': min_loss, '未计算个股数': num_nan}
    df_t = pd.Series(result_dict)
    return df_t


if __name__ == '__main__':
    logger = mp.log_to_stderr()
    logger.setLevel(logging.DEBUG)
    cpu_list = list(range(0, 23))
    print(cpu_list)
    pus = psutil.Process()
    pus.cpu_affinity(cpu_list)
    start_time = dt.now()

    # 定义需要穷举的参数值
    Volatilitys = range(1, 10, 1)  # 当日回撤
    rates = range(3, 20, 1)  # 乖离率
    # 生成所有参数组合
    all_combinations = list(product(Volatilitys, rates))
    print(f'共需计算{len(all_combinations)}次')

    # 获取数据
    stock_data_dict = get_stock_data()
    results = []

    df = pd.DataFrame(
        columns=['MA5乖离率', '当日回落', '盈利个数', '盈利比例', '总盈利', '平均盈利', '最大盈利', '最小盈利',
                 '总亏损',
                 '平均亏损', '最大亏损', '最小亏损', '未计算个股数'])

    err_list = []

    # 设置每一轮的任务数
    CHUNK_SIZE = 200  # 您可以根据需要进行调整
    timeout = 120
    max_retries = 3
    with concurrent.futures.ProcessPoolExecutor(max_workers=24) as inner_executor:
        for Volatility, rate in tqdm(all_combinations, desc='计算进度'):
            while True:
                try:
                    # 使用executor.map方法实现多进程并行计算不同参数组合的结果
                    res = [result for result in tqdm(
                        inner_executor.map(backtrader_test, stock_data_dict.values(), stock_data_dict.keys(),
                                           [Volatility] * len(stock_data_dict), [rate] * len(stock_data_dict)),
                        desc='单轮计算进度')]

                except BaseException as e:
                    print(f'计算错误{e}')
                    inner_executor = concurrent.futures.ProcessPoolExecutor(max_workers=20)
                else:
                    results.append(res)
                    df_t = tdf(res, rate, Volatility)
                    df = pd.concat([df, df_t.to_frame().T], ignore_index=True)
                    break
            # time.sleep(1)
            print(f'{rate}计算完成,共计算{len(res)}个股票')
            print(df)
        print('循环结束')
        to_df(df)
        print(f'计算完成,共耗时{dt.now() - start_time}秒')