#!/usr/bin/env python3
"""
Build order flow graph and generate signals for all assets.
"""

import json
import sqlite3
import statistics
from pathlib import Path
from datetime import datetime
from typing import Dict, List


class MultiAssetSignalGenerator:
    """Generate signals for multiple assets."""

    def __init__(self, db_path: Path, output_dir: Path):
        self.db_path = db_path
        self.output_dir = output_dir

    def get_available_assets(self) -> List[str]:
        """Get list of assets with data."""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        cursor.execute("""
            SELECT DISTINCT symbol
            FROM order_book_snapshots
            ORDER BY symbol
        """)

        assets = [row[0] for row in cursor.fetchall()]
        conn.close()

        return assets

    def load_snapshots(self, symbol: str, limit: int = 100) -> List[Dict]:
        """Load recent snapshots for a symbol."""
        conn = sqlite3.connect(self.db_path)
        conn.row_factory = sqlite3.Row
        cursor = conn.cursor()

        cursor.execute("""
            SELECT * FROM order_book_snapshots
            WHERE symbol = ?
            ORDER BY timestamp DESC
            LIMIT ?
        """, (symbol, limit))

        rows = cursor.fetchall()

        snapshots = []
        for row in rows:
            snapshot_id = row['id']

            # Get bid levels
            cursor.execute("""
                SELECT price, volume FROM order_book_levels
                WHERE snapshot_id = ? AND side = 'bid'
                ORDER BY price DESC
            """, (snapshot_id,))

            bid_levels = {row[0]: row[1] for row in cursor.fetchall()}

            # Get ask levels
            cursor.execute("""
                SELECT price, volume FROM order_book_levels
                WHERE snapshot_id = ? AND side = 'ask'
                ORDER BY price ASC
            """, (snapshot_id,))

            ask_levels = {row[0]: row[1] for row in cursor.fetchall()}

            snapshot = {
                'id': row['id'],
                'symbol': row['symbol'],
                'timestamp': row['timestamp'],
                'bid_levels': bid_levels,
                'ask_levels': ask_levels,
                'total_bid_volume': row['total_bid_volume'],
                'total_ask_volume': row['total_ask_volume']
            }

            snapshots.append(snapshot)

        conn.close()

        # Sort by timestamp ascending
        snapshots.sort(key=lambda x: x['timestamp'])

        return snapshots

    def compute_price_level_metrics(self, snapshots: List[Dict], price_precision: int = 2) -> Dict:
        """Compute order flow metrics for each price level."""
        from collections import defaultdict

        price_metrics = defaultdict(lambda: {
            'bid_volume': 0.0,
            'ask_volume': 0.0,
            'bid_count': 0,
            'ask_count': 0,
            'snapshots_seen': 0,
            'first_seen': None,
            'last_seen': None
        })

        for snapshot in snapshots:
            timestamp = snapshot['timestamp']

            # Process bid levels
            for price, volume in snapshot['bid_levels'].items():
                rounded_price = round(price, price_precision)

                price_metrics[rounded_price]['bid_volume'] += volume
                price_metrics[rounded_price]['bid_count'] += 1
                price_metrics[rounded_price]['snapshots_seen'] += 1

                if price_metrics[rounded_price]['first_seen'] is None:
                    price_metrics[rounded_price]['first_seen'] = timestamp
                price_metrics[rounded_price]['last_seen'] = timestamp

            # Process ask levels
            for price, volume in snapshot['ask_levels'].items():
                rounded_price = round(price, price_precision)

                price_metrics[rounded_price]['ask_volume'] += volume
                price_metrics[rounded_price]['ask_count'] += 1
                price_metrics[rounded_price]['snapshots_seen'] += 1

                if price_metrics[rounded_price]['first_seen'] is None:
                    price_metrics[rounded_price]['first_seen'] = timestamp
                price_metrics[rounded_price]['last_seen'] = timestamp

        # Compute derived metrics
        levels_data = {}
        for price, metrics in price_metrics.items():
            total_volume = metrics['bid_volume'] + metrics['ask_volume']

            if total_volume > 0:
                imbalance = (metrics['bid_volume'] - metrics['ask_volume']) / total_volume
            else:
                imbalance = 0.0

            wall_strength = max(metrics['bid_volume'], metrics['ask_volume'])
            persistence = metrics['snapshots_seen'] / len(snapshots) if snapshots else 0

            levels_data[price] = {
                'price': price,
                'bid_volume': metrics['bid_volume'],
                'ask_volume': metrics['ask_volume'],
                'total_volume': total_volume,
                'delta': metrics['ask_volume'] - metrics['bid_volume'],
                'imbalance': imbalance,
                'wall_strength': wall_strength,
                'persistence': persistence
            }

        return levels_data

    def detect_anomalies(self, levels_data: Dict) -> List[Dict]:
        """Detect anomalies in price levels."""
        anomalies = []

        if not levels_data:
            return anomalies

        volumes = [m['total_volume'] for m in levels_data.values()]
        walls = [m['wall_strength'] for m in levels_data.values()]

        avg_volume = statistics.mean(volumes) if volumes else 0
        avg_wall = statistics.mean(walls) if walls else 0

        for price, metrics in sorted(levels_data.items()):
            # Absorption detection
            if metrics['wall_strength'] >= avg_wall * 2.0 and metrics['persistence'] > 0.5:
                if metrics['bid_volume'] > metrics['ask_volume']:
                    direction = "bullish"
                else:
                    direction = "bearish"

                severity = min(metrics['wall_strength'] / (avg_wall * 5.0), 1.0)

                anomalies.append({
                    'type': 'ABSORPTION',
                    'direction': direction,
                    'severity': severity,
                    'price': price,
                    'metrics': metrics,
                    'reason': f"{'Bid' if metrics['bid_volume'] > metrics['ask_volume'] else 'Ask'} wall {metrics['wall_strength']:.2f} ({metrics['wall_strength']/avg_wall:.1f}x avg) at ${price:.2f}"
                })

            # Squeeze detection
            if metrics['total_volume'] < avg_volume * 0.3 and abs(metrics['imbalance']) > 0.5:
                if metrics['imbalance'] < 0:
                    direction = "bullish"
                else:
                    direction = "bearish"

                severity = min(abs(metrics['imbalance']), 1.0)

                anomalies.append({
                    'type': 'SQUEEZE',
                    'direction': direction,
                    'severity': severity,
                    'price': price,
                    'metrics': metrics,
                    'reason': f"Squeeze at ${price:.2f}: low volume, strong {direction} pressure"
                })

        return anomalies

    def generate_signals(self, anomalies: List[Dict], current_price: float) -> List[Dict]:
        """Generate trading signals from anomalies."""
        signals = []

        for anomaly in anomalies:
            entry_price = anomaly['price']
            severity = anomaly['severity']

            # Skip if too far from current price
            if abs(entry_price - current_price) > 20:
                continue

            if anomaly['type'] == 'ABSORPTION':
                direction = anomaly['direction']
                wall_size = anomaly['metrics']['wall_strength']

                price_impact = min(wall_size / 5.0, 5.0)

                if direction == 'bullish':
                    target = entry_price + price_impact
                    stop = entry_price - (price_impact * 0.25)
                else:
                    target = entry_price - price_impact
                    stop = entry_price + (price_impact * 0.25)

                signals.append({
                    'type': 'ABSORPTION',
                    'direction': direction,
                    'entry_price': round(entry_price, 2),
                    'target_price': round(target, 2),
                    'stop_price': round(stop, 2),
                    'confidence': round(min(severity + 0.15, 0.95), 2),
                    'risk_reward': round(abs(target - entry_price) / abs(entry_price - stop), 2),
                    'reason': anomaly['reason']
                })

            elif anomaly['type'] == 'SQUEEZE' and severity > 0.8:
                direction = anomaly['direction']
                price_impact = 3.0

                if direction == 'bullish':
                    target = entry_price + price_impact
                    stop = entry_price - 0.5
                else:
                    target = entry_price - price_impact
                    stop = entry_price + 0.5

                signals.append({
                    'type': 'SQUEEZE',
                    'direction': direction,
                    'entry_price': round(entry_price, 2),
                    'target_price': round(target, 2),
                    'stop_price': round(stop, 2),
                    'confidence': round(min(severity + 0.10, 0.85), 2),
                    'risk_reward': round(abs(target - entry_price) / abs(entry_price - stop), 2),
                    'reason': anomaly['reason'],
                    'note': 'Wait for breakout confirmation'
                })

        return signals

    def process_asset(self, symbol: str):
        """Process an asset and generate signals."""
        print(f"\nProcessing {symbol}...")

        # Load snapshots
        snapshots = self.load_snapshots(symbol, limit=100)

        if not snapshots:
            print(f"  ✗ No snapshots found for {symbol}")
            return

        print(f"  ✓ Loaded {len(snapshots)} snapshots")

        # Compute metrics
        levels_data = self.compute_price_level_metrics(snapshots)

        print(f"  ✓ Computed metrics for {len(levels_data)} price levels")

        # Detect anomalies
        anomalies = self.detect_anomalies(levels_data)

        print(f"  ✓ Detected {len(anomalies)} anomalies")

        # Get current price (median of entry prices)
        if levels_data:
            current_price = statistics.median(levels_data.keys())
        else:
            current_price = 0

        # Generate signals
        signals = self.generate_signals(anomalies, current_price)

        print(f"  ✓ Generated {len(signals)} signals")

        # Save signals
        output_file = self.output_dir / f"signals_{symbol.lower()}.json"

        output_data = {
            'metadata': {
                'generated_at': datetime.now().isoformat(),
                'symbol': symbol,
                'signal_count': len(signals),
                'current_price': current_price,
                'snapshots_processed': len(snapshots)
            },
            'signals': signals
        }

        with open(output_file, 'w') as f:
            json.dump(output_data, f, indent=2)

        print(f"  ✓ Saved to {output_file}")

    def process_all_assets(self):
        """Process all available assets."""
        print("="*80)
        print("MULTI-ASSET SIGNAL GENERATOR")
        print("="*80)

        # Get available assets
        assets = self.get_available_assets()

        print(f"\nFound {len(assets)} assets: {', '.join(assets)}")

        # Process each asset
        for asset in assets:
            self.process_asset(asset)

        print("\n" + "="*80)
        print("✓ All assets processed!")
        print("="*80)


def main():
    """Main script."""
    db_path = Path("/home/ubuntu/.hermes/workspace/projects/ORDER_FLOW_GRAPH/data/binance_multi_asset.db")
    output_dir = Path("/home/ubuntu/.hermes/workspace/projects/ORDER_FLOW_GRAPH/outputs/data")

    # Create output directory
    output_dir.mkdir(parents=True, exist_ok=True)

    # Copy existing XAUTUSDT signals
    xau_signals = Path("/home/ubuntu/.hermes/workspace/projects/ORDER_FLOW_GRAPH/data/signals_tradeable.json")
    if xau_signals.exists():
        import shutil
        shutil.copy(xau_signals, output_dir / "signals_xautusdt.json")
        print("✓ Copied existing XAUTUSDT signals")

    # Generate signals for all assets
    generator = MultiAssetSignalGenerator(db_path, output_dir)
    generator.process_all_assets()


if __name__ == "__main__":
    main()
