from xtquant import xtdata
from datetime import datetime as dt
import pandas as pd
import math
from sqlalchemy import create_engine, text
import multiprocessing as mp
import os
from apscheduler.schedulers.blocking import BlockingScheduler
import traceback
import psutil


pd.set_option('display.max_columns', None) # 设置显示最大行


path = 'C:\\qmt\\userdata_mini'

field = ['time', 'open', 'close', 'high', 'low', 'volume', 'amount']
cpu_count = mp.cpu_count()




def err_call_back(err):
    print(f'问题在这里~ error:{str(err)}')
    traceback.print_exc()


def to_sql(stock_list):
    print(f'{dt.now()}开始循环入库! MyPid is {os.getpid()}')
    m = 0
    for stock in stock_list:
        eng_w = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/qmt_stocks_whole?charset=utf8',
                              pool_recycle=3600, pool_pre_ping=True, pool_size=1)
        # 后复权数据
        data_back = xtdata.get_market_data(field, [stock], '1d', end_time='', count=-1, dividend_type='back')
        df_back = pd.concat([data_back[i].loc[stock].T for i in ['time', 'open', 'high', 'low', 'close', 'volume',
                                                                 'amount']], axis=1)
        df_back.columns = ['time', 'open_back', 'high_back', 'low_back', 'close_back', 'volume_back', 'amount_back']
        df_back['time'] = df_back['time'].apply(lambda x: dt.fromtimestamp(x / 1000.0))
        df_back.reset_index(drop=True, inplace=True)

        # 前复权数据
        data_front = xtdata.get_market_data(field, [stock], '1d', end_time='', count=-1, dividend_type='front')
        df_front = pd.concat([data_front[i].loc[stock].T for i in ['time', 'open', 'high', 'low', 'close', 'volume',
                                                                   'amount']], axis=1)
        df_front.columns = ['time', 'open_front', 'high_front', 'low_front', 'close_front', 'volume_front',
                            'amount_front']
        df_front['time'] = df_front['time'].apply(lambda x: dt.fromtimestamp(x / 1000.0))
        df = pd.merge_asof(df_back, df_front, 'time')
        # print(df)
        try:
            # eng_w.connect().execute(text("truncate table `%s_1d`" % stock))
            df.to_sql('%s_1d' % stock, con=eng_w, index=False, if_exists='replace', chunksize=20000)
        except BaseException as e:
            print(stock, e)
            pass
        else:
            m += 1

        eng_w.dispose()
    print(f'Pid:{os.getpid()}已经完工了.应入库{len(stock_list)},共入库{m}支个股')


def download_data():
    stock_list = xtdata.get_stock_list_in_sector('沪深A股')
    stock_list.sort()
    print(dt.now(), '开始下载!')
    xtdata.download_history_data2(stock_list=stock_list, period='1d', start_time='', end_time='')
    print(dt.now(), '下载完成,准备入库!')
    step = math.ceil(len(stock_list) / mp.cpu_count())
    pool = mp.Pool(processes=mp.cpu_count())
    # pool = mp.Pool(processes=8)
    # step = math.ceil(len(stock_list) / 8)
    for i in range(0, len(stock_list), step):
        pool.apply_async(func=to_sql, args=(stock_list[i:i+step],), error_callback=err_call_back)
    pool.close()
    pool.join()

    print(f'今日数据下载完毕 {dt.now()}')


if __name__ == '__main__':
    field = ['time', 'open', 'close', 'high', 'low', 'volume', 'amount']
    cpu_count = mp.cpu_count()
    pus = psutil.Process()
    # pus.cpu_affinity([12, 13, 14, 15, 16, 17, 18, 19])

    # download_data()

    scheduler = BlockingScheduler()
    scheduler.add_job(func=download_data, trigger='cron', day_of_week='0-4', hour='20', minute='05',
                      timezone="Asia/Shanghai", max_instances=10)
    try:
        scheduler.start()
    except (KeyboardInterrupt, SystemExit):
        pass