from jqdatasdk import *
from datetime import datetime as dt
import pandas as pd
import pymysql
from sqlalchemy import create_engine
import time
from xtquant.xttrader import XtQuantTrader, XtQuantTraderCallback
from xtquant.xttype import StockAccount
from xtquant import xtconstant
from xtquant import xtdata



auth('18616891214', 'Ea?*7f68nD.dafcW34d!')
# auth('18521506014', 'Abc123!@#')
#启动交易系统
path = 'c:\\qmt\\userdata_mini'
# session_id为会话编号,策略使用方对于不同的Python策略需要使用不同的会话编号
session_id = 20221123

# connect_result = xt_trader.connect()
# if connect_result == 0:
#     print('QMTmini 已连接')
# else:
#     print('连接失败')

# account = StockAccount('920000207040', 'SECURITY')  # xt_trader为XtQuant API实例对象
# positions = xt_trader.query_stock_positions(account)




fre = '1d'

engine_hlfx_pool = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/hlfx_pool?charset=utf8')
# engine_stock = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/stocks?charset=utf8')

db_pool = pymysql.connect(host='localhost',
                              user='root',
                              port=3307,
                              password='r6kEwqWU9!v3',
                              database='hlfx_pool')
cursor_pool = db_pool.cursor()



fut = locals()
print(dt.now(), '开始寻找MA5趋势!')

def real_price(datas):
    return datas


def XtTrader(new_keep_stock):
    # 获取账号信息
    # account = StockAccount('888824600221', 'CREDIT') #xt_trader为XtQuant API实例对象
    account = StockAccount('920000207040', 'SECURITY')  # xt_trader为XtQuant API实例对象
    # print('acc:', account.account_type, account.account_id)
    # print("query asset:")
    asset = xt_trader.query_stock_asset(account)
    xtdata.subscribe_whole_quote(new_keep_stock, callback=real_price)
    positions = xt_trader.query_stock_positions(account)

    if asset:
        print("asset:")
        print(asset.account_type, asset.account_id, asset.cash, asset.frozen_cash, asset.market_value,
              asset.total_asset)
    # 开始交易
    for i in new_keep_stock:
        print(i)
        # price = get_bars(i.replace('SH', 'XSHG').replace('SZ', 'XSHE'), count=1, unit=fre, fields=['close'],
        #                  include_now=True).iloc[-1].at['close']
        datas = real_price()
        price = datas[i]['lastPrice']
        print('price:', price)
        exit()
        print(asset.cash / price)
        if asset.cash > 2000:
            volume = int((asset.cash / 2 / price) // 100 * 100)
            print('volume:', volume)
            order_id = xt_trader.order_stock(account, i, xtconstant.STOCK_BUY, volume, xtconstant.LATEST_PRICE, price, 'strategy1', 'order_test')
            print(order_id)
    # for i in positions:
    #     price = get_bars(i.replace('SH', 'XSHG').replace('SZ', 'XSHE'), count=1, unit=fre, fields=['close'],
    #                      include_now=True).iloc[-1].at['close']
    #     if Sell_Trader(i):
    #         print('yao maihu de gupiao !!!!!!!!!!!', i.stock_code)
    #         order_id = xt_trader.order_stock(account, i, xtconstant.STOCK_SELL,
    #                                              1000, xtconstant.FIX_PRICE, 10.90, 'strategy1', 'order_test')
    # print(positions[1].stock_code, positions[1].volume)

    print('今日成交:')
    for trades in xt_trader.query_stock_trades(account):
        print(trades.stock_code, trades.traded_volume, trades.traded_price)

    positions = xt_trader.query_stock_positions(account)
    print("positions:", len(positions))
    if len(positions) != 0:
        print("last position:")
        print("{0} {1} {2}".format(positions[-1].account_id, positions[-1].stock_code, positions[-1].volume))

    print(positions)
    xt_trader.stop()

def Sell_Trader(stock, account, positions, volume):
    price = get_bars(stock, count=1, unit=fre, fields=['close'],
                     include_now=True).iloc[-1].at['close']
    print(type(stock.replace('XSHG', 'SH').replace('XSHE', 'SZ')),stock.replace('XSHG', 'SH').replace('XSHE', 'SZ') )
    order_id = xt_trader.order_stock(account, stock.replace('XSHG', 'SH').replace('XSHE', 'SZ'), xtconstant.STOCK_SELL,
                                     volume, xtconstant.LATEST_PRICE, 0,  'strategy1', 'order_test')
    print(order_id, i)


while True:
    # print('进入循环')
    stocks = xtdata.get_stock_list_in_sector('沪深A股')
    now_date = dt.now()
    date_morning_begin = now_date.replace(hour=9, minute=25, second=0)
    date_morning_end = now_date.replace(hour=11, minute=31, second=0)
    date_afternooe_begin = now_date.replace(hour=13, minute=0, second=0)
    date_afternooe_end = now_date.replace(hour=15, minute=0, second=0)
    if True:
    # if date_morning_begin < now_date < date_morning_end or date_afternooe_begin < now_date < date_afternooe_end:
        # time.sleep(1800)
        # 后续的所有示例将使用该实例对象
        xt_trader = XtQuantTrader(path, session_id)
        xt_trader.start()
        connect_result = xt_trader.connect()
        xtdata.subscribe_whole_quote(stocks, callback=real_price)
        try:
            if connect_result == 0:
                print('QMTmini 已连接')
            else:
                print('连接失败')
            account = StockAccount('920000207040', 'SECURITY')  # xt_trader为XtQuant API实例对象
            positions = xt_trader.query_stock_positions(account)
            print(positions)
        except BaseException:
            continue
        for i in positions:
            # print(i.stock_code, i.volume)
            volume = i.volume
            stock = i.stock_code.replace('SH', 'XSHG').replace('SZ', 'XSHE')
            df_stock = get_bars(stock, count=60, unit=fre,
                                     fields=['date', 'open', 'close', 'high', 'low', 'volume'],
                                     include_now=True, df=True)

            # price = df_stock.iloc[-1].at['close']

            datas = real_price()
            price = datas[i.stock_code]['lastPrice']

            MA5_1 = df_stock['close'][-7:-2].mean()
            MA5 = df_stock['close'][-6:-1].mean()
            MA10 = df_stock['close'][-11:-1].mean()
            MA20 = df_stock['close'][-21:-1].mean()
            if price < MA5 or MA5 < MA5_1 or price > MA5 * 1.12:
                print(MA5, MA5_1)
                Sell_Trader(stock, account, positions, volume)


        for fre in ['1d']:
            print('开始:', fre)
            results = []
            try:
                stock_pool = pd.read_sql_query(
                    'select value from `%s`' % fre, engine_hlfx_pool)
                stock_pool = stock_pool.iloc[-1, 0].split(",")
                print(stock_pool)
            except BaseException:
                continue
            for stock in stock_pool:
                # print(stock)
                try:
                    df_stock = get_bars(stock, count=60, unit=fre, fields=['date', 'open', 'close', 'high', 'low','volume'],
                                           include_now=True, df=True)
                    # print('time=', df_stock.iloc[-1].at['date'])
                    # price = df_stock.iloc[-1].at['close']


                    datas = real_price()
                    price = datas[i]['lastPrice']

                    price_open = df_stock.iloc[-1].at['open']
                    MA5_1 = df_stock['close'][-7:-2].mean()
                    MA5 = df_stock['close'][-6:-1].mean()
                    MA10 = df_stock['close'][-11:-1].mean()
                    MA20 = df_stock['close'][-21:-1].mean()
                    # print(price,price_open, 'ma5_1:',MA5_1, 'ma5:', MA5,MA10)
                    if (price > price_open) & (price > MA5) & (MA5 > MA5_1) & (price < MA5 * 1.03) & (MA20 < MA10)  \
                            & (df_stock.iloc[-1].at['volume'] > df_stock.iloc[-2].at['volume']):
                        print(stock)
                        results.append(stock)
                    elif price < MA5 or MA5<MA5_1 or price > MA5*1.09:
                        stock_pool.remove(stock)
                        print(stock, '已失败!')
                except BaseException:
                    continue
            results = list(set(results))
            print(results)
            now_time = dt.now().strftime('%Y-%m-%d %H:%M:%S')
            # results_list =','.join(results)
            # print(fre, '\n', results_list)



            if len(results) == 0:
                continue
            else:
                num_industry = get_industry(results)
                print(num_industry)
                industry_list = []
                for key in num_industry.values():
                    for key2 in key.values():
                        industry_list.append(key2['industry_name'])
                industry_list = pd.value_counts(industry_list)
                # 最热集中的n个板块
                max_industry_list = list(industry_list[0:3].index)
                results_industry = []
                for key, value in num_industry.items():
                    for key2 in value.values():
                        if key2['industry_name'] in max_industry_list:
                            results_industry.append(key)
                print('suoyou:', set(results_industry))
                results_industry = ','.join(set(results_industry))
                print(fre, '\n', results_industry)

                sql = "INSERT INTO MA5_%s (date,value) VALUES('%s','%s')" % (fre, dt.now().strftime('%Y-%m-%d %H:%M:%S'),
                                                                             results_industry)
                cursor_pool.execute(sql)
                db_pool.commit()

                print(len(results_industry), results_industry)
                print(dt.now(), '数据库数据已赋值!')

                # 取值交易
                engine_hlfx_pool = create_engine('mysql+pymysql://root:r6kEwqWU9!v3@localhost:3307/hlfx_pool?charset=utf8')

                # stocks = xtdata.get_stock_list_in_sector('沪深A股')
                keep_stocks = pd.read_sql_query(
                    'select value from `MA5_%s`' % fre, engine_hlfx_pool)
                keep_stocks = keep_stocks.iloc[-1, 0].split(",")
                new_keep_stock = [stock.replace('XSHG', 'SH').replace('XSHE', 'SZ') for stock in keep_stocks]
                print(new_keep_stock)
                price = get_bars(keep_stocks, count=1, unit=fre, fields=['close'])

                XtTrader(new_keep_stock)
        xt_trader.run_forever()
        time.sleep(1800)
    elif now_date > date_afternooe_end:
        pass
        # print("MA5_收盘了", now_date)
        # break