#!/usr/bin/env python3
"""
Track signal performance and accuracy.

For each signal:
1. Record entry time and price
2. Monitor price movement
3. Detect if target/stop was hit
4. Calculate win rate and performance metrics
"""

import json
import sqlite3
from pathlib import Path
from datetime import datetime, timedelta
import statistics


class SignalTracker:
    """Track signal outcomes and calculate accuracy."""

    def __init__(self, db_path: Path):
        self.db_path = db_path
        self.create_database()

    def create_database(self):
        """Create SQLite database for tracking signals."""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        # Signals table
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS signals (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                asset TEXT NOT NULL,
                signal_type TEXT NOT NULL,
                direction TEXT NOT NULL,
                entry_price REAL NOT NULL,
                target_price REAL NOT NULL,
                stop_price REAL NOT NULL,
                confidence REAL NOT NULL,
                risk_reward REAL NOT NULL,
                entry_time TEXT NOT NULL,
                cfd_entry REAL,
                cfd_target REAL,
                cfd_stop REAL,
                status TEXT DEFAULT 'active',
                outcome_time TEXT,
                outcome_price REAL,
                outcome TEXT,
                profit_loss REAL,
                max_favorable REAL,
                max_adverse REAL,
                bars_held INTEGER,
                created_at TEXT NOT NULL
            )
        ''')

        # Price checkpoints table (track price over time)
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS price_checkpoints (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                signal_id INTEGER NOT NULL,
                timestamp TEXT NOT NULL,
                price REAL NOT NULL,
                FOREIGN KEY (signal_id) REFERENCES signals(id)
            )
        ''')

        # Performance summary table
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS performance_summary (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                asset TEXT NOT NULL,
                signal_type TEXT,
                direction TEXT,
                total_signals INTEGER DEFAULT 0,
                wins INTEGER DEFAULT 0,
                losses INTEGER DEFAULT 0,
                win_rate REAL DEFAULT 0,
                avg_profit REAL,
                avg_loss REAL,
                avg_hold_time INTEGER,
                best_trade REAL,
                worst_trade REAL,
                max_drawdown REAL,
                last_updated TEXT NOT NULL,
                UNIQUE(asset, signal_type, direction)
            )
        ''')

        # Indexes
        cursor.execute('CREATE INDEX IF NOT EXISTS idx_signals_asset_time ON signals(asset, entry_time)')
        cursor.execute('CREATE INDEX IF NOT EXISTS idx_signals_status ON signals(status)')
        cursor.execute('CREATE INDEX IF NOT EXISTS idx_checkpoints_signal ON price_checkpoints(signal_id)')

        conn.commit()
        conn.close()

    def add_signal(self, signal_data: dict, asset: str):
        """Add a new signal to track."""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        cursor.execute('''
            INSERT INTO signals (
                asset, signal_type, direction,
                entry_price, target_price, stop_price,
                confidence, risk_reward, entry_time,
                cfd_entry, cfd_target, cfd_stop,
                status, created_at
            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        ''', (
            asset,
            signal_data['type'],
            signal_data['direction'],
            signal_data['entry_price'],
            signal_data['target_price'],
            signal_data['stop_price'],
            signal_data['confidence'],
            signal_data['risk_reward'],
            datetime.now().isoformat(),
            signal_data.get('cfd_entry'),
            signal_data.get('cfd_target'),
            signal_data.get('cfd_stop'),
            'active',
            datetime.now().isoformat()
        ))

        conn.commit()
        conn.close()

        return cursor.lastrowid

    def add_price_checkpoint(self, signal_id: int, price: float):
        """Add a price checkpoint for a signal."""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        cursor.execute('''
            INSERT INTO price_checkpoints (signal_id, timestamp, price)
            VALUES (?, ?, ?)
        ''', (signal_id, datetime.now().isoformat(), price))

        conn.commit()
        conn.close()

    def check_signal_outcome(self, signal_id: int, current_price: float) -> dict:
        """Check if signal hit target or stop."""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        cursor.execute('''
            SELECT entry_price, target_price, stop_price, direction, entry_time, status
            FROM signals WHERE id = ?
        ''', (signal_id,))

        row = cursor.fetchone()
        if not row:
            conn.close()
            return {'status': 'not_found'}

        entry_price, target_price, stop_price, direction, entry_time, status = row

        # Skip if already closed
        if status != 'active':
            conn.close()
            return {'status': status}

        # Check if target or stop hit
        outcome = None
        outcome_price = current_price

        if direction == 'bullish':
            if current_price >= target_price:
                outcome = 'win'
            elif current_price <= stop_price:
                outcome = 'loss'
        else:  # bearish
            if current_price <= target_price:
                outcome = 'win'
            elif current_price >= stop_price:
                outcome = 'loss'

        if outcome:
            # Calculate profit/loss
            if direction == 'bullish':
                profit_loss = current_price - entry_price
            else:
                profit_loss = entry_price - current_price

            # Update signal
            cursor.execute('''
                UPDATE signals
                SET outcome = ?, outcome_time = ?, outcome_price = ?,
                    profit_loss = ?, status = 'closed'
                WHERE id = ?
            ''', (outcome, datetime.now().isoformat(), current_price, profit_loss, signal_id))

            conn.commit()
            conn.close()

            return {'status': 'closed', 'outcome': outcome, 'profit_loss': profit_loss}

        conn.close()
        return {'status': 'active'}

    def update_performance_summary(self):
        """Update performance summary metrics."""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        # Group by asset, signal_type, direction
        cursor.execute('''
            SELECT asset, signal_type, direction,
                   COUNT(*) as total,
                   SUM(CASE WHEN outcome = 'win' THEN 1 ELSE 0 END) as wins,
                   SUM(CASE WHEN outcome = 'loss' THEN 1 ELSE 0 END) as losses,
                   AVG(CASE WHEN outcome = 'win' THEN profit_loss END) as avg_profit,
                   AVG(CASE WHEN outcome = 'loss' THEN profit_loss END) as avg_loss,
                   MAX(profit_loss) as best_trade,
                   MIN(profit_loss) as worst_trade
            FROM signals
            WHERE status = 'closed'
            GROUP BY asset, signal_type, direction
        ''')

        rows = cursor.fetchall()

        for row in rows:
            asset, signal_type, direction, total, wins, losses, avg_profit, avg_loss, best_trade, worst_trade = row

            win_rate = (wins / total * 100) if total > 0 else 0

            cursor.execute('''
                INSERT OR REPLACE INTO performance_summary
                (asset, signal_type, direction, total_signals, wins, losses,
                 win_rate, avg_profit, avg_loss, best_trade, worst_trade, last_updated)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            ''', (asset, signal_type, direction, total, wins, losses,
                  win_rate, avg_profit, avg_loss, best_trade, worst_trade,
                  datetime.now().isoformat()))

        conn.commit()
        conn.close()

    def get_performance_report(self) -> dict:
        """Get overall performance report."""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        # Overall stats
        cursor.execute('''
            SELECT
                COUNT(*) as total_signals,
                SUM(CASE WHEN outcome = 'win' THEN 1 ELSE 0 END) as total_wins,
                SUM(CASE WHEN outcome = 'loss' THEN 1 ELSE 0 END) as total_losses,
                AVG(CASE WHEN outcome = 'win' THEN profit_loss END) as avg_win,
                AVG(CASE WHEN outcome = 'loss' THEN profit_loss END) as avg_loss,
                SUM(profit_loss) as total_profit_loss
            FROM signals
            WHERE status = 'closed'
        ''')

        row = cursor.fetchone()
        if not row or row[0] == 0:
            conn.close()
            return {'message': 'No closed signals yet'}

        total_signals, total_wins, total_losses, avg_win, avg_loss, total_profit_loss = row

        win_rate = (total_wins / total_signals * 100) if total_signals > 0 else 0

        report = {
            'total_signals': total_signals,
            'wins': total_wins,
            'losses': total_losses,
            'win_rate': round(win_rate, 2),
            'avg_win': round(avg_win, 2) if avg_win else 0,
            'avg_loss': round(avg_loss, 2) if avg_loss else 0,
            'total_profit_loss': round(total_profit_loss, 2) if total_profit_loss else 0,
            'profit_factor': round(abs(avg_win / avg_loss), 2) if avg_loss and avg_win else 0
        }

        # By asset
        cursor.execute('''
            SELECT asset,
                   COUNT(*) as total,
                   SUM(CASE WHEN outcome = 'win' THEN 1 ELSE 0 END) as wins,
                   AVG(CASE WHEN outcome = 'win' THEN profit_loss END) as avg_profit,
                   AVG(CASE WHEN outcome = 'loss' THEN profit_loss END) as avg_loss
            FROM signals
            WHERE status = 'closed'
            GROUP BY asset
        ''')

        report['by_asset'] = []
        for row in cursor.fetchall():
            asset, total, wins, avg_profit, avg_loss = row
            win_rate = (wins / total * 100) if total > 0 else 0

            report['by_asset'].append({
                'asset': asset,
                'total': total,
                'wins': wins,
                'win_rate': round(win_rate, 2),
                'avg_profit': round(avg_profit, 2) if avg_profit else 0,
                'avg_loss': round(avg_loss, 2) if avg_loss else 0
            })

        # By signal type
        cursor.execute('''
            SELECT signal_type,
                   COUNT(*) as total,
                   SUM(CASE WHEN outcome = 'win' THEN 1 ELSE 0 END) as wins
            FROM signals
            WHERE status = 'closed'
            GROUP BY signal_type
        ''')

        report['by_type'] = []
        for row in cursor.fetchall():
            signal_type, total, wins = row
            win_rate = (wins / total * 100) if total > 0 else 0

            report['by_type'].append({
                'type': signal_type,
                'total': total,
                'wins': wins,
                'win_rate': round(win_rate, 2)
            })

        conn.close()
        return report


def main():
    """Demo signal tracking."""

    print("="*80)
    print("SIGNAL PERFORMANCE TRACKER")
    print("="*80)

    db_path = Path('/home/ubuntu/.hermes/workspace/projects/ORDER_FLOW_GRAPH/data/signal_performance.db')
    tracker = SignalTracker(db_path)

    print(f"\n✓ Database: {db_path}")
    print("✓ Tables: signals, price_checkpoints, performance_summary")

    # Get latest signals
    assets = ['btcusdt', 'ethusdt', 'xautusdt']

    for asset in assets:
        signal_file = Path(f'/home/ubuntu/.hermes/workspace/projects/ORDER_FLOW_GRAPH/outputs/data/realtime_summary_{asset}.json')

        if not signal_file.exists():
            continue

        with open(signal_file) as f:
            data = json.load(f)

        if 'best_setup' not in data:
            continue

        setup = data['best_setup']

        # Check if signal already exists (avoid duplicates)
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        cursor.execute('''
            SELECT COUNT(*) FROM signals
            WHERE asset = ? AND entry_price = ? AND status = 'active'
            AND datetime(entry_time) > datetime('now', '-5 minutes')
        ''', (asset, setup['entry_price']))

        count = cursor.fetchone()[0]
        conn.close()

        if count > 0:
            print(f"\n{asset.upper()}: Signal already tracked")
            continue

        # Add signal
        signal_id = tracker.add_signal(setup, asset)
        print(f"\n{asset.upper()}: Added signal #{signal_id}")
        print(f"  {setup['type']} - {setup['direction'].upper()}")
        print(f"  Entry: ${setup['entry_price']:.2f}")
        print(f"  Target: ${setup['target_price']:.2f}")
        print(f"  Stop: ${setup['stop_price']:.2f}")

    # Get performance report
    print("\n" + "="*80)
    print("PERFORMANCE REPORT")
    print("="*80)

    report = tracker.get_performance_report()

    if 'message' in report:
        print(f"\n{report['message']}")
        print("\n📊 Track new signals to see accuracy metrics!")
    else:
        print(f"\nTotal Signals: {report['total_signals']}")
        print(f"Wins: {report['wins']} | Losses: {report['losses']}")
        print(f"Win Rate: {report['win_rate']}%")
        print(f"Total P/L: ${report['total_profit_loss']:.2f}")

        if report['by_asset']:
            print("\nBy Asset:")
            for asset in report['by_asset']:
                print(f"  {asset['asset'].upper()}: {asset['win_rate']}% win rate ({asset['wins']}/{asset['total']})")

        if report['by_type']:
            print("\nBy Type:")
            for sig_type in report['by_type']:
                print(f"  {sig_type['type']}: {sig_type['win_rate']}% win rate ({sig_type['wins']}/{sig_type['total']})")

    print("\n" + "="*80)
    print("✓ Signal tracking initialized!")
    print("="*80)
    print("\nNext steps:")
    print("1. Run signal checker every 1-5 minutes to check outcomes")
    print("2. Update performance summary after each check")
    print("3. Display accuracy in UI")
    print("="*80)


if __name__ == "__main__":
    main()
