# coding:utf-8
import time

from datetime import datetime as dt
import socket
import pandas as pd
import numpy as np
from sqlalchemy import create_engine, text
from jqdatasdk import *
import pymysql
import multiprocessing as mp
from multiprocessing import freeze_support
import concurrent.futures
import math
import talib as ta
import os
import traceback
import random
import logging
from myindicator import myind
import psutil
from tqdm import tqdm
from itertools import islice
from func_timeout import func_set_timeout, FunctionTimedOut
from apscheduler.schedulers.blocking import BlockingScheduler

# 显示最大行与列
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 创建连接池
engine = create_engine(
    'mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_whole?charset=utf8', pool_recycle=3600, pool_size=100,
    max_overflow=20)
engine_tech = create_engine(
    'mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_tech?charset=utf8', pool_size=100, pool_recycle=3600,
    max_overflow=20)


# engine_tech2 = create_engine(
#     'mysql+pymysql://root:r6kEwqWU9!v3@localhost:3308/qmt_stocks_tech?charset=utf8', pool_size=100, max_overflow=20)


def err_call_back(err):
    logging.error(f'进程池出错~ error:{str(err)}')
    traceback.print_exc()


def tech_anal(stock, df_stock, fre, hlfx_pool, hlfx_pool_daily, err_list):
    import pandas as pd
    t_signals = 0
    global engine
    global engine_tech
    # global engine_tech2

    try:
        # con_engine = engine.connect()
        # con_engine_tech = engine_tech.connect()
        # con_engine_tech2 = engine_tech2.connect()
        try:
            # table_name = f'{stock}_{fre}'
            # 从engine中读取table_name表存入df
            # df = pd.read_sql_table(table_name, con=engine)
            table_name = stock
            df = df_stock
            df.dropna(axis=0, how='any')
        except BaseException as e:
            print(f"{stock}读取有问题")
            traceback.print_exc()
            err_list.append(stock[0:9])
        else:
            if len(df) != 0:
                # 计算技术指标
                print(f'{stock}开始计算技术指标')
                try:
                    myind.get_macd_data(df)
                    myind.get_ris(df)
                    myind.get_bias(df)
                    myind.get_wilr(df)
                    df = df.round(2)
                    df_temp, t_signals = myind.get_hlfx(df)
                    df = pd.merge(df, df_temp, on='time', how='left')
                    df['HL'].fillna(value='-', inplace=True)
                    df = df.reset_index(drop=True)
                    df = df.replace([np.inf, -np.inf], np.nan)
                    df = df.round(2)
                except BaseException as e:
                    print(f'{stock}计算有问题', e)
                else:
                    # 存入数据库
                    try:
                        # pass
                        df.to_sql('%s' % stock, con=engine_tech, index=False, if_exists='replace')
                        # df.to_sql('%s_1d' % stock, con=engine_tech2, index=False, if_exists='replace')
                    except BaseException as e:
                        print(f'{stock}存储有问题', e)
                        traceback.print_exc()
                        err_list.append(stock[0:9])
            else:
                err_list.append(stock[0:9])
                print(f'{stock}数据为空')
        finally:
            if stock in hlfx_pool and t_signals == 2:
                hlfx_pool.remove(stock)
            elif stock not in hlfx_pool and t_signals == 1:
                hlfx_pool.append(stock[0:9])
                hlfx_pool_daily.append(stock[0:9])
            # con_engine.close()
            # con_engine_tech.close()
            print(f'{stock}计算完成!')
            # con_engine_tech2.close()
            # print(f"{stock}, {T_signals}, '\n', {df_temp.head(20)}")
            # print(f'{stock}计算完成!')

    except Exception as e:
        logging.error(f'子进程{os.getpid()}问题在这里~~ error:{str(e)}')
        traceback.print_exc()

    engine.dispose()
    engine_tech.dispose()
    # engine_tech2.dispose()


def query_database(table_name):
    engine = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_whole?charset=utf8')
    df = pd.read_sql_table(table_name, engine)
    engine.dispose()
    return df


def get_stock_data():
    while True:
        try:
            db = pymysql.connect(host='localhost',
                                 user='root',
                                 port=3307,
                                 password='r6kEwqWU9!v3',
                                 database='qmt_stocks_whole')
            cursor = db.cursor()
            cursor.execute("show tables like '%%%s%%' " % '1d')
            table_list = [tuple[0] for tuple in cursor.fetchall()]
            table_list = table_list
            cursor.close()
            db.close()
            print(f'开始数据库读取')
            with concurrent.futures.ProcessPoolExecutor(max_workers=24) as executor:
                # 使用executor.map方法实现多进程并行查询数据库,得到每个表的数据,并存储在一个字典中
                data_dict = {table_name: df for table_name, df in
                             tqdm(zip(table_list, executor.map(query_database, table_list)))}
            print(f'数据库读取完成')
            break
        except BaseException as e:
            print(f'数据库读取错误{e}')
            time.sleep(30)
            continue
    return data_dict


# 分割列表
def split_list(lst, num_parts):
    avg = len(lst) // num_parts
    rem = len(lst) % num_parts

    partitions = []
    start = 0
    for i in range(num_parts):
        end = start + avg + (1 if i < rem else 0)
        partitions.append(lst[start:end])
        start = end

    return partitions


def chunked_iterable(iterable, size):
    """将可迭代对象分割为指定大小的块"""
    it = iter(iterable)
    while True:
        chunk = tuple(islice(it, size))
        if not chunk:
            return
        yield chunk


# 多进程实现技术指标计算
def ind():
    # 记录开始时间
    start_time = dt.now()
    fre = '1d'
    if socket.gethostname() == 'DESKTOP-PC':
        num_cpus = mp.cpu_count()
    else:
        num_cpus = mp.cpu_count()

    print(
        f"{socket.gethostname()}共有{num_cpus}个核心\n{start_time.strftime('%Y-%m-%d %H:%M:%S')}开始计算{fre}技术指标")
    while True:
        try:
            # 连接数据库 获取股票列表
            conn_engine_hlfx_pool = create_engine(
                'mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/hlfx_pool?charset=utf8')
            con_engine_hlfx_pool = conn_engine_hlfx_pool.connect()

            # stocks = xtdata.get_stock_list_in_sector('沪深A股')
            stocks = pd.read_sql_query(
                text("select securities from %s" % 'stocks_list'), con=con_engine_hlfx_pool).iloc[-1, 0].split(",")
            con_engine_hlfx_pool.close()
            conn_engine_hlfx_pool.dispose()

        except BaseException as e:
            print(f'股票列表读取错误{e}')
            continue
        else:
            print(f'股票列表长度为{len(stocks)}')
            break
    err_list, hlfx_pool, hlfx_pool_daily = mp.Manager().list(), mp.Manager().list(), mp.Manager().list()  # 定义共享列表

    # 多进程执行tech_anal方法
    # 保存AsyncResult对象的列表
    async_results = []
    # m = 0
    # with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus) as executor:
    #     for stock in tqdm(stocks):
    #         executor.submit(tech_anal, stock, fre, hlfx_pool, hlfx_pool_daily, err_list)
    #         m += 1
    # print(m)

    # 获取数据
    stock_data_dict = get_stock_data()

    # 设置每一轮的任务数
    CHUNK_SIZE = 200  # 您可以根据需要进行调整
    timeout = 120
    max_retries =3

    for chunk in chunked_iterable(stock_data_dict.items(), CHUNK_SIZE):
        retries = 0
        while True:
            print(f'chunk:{chunk[0][0]}-{chunk[-1][0]}')
            with mp.Pool(processes=min(CHUNK_SIZE, len(chunk), num_cpus)) as pool:  # 使用最小值确保不会超出任务数或超过24核心
                for stock, df_stock in chunk:
                    print('**************', stock)
                    async_result = pool.apply_async(func=tech_anal, args=(stock, df_stock, fre, hlfx_pool, hlfx_pool_daily,
                                                                          err_list), error_callback=err_call_back)
                    async_results.append(async_result)
                try:
                    for async_result in async_results:
                        result = async_result.get(timeout=timeout)
                except mp.TimeoutError:
                    retries += 1
                    print(f"Timeout occurred in pool. Retry {retries}/{max_retries}...")
                    continue
                except FunctionTimedOut:
                    retries += 1
                    print(f"Timeout occurred in worker. Retry {retries}/{max_retries}...")
                    continue
                except Exception as e:
                    print(f"Error occurred: {e}")
                    break
                else:
                    pool.close()
                    pool.join()
                    break

    # with mp.Pool(processes=1) as pool:
    #     for stock, df_stock in tqdm(stock_data_dict.items()):
    #         # print(stock, df_stock.shape)
    #         async_result = pool.apply_async(tech_anal, args=(stock, df_stock, fre, hlfx_pool, hlfx_pool_daily, err_list),
    #                                         error_callback=err_call_back)
    #         async_results.append(async_result)
    #     pool.close()
    #     pool.join()

    # 统计返回为 None 的结果数量
    none_count = 0
    for i, result_async in enumerate(async_results):
        result = result_async.get()  # 获取任务的结果
        # print(f"The result of task {i} is: {result}")
        if result is None:
            none_count += 1

    print(
        f"共计算{none_count}/{len(async_results)},\n当日信号:{len(hlfx_pool_daily)},\n持续检测为:{len(hlfx_pool)}, \n错误列表:{err_list}")

    # 将list转换为字符串
    results_list = ','.join(set(hlfx_pool))
    results_list_daily = ','.join(set(hlfx_pool_daily))

    # 建立数据库连接

    db_pool = pymysql.connect(host='localhost',
                              user='root',
                              port=3307,
                              password='r6kEwqWU9!v3',
                              database='hlfx_pool')

    # db_pool2 = pymysql.connect(host='localhost',
    #                            user='root',b
    #                            port=3308,
    #                            password='r6kEwqWU9!v3',
    #                            database='hlfx_pool')

    # 将list插入数据库
    cursor = db_pool.cursor()
    # cursor2 = db_pool2.cursor()
    sql = "INSERT INTO %s (date,value) VALUES('%s','%s')" % (fre, dt.now().strftime('%Y-%m-%d %H:%M:%S'), results_list)
    sql2 = "INSERT INTO daily_%s (date,value) VALUES('%s','%s')" % (fre, dt.now().strftime('%Y-%m-%d %H:%M:%S'),
                                                                    results_list_daily)
    try:
        cursor.execute(sql)
        cursor.execute(sql2)
        # cursor2.execute(sql)
        # cursor2.execute(sql2)
        db_pool.commit()
        # db_pool2.commit()
    except Exception as e:
        print(f'1d存入有问题', e)
        # db_pool.rollback()
    finally:
        print(f"results_list_daily:{results_list_daily}")
        cursor.close()
        db_pool.close()
        # cursor2.close()
        # db_pool2.close()

    # 记录结束时间
    end_time = dt.now()
    print(f"运行时间:{end_time - start_time}")


if __name__ == '__main__':
    logger = mp.log_to_stderr()
    logger.setLevel(logging.DEBUG)
    freeze_support()
    # 创建一个0-17的列表,用于设置cpu亲和度
    cpu_list = list(range(23))
    pus = psutil.Process()
    pus.cpu_affinity(cpu_list)

    ind()