import akshare as ak
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from sklearn.linear_model import LinearRegression
import warnings

# 设置中文字体
plt.rcParams["font.family"] = ["SimHei"]
warnings.filterwarnings('ignore')

# ---------------------- 参数化配置 ----------------------
START_DATE = pd.to_datetime("2022-01-01")
END_DATE = datetime.now()

# 买入条件
BUY_DAILY_RISE_LIMIT = 0.01  # 当日涨幅 < 1%
MA5_WINDOW = 5
MA20_WINDOW = 20
MA60_WINDOW = 60

# 动量权重
WEIGHT_SLOPE_R2 = 0.25
WEIGHT_RISE = 0.17
WEIGHT_VOLUME_RATIO = 0.5

# 卖出条件
SELL_RISE_THRESHOLD = 0.16  # 20日涨幅 > 16%

# 基金池子
FUND_POOL = {

    # 商品ETF
    "黄金ETF": "518880",
    "豆粕ETF": "159985",
    "石油LOF": "162719",
    # 海外市场ETF
    "纳斯达克ETF": "513300",
    "恒生ETF": "159920",
    "德国ETF": "513030",
    "日经ETF": "513520",
    
    

    
}
# --------------------------------------------------------


def get_fund_data(fund_code):
    """获取基金数据，保留原始日期索引用于校验"""
    market = "sh" if fund_code.startswith("5") else "sz"
    symbol = f"{market}{fund_code}"
    
    try:
        df = ak.fund_etf_hist_sina(symbol=symbol)
    except Exception as e:
        print(f"❌ 获取{fund_code}数据失败: {e}")
        return None
    
    # 处理日期（保留原始格式，后续统一用字符串匹配）
    df = df[["date", "open", "high", "low", "close", "volume"]].copy()
    df["date"] = pd.to_datetime(df["date"]).dt.floor("D")  # 仅保留日期
    df["date_str"] = df["date"].dt.strftime("%Y-%m-%d")  # 新增字符串日期列
    df = df[(df["date"] >= START_DATE) & (df["date"] <= END_DATE)]
    df = df[(df["volume"] > 0) & (df["close"] > 0)]  # 过滤无效数据
    df = df.sort_values("date").reset_index(drop=True)
    
    # 记录原始日期列表（用于后续校验）
    df["original_dates"] = df["date_str"].tolist()
    return df if len(df) >= MA60_WINDOW else None


def calculate_indicators(df):
    """计算指标，确保不删除任何原始日期"""
    df_copy = df.copy()
    
    # 均线指标
    df_copy["ma5"] = df_copy["close"].rolling(window=MA5_WINDOW).mean().fillna(0)
    df_copy["ma20"] = df_copy["close"].rolling(window=MA20_WINDOW).mean().fillna(0)
    df_copy["ma60"] = df_copy["close"].rolling(window=MA60_WINDOW).mean().fillna(0)
    
    # 成交金额
    df_copy["amount"] = df_copy["close"] * df_copy["volume"]
    df_copy["avg_amount5"] = df_copy["amount"].rolling(window=5).mean().fillna(0)
    df_copy["avg_amount20"] = df_copy["amount"].rolling(window=20).mean().fillna(0)
    
    # 涨幅指标
    df_copy["daily_return"] = df_copy["close"].pct_change().fillna(0)
    df_copy["rise_5d"] = df_copy["close"].pct_change(5).fillna(0)
    df_copy["rise_10d"] = df_copy["close"].pct_change(10).fillna(0)
    df_copy["rise_20d"] = df_copy["close"].pct_change(20).fillna(0)
    
    # 成交量指标
    df_copy["avg_volume5"] = df_copy["volume"].rolling(window=5).mean().fillna(0)
    df_copy["avg_volume18"] = df_copy["volume"].rolling(window=18).mean().fillna(0)
    df_copy["volume_ratio"] = df_copy["avg_volume5"] / df_copy["avg_volume18"].replace(0, 1)  # 避免除零
    
    # 25日斜率和拟合度（用0填充早期数据，不删除行）
    df_copy["slope"] = 0.0
    df_copy["r_squared"] = 0.0
    for i in range(25, len(df_copy)):
        x = np.arange(25).reshape(-1, 1)
        y = df_copy["close"].iloc[i-25:i].values
        model = LinearRegression().fit(x, y)
        df_copy["slope"].iloc[i] = model.coef_[0]
        df_copy["r_squared"].iloc[i] = model.score(x, y)
    df_copy["slope_r2"] = df_copy["slope"] * df_copy["r_squared"]
    
    return df_copy  # 保留所有原始日期，不删除任何行


def check_buy_conditions(df, index):
    """检查买入条件，增加索引范围校验"""
    if index < MA60_WINDOW or index >= len(df):  # 确保索引在有效范围内
        return False
    
    # 当日涨幅 < 1%
    if df["daily_return"].iloc[index] >= BUY_DAILY_RISE_LIMIT:
        return False
    
    # 5日均额 < 20日均额
    if df["avg_amount20"].iloc[index] == 0:
        return False
    if df["avg_amount5"].iloc[index] >= df["avg_amount20"].iloc[index]:
        return False
    
    # 均线多头排列
    ma5 = df["ma5"].iloc[index]
    ma20 = df["ma20"].iloc[index]
    ma60 = df["ma60"].iloc[index]
    if not (ma5 > ma20 > ma60 > 0):
        return False
    
    return True


def calculate_momentum(df, index):
    """计算动量值，确保索引有效"""
    if index < 25 or index >= len(df):
        return -np.inf
    
    score_slope = np.clip(df["slope_r2"].iloc[index], -10, 10) * WEIGHT_SLOPE_R2
    score_rise = np.clip((df["rise_10d"].iloc[index] + df["rise_5d"].iloc[index]), -0.5, 0.5) * WEIGHT_RISE
    score_volume = np.clip(df["volume_ratio"].iloc[index], 0, 5) * WEIGHT_VOLUME_RATIO
    
    return score_slope + score_rise + score_volume


def backtest_strategy(fund_data_dict):
    """回测策略：三重日期校验+安全索引访问"""
    initial_cash = 100000.0
    cash = initial_cash
    position = None
    history = []
    
    # 第一步：生成候选共同日期（所有基金都包含的日期）
    all_dates = [set(df["date_str"]) for df in fund_data_dict.values()]
    candidate_dates = sorted(set.intersection(*all_dates))
    if not candidate_dates:
        raise ValueError("❌ 无共同日期")
    
    # 第二步：严格校验候选日期（确保每个日期在所有基金中存在有效索引）
    valid_dates = []
    for date_str in candidate_dates:
        valid = True
        for fund_name, df in fund_data_dict.items():
            # 检查基金数据中是否真的包含该日期
            if date_str not in df["date_str"].values:
                valid = False
                print(f"⚠️ 过滤无效日期 {date_str}（{fund_name} 无此日期）")
                break
            # 检查该日期的索引是否有效（避免越界）
            idx = df.index[df["date_str"] == date_str].tolist()
            if not idx or idx[0] >= len(df):
                valid = False
                print(f"⚠️ 过滤无效日期 {date_str}（{fund_name} 索引无效）")
                break
        if valid:
            valid_dates.append(date_str)
    if not valid_dates:
        raise ValueError("❌ 无有效共同日期")
    print(f"✅ 有效回测日期数量: {len(valid_dates)}")
    
    # 转换为datetime格式
    valid_dates_dt = [pd.to_datetime(d) for d in valid_dates]
    
    # 初始化记录
    history.append({
        "date": valid_dates_dt[0],
        "action": "初始",
        "fund": None,
        "price": None,
        "shares": 0,
        "cash": cash,
        "assets": cash
    })
    
    # 遍历每个有效日期
    for i in range(1, len(valid_dates)):
        date_str = valid_dates[i]
        date = valid_dates_dt[i]
        qualified_funds = []
        
        # 1. 筛选符合条件的基金（安全获取索引）
        for fund_name, df in fund_data_dict.items():
            # 安全获取索引（先转为列表，检查长度）
            idx_list = df.index[df["date_str"] == date_str].tolist()
            if not idx_list:
                print(f"⚠️ {fund_name} 缺失 {date_str} 索引，跳过")
                continue
            idx = idx_list[0]  # 确保列表非空再取索引
            
            if check_buy_conditions(df, idx):
                momentum = calculate_momentum(df, idx)
                qualified_funds.append((fund_name, momentum, idx))
        
        # 2. 计算当前资产
        current_assets = cash
        if position:
            hold_name, hold_shares, hold_price, _ = position
            hold_df = fund_data_dict[hold_name]
            hold_idx_list = hold_df.index[hold_df["date_str"] == date_str].tolist()
            if not hold_idx_list:
                print(f"⚠️ 持仓 {hold_name} 缺失 {date_str} 索引，资产计算异常")
                continue
            hold_idx = hold_idx_list[0]
            hold_current_price = hold_df["close"].iloc[hold_idx]
            current_assets += hold_shares * hold_current_price
        
        # 3. 确定调仓需求
        need_switch = False
        best_fund = None
        best_momentum = -np.inf
        best_idx = None
        
        if qualified_funds:
            qualified_funds.sort(key=lambda x: x[1], reverse=True)
            best_fund, best_momentum, best_idx = qualified_funds[0]
            
            if position:
                hold_name = position[0]
                hold_in_qualified = any(f[0] == hold_name for f in qualified_funds)
                
                if not hold_in_qualified:
                    need_switch = True
                else:
                    hold_momentum = next(m[1] for m in qualified_funds if m[0] == hold_name)
                    if best_momentum > hold_momentum:
                        need_switch = True
        
        # 4. 检查卖出条件
        if position and not need_switch:
            hold_name, hold_shares, hold_price, _ = position
            hold_df = fund_data_dict[hold_name]
            hold_idx_list = hold_df.index[hold_df["date_str"] == date_str].tolist()
            if hold_idx_list:
                hold_idx = hold_idx_list[0]
                if hold_df["rise_20d"].iloc[hold_idx] >= SELL_RISE_THRESHOLD:
                    need_switch = True
        
        # 5. 执行卖出
        if need_switch and position:
            hold_name, hold_shares, hold_price, _ = position
            hold_df = fund_data_dict[hold_name]
            hold_idx_list = hold_df.index[hold_df["date_str"] == date_str].tolist()
            if not hold_idx_list:
                print(f"⚠️ 无法卖出 {hold_name}（{date_str} 索引缺失）")
                continue
            hold_idx = hold_idx_list[0]
            sell_price = hold_df["close"].iloc[hold_idx]
            cash += hold_shares * sell_price
            
            history.append({
                "date": date,
                "action": "卖出",
                "fund": hold_name,
                "price": sell_price,
                "shares": hold_shares,
                "cash": cash,
                "assets": cash
            })
            position = None
        
        # 6. 买入最佳基金
        if position is None and qualified_funds:
            best_df = fund_data_dict[best_fund]
            buy_price = best_df["close"].iloc[best_idx]
            
            if buy_price > 0 and cash > 0:
                buy_shares = int(cash / buy_price)
                if buy_shares > 0:
                    cash -= buy_shares * buy_price
                    position = (best_fund, buy_shares, buy_price, date)
                    current_assets = cash + buy_shares * buy_price
                    history.append({
                        "date": date,
                        "action": "买入",
                        "fund": best_fund,
                        "price": buy_price,
                        "shares": buy_shares,
                        "cash": cash,
                        "assets": current_assets
                    })
        
        # 7. 记录当日持仓
        if not any(h["date"] == date for h in history):
            history.append({
                "date": date,
                "action": "持仓",
                "fund": position[0] if position else None,
                "price": None,
                "shares": position[1] if position else 0,
                "cash": cash,
                "assets": current_assets
            })
    
    history_df = pd.DataFrame(history).sort_values("date").reset_index(drop=True)
    return history_df


def evaluate_strategy(history_df):
    """评估策略表现"""
    if len(history_df) < 2:
        return None
    
    initial_assets = history_df["assets"].iloc[0]
    final_assets = history_df["assets"].iloc[-1]
    total_return = (final_assets - initial_assets) / initial_assets * 100
    
    start_date = history_df["date"].iloc[0]
    end_date = history_df["date"].iloc[-1]
    days = (end_date - start_date).days
    years = days / 365.25
    annual_return = (pow(final_assets / initial_assets, 1/years) - 1) * 100 if years > 0 else 0
    
    buy_count = len(history_df[history_df["action"] == "买入"])
    sell_count = len(history_df[history_df["action"] == "卖出"])
    turnover_rate = (sell_count / len(history_df)) * 100 if len(history_df) > 0 else 0
    
    trades = []
    for _, sell_row in history_df[history_df["action"] == "卖出"].iterrows():
        buy_rows = history_df[
            (history_df["fund"] == sell_row["fund"]) &
            (history_df["action"] == "买入") &
            (history_df["date"] < sell_row["date"])
        ]
        if not buy_rows.empty:
            buy_row = buy_rows.iloc[-1]
            profit_ratio = (sell_row["price"] - buy_row["price"]) / buy_row["price"]
            trades.append(profit_ratio)
    
    win_rate = (len([p for p in trades if p > 0]) / len(trades) * 100) if trades else 0
    avg_profit = (np.mean(trades) * 100) if trades else 0
    
    cumulative_max = history_df["assets"].cummax()
    drawdown = (history_df["assets"] - cumulative_max) / cumulative_max
    max_drawdown = drawdown.min() * 100 if len(drawdown) > 0 else 0
    
    print("="*60)
    print("策略表现指标（2020年至今）")
    print("="*60)
    print(f"初始资金: {initial_assets:.2f} 元")
    print(f"最终资金: {final_assets:.2f} 元")
    print(f"总收益率: {total_return:.2f}%")
    print(f"年化收益率: {annual_return:.2f}%")
    print(f"交易频率: 买入 {buy_count} 次, 卖出 {sell_count} 次 (日均换手率: {turnover_rate:.2f}%)")
    print(f"胜率: {win_rate:.2f}%")
    print(f"平均每次交易收益率: {avg_profit:.2f}%")
    print(f"最大回撤: {max_drawdown:.2f}%")
    print("="*60)
    
    return {
        "total_return": total_return,
        "annual_return": annual_return,
        "turnover_rate": turnover_rate,
        "win_rate": win_rate,
        "max_drawdown": max_drawdown
    }


def plot_return_curve(history_df):
    """绘制收益率曲线"""
    plt.figure(figsize=(12, 6))
    plt.plot(history_df["date"], history_df["assets"], label="策略资产", linewidth=2)
    
    buy_points = history_df[history_df["action"] == "买入"]
    sell_points = history_df[history_df["action"] == "卖出"]
    plt.scatter(buy_points["date"], buy_points["assets"], color="red", marker="^", label="买入")
    plt.scatter(sell_points["date"], sell_points["assets"], color="green", marker="v", label="卖出")
    
    plt.title("2020年至今策略收益率曲线")
    plt.xlabel("日期")
    plt.ylabel("资产价值 (元)")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.gcf().autofmt_xdate()
    plt.tight_layout()
    plt.show()


def main():
    print("正在获取2020年至今的基金数据...")
    fund_data_dict = {}
    
    for fund_name, fund_code in FUND_POOL.items():
        df = get_fund_data(fund_code)
        if df is None:
            print(f"❌ 跳过 {fund_name} ({fund_code}): 数据不足")
            continue
        df_with_indicators = calculate_indicators(df)
        fund_data_dict[fund_name] = df_with_indicators
        print(f"✅ 已加载 {fund_name} ({fund_code}): {len(df)} 条记录")
    
    if len(fund_data_dict) < 2:
        print("❌ 有效基金不足2只，无法回测")
        return
    
    print("\n开始策略回测（2020年至今）...")
    try:
        history_df = backtest_strategy(fund_data_dict)
    except Exception as e:
        print(f"回测失败: {str(e)}")
        return
    
    evaluate_strategy(history_df)
    plot_return_curve(history_df)


if __name__ == "__main__":
    main()