import akshare as ak
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import os
import pickle
import warnings
warnings.filterwarnings('ignore')

# 设置中文字体
plt.rcParams["font.family"] = ["SimHei", ]
plt.rcParams["axes.unicode_minus"] = False  # 解决负号显示问题

# 缓存配置
CACHE_DIR = "etf_data_cache"
CACHE_EXPIRE_DAYS = 1

# 交易成本配置（模拟真实市场）
COMMISSION_RATE = 0.0005  # 佣金率（0.05%）
MIN_COMMISSION = 5  # 最低佣金（5元）
STAMP_DUTY_RATE = 0.001  # 印花税（卖出时收0.1%）


class ETFReboundStrategy:
    def __init__(self):
        self.etf_pool = {
            "sz159536": "中证2000",
            "sz159629": "中证1000",
            "sz159922": "中证500",
            "sz159919": "沪深300",
            "sz159783": "双创50"
        }
        
        # 策略参数
        self.limit_days = 2
        self.max_days = 5
        self.drop_threshold = 0.98
        self.rise_threshold = 1.01
        self.lookback_days = 3
        
        # 初始化财务数据（新增每日资产记录）
        self.total_assets = 100000  # 初始资金
        self.cash = self.total_assets  # 可用现金
        self.positions = {}  # {symbol: (买入价, 数量)}
        self.trade_records = []
        self.daily_assets = []  # 记录每日资产总值（关键修复）
        
        # 初始化缓存
        if not os.path.exists(CACHE_DIR):
            os.makedirs(CACHE_DIR)

    # ---------------------- 缓存相关方法（不变）----------------------
    def _get_cache_path(self, symbol):
        return os.path.join(CACHE_DIR, f"{symbol}.pkl")

    def _is_cache_valid(self, cache_path):
        if not os.path.exists(cache_path):
            return False
        modify_time = datetime.fromtimestamp(os.path.getmtime(cache_path))
        return (datetime.now() - modify_time).days < CACHE_EXPIRE_DAYS

    def _load_cache(self, symbol):
        cache_path = self._get_cache_path(symbol)
        if self._is_cache_valid(cache_path):
            with open(cache_path, "rb") as f:
                return pickle.load(f)
        return None

    def _save_cache(self, symbol, data):
        cache_path = self._get_cache_path(symbol)
        with open(cache_path, "wb") as f:
            pickle.dump(data, f)

    # ---------------------- 数据获取方法（不变）----------------------
    def get_etf_data(self, symbol, start_date=None, end_date=None, force_refresh=False):
        if not force_refresh:
            cached_data = self._load_cache(symbol)
            if cached_data is not None:
                df = cached_data.copy()
                if start_date:
                    df = df[df["date"] >= pd.to_datetime(start_date)]
                if end_date:
                    df = df[df["date"] <= pd.to_datetime(end_date)]
                return df.reset_index(drop=True)
        
        try:
            df = ak.fund_etf_hist_sina(symbol=symbol)
            df = df.rename(columns={
                "日期": "date",
                "开盘价": "open",
                "最高价": "high",
                "最低价": "low",
                "收盘价": "close",
                "成交量": "volume"
            })
            df["date"] = pd.to_datetime(df["date"])
            df = df[["date", "open", "high", "low", "close", "volume"]].sort_values("date")
            self._save_cache(symbol, df)
            
            if start_date:
                df = df[df["date"] >= pd.to_datetime(start_date)]
            if end_date:
                df = df[df["date"] <= pd.to_datetime(end_date)]
            return df.reset_index(drop=True)
        except Exception as e:
            print(f"获取{symbol}数据失败: {e}")
            return None

    def batch_get_etf_data(self, start_date=None, end_date=None, force_refresh=False):
        etf_data_dict = {}
        for symbol in self.etf_pool.keys():
            etf_data_dict[symbol] = self.get_etf_data(symbol, start_date, end_date, force_refresh)
        return etf_data_dict

    def get_recent_data(self, symbol, end_date, count, etf_data=None):
        if etf_data is None:
            data = self.get_etf_data(symbol, end_date=end_date)
        else:
            data = etf_data.copy()
        if data is None or len(data) == 0:
            return None
        end_date = pd.to_datetime(end_date)
        recent_data = data[data["date"] <= end_date].tail(count)
        return recent_data if len(recent_data) == count else None

    # ---------------------- 核心逻辑修复 ----------------------
    def calculate_position_value(self, etf_data_dict, current_date):
        """计算当前持仓总市值（每日更新）"""
        position_value = 0
        for symbol, (buy_price, quantity) in self.positions.items():
            etf_data = etf_data_dict.get(symbol)
            if etf_data is None:
                continue
            # 获取当前日期的收盘价
            current_data = etf_data[etf_data["date"] == pd.to_datetime(current_date)]
            if not current_data.empty:
                current_price = current_data["close"].iloc[0]
                position_value += current_price * quantity
        return position_value

    def calculate_trade_cost(self, trade_value, is_buy):
        """计算交易成本（买入仅佣金，卖出加印花税）"""
        commission = max(trade_value * COMMISSION_RATE, MIN_COMMISSION)
        stamp_duty = trade_value * STAMP_DUTY_RATE if not is_buy else 0
        return commission + stamp_duty

    def check_rebound_conditions(self, symbol, current_date, etf_data):
        """买入条件（不变）"""
        hist_data = self.get_recent_data(symbol, current_date, self.lookback_days + 1, etf_data)
        if hist_data is None:
            return False
        pre_high_max = hist_data["high"].iloc[:-1].max()
        today_data = hist_data.iloc[-1]
        return (today_data["open"] / pre_high_max < self.drop_threshold and
                today_data["close"] / today_data["open"] > self.rise_threshold)

    def check_sell_conditions(self, symbol, current_date, etf_data):
        """卖出条件（不变）"""
        if symbol not in self.positions:
            return False, ""
        buy_price, quantity = self.positions[symbol]
        # 计算持仓天数（按交易日）
        buy_date = self.trade_records[-1]["date"] if self.trade_records else current_date
        trade_dates = etf_data[etf_data["date"] >= pd.to_datetime(buy_date)]["date"].tolist()
        holding_days = trade_dates.index(pd.to_datetime(current_date)) + 1
        
        data = self.get_recent_data(symbol, current_date, 2, etf_data)
        if data is None or len(data) < 2:
            return False, ""
        yesterday_close = data["close"].iloc[-2]
        today_close = data["close"].iloc[-1]
        
        if holding_days >= self.max_days:
            return True, f"持仓{holding_days}天（强制卖出）"
        if holding_days >= self.limit_days and today_close < yesterday_close:
            return True, f"持仓{holding_days}天（反弹结束）"
        return False, ""

    def execute_trade(self, date, etf_data_dict):
        """执行交易（修复资产记录）"""
        date_str = date.strftime("%Y-%m-%d")
        current_candidates = []
        
        # 1. 筛选买入候选
        for symbol in self.etf_pool.keys():
            etf_data = etf_data_dict.get(symbol)
            if etf_data is not None and self.check_rebound_conditions(symbol, date_str, etf_data):
                current_candidates.append(symbol)
        
        # 2. 处理卖出
        for symbol in list(self.positions.keys()):
            etf_data = etf_data_dict.get(symbol)
            if etf_data is None:
                continue
            sell_flag, reason = self.check_sell_conditions(symbol, date_str, etf_data)
            if sell_flag:
                # 获取卖出价
                data = self.get_recent_data(symbol, date_str, 1, etf_data)
                if data is None:
                    continue
                sell_price = data["close"].iloc[0]
                quantity = self.positions[symbol][1]
                trade_value = sell_price * quantity
                
                # 计算交易成本
                trade_cost = self.calculate_trade_cost(trade_value, is_buy=False)
                net_proceeds = trade_value - trade_cost
                
                # 更新现金和持仓
                self.cash += net_proceeds
                del self.positions[symbol]
                
                # 记录交易
                self.trade_records.append({
                    "date": date_str,
                    "symbol": symbol,
                    "name": self.etf_pool[symbol],
                    "action": "卖出",
                    "price": sell_price,
                    "quantity": quantity,
                    "trade_value": trade_value,
                    "cost": trade_cost,
                    "net_value": net_proceeds,
                    "reason": reason
                })
                print(f"{date_str} 卖出 {self.etf_pool[symbol]} - {reason} | 成交价：{sell_price:.2f} | 净收入：{net_proceeds:.2f}")
        
        # 3. 处理买入
        if current_candidates:
            selected = sorted(current_candidates, key=lambda x: list(self.etf_pool.keys()).index(x))[0]
            if selected not in self.positions:
                etf_data = etf_data_dict.get(selected)
                data = self.get_recent_data(selected, date_str, 1, etf_data)
                if data is None:
                    return
                buy_price = data["close"].iloc[0]
                # 计算可买数量（扣减手续费后）
                max_quantity = int((self.cash * (1 - COMMISSION_RATE)) / buy_price / 100) * 100
                if max_quantity <= 0:
                    print(f"{date_str} 资金不足，无法买入 {self.etf_pool[selected]}")
                    return
                
                trade_value = buy_price * max_quantity
                trade_cost = self.calculate_trade_cost(trade_value, is_buy=True)
                total_cost = trade_value + trade_cost
                
                # 检查现金是否足够
                if total_cost > self.cash:
                    print(f"{date_str} 现金不足，无法买入 {self.etf_pool[selected]}")
                    return
                
                # 更新现金和持仓
                self.cash -= total_cost
                self.positions[selected] = (buy_price, max_quantity)
                
                # 记录交易
                self.trade_records.append({
                    "date": date_str,
                    "symbol": selected,
                    "name": self.etf_pool[selected],
                    "action": "买入",
                    "price": buy_price,
                    "quantity": max_quantity,
                    "trade_value": trade_value,
                    "cost": trade_cost,
                    "total_cost": total_cost,
                    "reason": "满足反弹条件"
                })
                print(f"{date_str} 买入 {self.etf_pool[selected]} - 满足反弹条件 | 成交价：{buy_price:.2f} | 总成本：{total_cost:.2f}")
        
        # 4. 记录当日资产总值（关键修复：每日都记录，无论是否交易）
        position_value = self.calculate_position_value(etf_data_dict, date_str)
        total_asset = self.cash + position_value
        self.daily_assets.append({
            "date": date_str,
            "cash": self.cash,
            "position_value": position_value,
            "total_asset": total_asset,
            "return_rate": (total_asset / self.total_assets - 1) * 100  # 累计收益率
        })

    def backtest(self, start_date, end_date):
        """回测主函数（修复资产记录逻辑）"""
        print(f"开始回测: {start_date} 至 {end_date}")
        etf_data_dict = self.batch_get_etf_data(start_date, end_date)
        
        # 获取所有交易日（确保时间轴完整）
        first_symbol = next(iter(self.etf_pool.keys()))
        base_data = etf_data_dict.get(first_symbol)
        if base_data is None or base_data.empty:
            print("回测失败：无有效数据")
            return pd.DataFrame()
        
        trade_dates = base_data[(base_data["date"] >= start_date) & 
                               (base_data["date"] <= end_date)]["date"].tolist()
        
        # 逐天执行交易并记录资产
        for date in trade_dates:
            self.execute_trade(date, etf_data_dict)
        
        # 回测结束：平仓所有持仓
        final_date = trade_dates[-1].strftime("%Y-%m-%d")
        for symbol in list(self.positions.keys()):
            etf_data = etf_data_dict.get(symbol)
            data = self.get_recent_data(symbol, final_date, 1, etf_data)
            if data is None:
                continue
            sell_price = data["close"].iloc[0]
            quantity = self.positions[symbol][1]
            trade_value = sell_price * quantity
            trade_cost = self.calculate_trade_cost(trade_value, is_buy=False)
            net_proceeds = trade_value - trade_cost
            
            self.cash += net_proceeds
            del self.positions[symbol]
            
            self.trade_records.append({
                "date": final_date,
                "symbol": symbol,
                "name": self.etf_pool[symbol],
                "action": "卖出",
                "price": sell_price,
                "quantity": quantity,
                "trade_value": trade_value,
                "cost": trade_cost,
                "net_value": net_proceeds,
                "reason": "回测结束平仓"
            })
            print(f"{final_date} 平仓 {self.etf_pool[symbol]} | 成交价：{sell_price:.2f} | 净收入：{net_proceeds:.2f}")
        
        # 更新最终资产记录
        final_position_value = self.calculate_position_value(etf_data_dict, final_date)
        final_total_asset = self.cash + final_position_value
        self.daily_assets.append({
            "date": final_date,
            "cash": self.cash,
            "position_value": final_position_value,
            "total_asset": final_total_asset,
            "return_rate": (final_total_asset / self.total_assets - 1) * 100
        })
        
        # 输出回测结果
        print(f"\n回测完成！")
        print(f"初始资金：{self.total_assets:.2f} 元")
        print(f"最终资产：{final_total_asset:.2f} 元")
        print(f"总收益：{final_total_asset - self.total_assets:.2f} 元")
        print(f"累计收益率：{(final_total_asset / self.total_assets - 1) * 100:.2f}%")
        
        # 转换为DataFrame方便后续处理
        return pd.DataFrame(self.trade_records), pd.DataFrame(self.daily_assets)

    def plot_results(self, daily_assets_df):
        """绘制修复后的收益率曲线"""
        if daily_assets_df.empty:
            print("无资产数据可绘制")
            return
        
        # 转换日期格式
        daily_assets_df["date"] = pd.to_datetime(daily_assets_df["date"])
        
        # 绘制资产总值曲线
        plt.figure(figsize=(14, 8))
        
        # 子图1：资产总值变化
        plt.subplot(2, 1, 1)
        plt.plot(daily_assets_df["date"], daily_assets_df["total_asset"], 
                 color="#2E86AB", linewidth=2, label="资产总值")
        plt.axhline(y=self.total_assets, color="#A23B72", linestyle="--", 
                    label=f"初始资金（{self.total_assets:.0f}元）")
        plt.title("ETF反弹策略资产总值变化", fontsize=14, fontweight="bold")
        plt.ylabel("资产总值（元）", fontsize=12)
        plt.grid(True, alpha=0.3)
        plt.legend()
        
        # 子图2：累计收益率
        plt.subplot(2, 1, 2)
        plt.plot(daily_assets_df["date"], daily_assets_df["return_rate"], 
                 color="#F18F01", linewidth=2, label="累计收益率")
        plt.axhline(y=0, color="#C73E1D", linestyle="--", label="盈亏平衡线")
        plt.title("ETF反弹策略累计收益率", fontsize=14, fontweight="bold")
        plt.xlabel("日期", fontsize=12)
        plt.ylabel("累计收益率（%）", fontsize=12)
        plt.grid(True, alpha=0.3)
        plt.legend()
        
        plt.tight_layout()
        plt.show()


if __name__ == "__main__":
    # 初始化策略并执行回测
    strategy = ETFReboundStrategy()
    # 回测时间范围：最近180天
    end_date = datetime.now().strftime("%Y-%m-%d")
    start_date = (datetime.now() - timedelta(days=1900)).strftime("%Y-%m-%d")
    trade_records_df, daily_assets_df = strategy.backtest(start_date, end_date)
    
    # 绘制收益率曲线
    strategy.plot_results(daily_assets_df)
    
    # 可选：保存结果到CSV
    # daily_assets_df.to_csv("策略每日资产记录.csv", index=False, encoding="utf-8-sig")
    # trade_records_df.to_csv("策略交易记录.csv", index=False, encoding="utf-8-sig")