#!/usr/bin/env python3
"""
Phase 1 (Revised): Build order flow graph from Binance order book data.

This script:
1. Loads Binance order book snapshots from SQLite
2. Computes order flow metrics (bid/ask imbalance, wall strength)
3. Builds price level nodes with real volume data
4. Detects anomalies (absorption, squeeze, exhaustion)
5. Generates jump-in/jump-out signals
"""

import sqlite3
import json
from pathlib import Path
from datetime import datetime
from collections import defaultdict
from typing import Dict, List
import statistics


class BinanceOrderBookLoader:
    """Load Binance order book snapshots from SQLite."""

    def __init__(self, db_path: Path):
        self.db_path = db_path

    def load_recent_snapshots(self, limit: int = 10) -> List[Dict]:
        """Load recent order book snapshots."""
        print(f"Loading snapshots from {self.db_path}...")

        try:
            conn = sqlite3.connect(self.db_path)
            conn.row_factory = sqlite3.Row
            cursor = conn.cursor()

            # Get recent snapshots
            cursor.execute("""
                SELECT * FROM order_book_snapshots
                ORDER BY id DESC
                LIMIT ?
            """, (limit,))

            rows = cursor.fetchall()

            # Load price levels for each snapshot
            snapshots = []
            for row in rows:
                snapshot_id = row['id']

                # Get bid levels
                cursor.execute("""
                    SELECT price, volume FROM order_book_levels
                    WHERE snapshot_id = ? AND side = 'bid'
                    ORDER BY price DESC
                """, (snapshot_id,))

                bid_levels = {row[0]: row[1] for row in cursor.fetchall()}

                # Get ask levels
                cursor.execute("""
                    SELECT price, volume FROM order_book_levels
                    WHERE snapshot_id = ? AND side = 'ask'
                    ORDER BY price ASC
                """, (snapshot_id,))

                ask_levels = {row[0]: row[1] for row in cursor.fetchall()}

                # Create snapshot dict
                snapshot = {
                    'id': row['id'],
                    'symbol': row['symbol'],
                    'timestamp': row['timestamp'],
                    'update_id': row['update_id'],
                    'bid_levels': bid_levels,
                    'ask_levels': ask_levels,
                    'total_bid_volume': row['total_bid_volume'],
                    'total_ask_volume': row['total_ask_volume']
                }

                snapshots.append(snapshot)

            conn.close()

            # Sort by timestamp ascending
            snapshots.sort(key=lambda x: x['timestamp'])

            print(f"  ✓ Loaded {len(snapshots)} snapshots")

            if snapshots:
                from datetime import datetime
                print(f"    Period: {datetime.fromtimestamp(snapshots[0]['timestamp'])} to {datetime.fromtimestamp(snapshots[-1]['timestamp'])}")

            return snapshots

        except Exception as e:
            print(f"  ✗ Error loading snapshots: {e}")
            import traceback
            traceback.print_exc()
            return []


class OrderFlowAnalyzer:
    """Analyze order flow from Binance snapshots."""

    def compute_price_level_metrics(self, snapshots: List[Dict], price_precision: int = 2) -> Dict:
        """Compute order flow metrics for each price level across all snapshots."""
        print("\nComputing price level metrics...")

        # Round all prices to precision
        price_metrics = defaultdict(lambda: {
            'bid_volume': 0.0,
            'ask_volume': 0.0,
            'bid_count': 0,
            'ask_count': 0,
            'snapshots_seen': 0,
            'first_seen': None,
            'last_seen': None
        })

        for snapshot in snapshots:
            timestamp = snapshot['timestamp']

            # Process bid levels
            for price, volume in snapshot['bid_levels'].items():
                rounded_price = round(price, price_precision)

                price_metrics[rounded_price]['bid_volume'] += volume
                price_metrics[rounded_price]['bid_count'] += 1
                price_metrics[rounded_price]['snapshots_seen'] += 1

                if price_metrics[rounded_price]['first_seen'] is None:
                    price_metrics[rounded_price]['first_seen'] = timestamp
                price_metrics[rounded_price]['last_seen'] = timestamp

            # Process ask levels
            for price, volume in snapshot['ask_levels'].items():
                rounded_price = round(price, price_precision)

                price_metrics[rounded_price]['ask_volume'] += volume
                price_metrics[rounded_price]['ask_count'] += 1
                price_metrics[rounded_price]['snapshots_seen'] += 1

                if price_metrics[rounded_price]['first_seen'] is None:
                    price_metrics[rounded_price]['first_seen'] = timestamp
                price_metrics[rounded_price]['last_seen'] = timestamp

        # Compute derived metrics
        levels_data = {}
        for price, metrics in price_metrics.items():
            total_volume = metrics['bid_volume'] + metrics['ask_volume']
            delta = metrics['ask_volume'] - metrics['bid_volume']

            # Imbalance: -1 (all bids) to +1 (all asks)
            if total_volume > 0:
                imbalance = (metrics['bid_volume'] - metrics['ask_volume']) / total_volume
            else:
                imbalance = 0.0

            # Wall strength (max single-side volume)
            wall_strength = max(metrics['bid_volume'], metrics['ask_volume'])

            # Persistence (how often this level appears)
            persistence = metrics['snapshots_seen'] / len(snapshots)

            levels_data[price] = {
                'price': price,
                'bid_volume': metrics['bid_volume'],
                'ask_volume': metrics['ask_volume'],
                'total_volume': total_volume,
                'delta': delta,
                'imbalance': imbalance,
                'wall_strength': wall_strength,
                'bid_count': metrics['bid_count'],
                'ask_count': metrics['ask_count'],
                'snapshots_seen': metrics['snapshots_seen'],
                'persistence': persistence,
                'first_seen': metrics['first_seen'],
                'last_seen': metrics['last_seen']
            }

        print(f"  ✓ Computed metrics for {len(levels_data)} price levels")

        return levels_data


class AnomalyDetector:
    """Detect order flow anomalies in Binance data."""

    @staticmethod
    def detect_absorption(metrics: Dict, avg_volume: float, avg_wall: float) -> List[Dict]:
        """
        Detect absorption: Large order wall that persists across snapshots.

        Signs:
        - Wall strength > 2x average
        - Persistence > 0.5 (appears in >50% of snapshots)
        - High total volume
        """
        anomalies = []

        if metrics['wall_strength'] >= avg_wall * 2.0 and metrics['persistence'] > 0.5:
            # Determine direction
            if metrics['bid_volume'] > metrics['ask_volume']:
                direction = "bullish"  # Strong bid wall absorbing sellers
                wall_side = "bid"
                wall_size = metrics['bid_volume']
            else:
                direction = "bearish"  # Strong ask wall absorbing buyers
                wall_side = "ask"
                wall_size = metrics['ask_volume']

            # Severity based on wall strength relative to average
            severity = min(metrics['wall_strength'] / (avg_wall * 5.0), 1.0)

            anomalies.append({
                'type': 'ABSORPTION',
                'direction': direction,
                'severity': severity,
                'wall_side': wall_side,
                'wall_size': wall_size,
                'wall_ratio': metrics['wall_strength'] / avg_wall,
                'persistence': metrics['persistence'],
                'reason': f"{wall_side.upper()} wall {wall_size:.2f} ({metrics['wall_strength']/avg_wall:.1f}x avg) at ${metrics['price']:.2f}, persists {metrics['persistence']:.0%}"
            })

        return anomalies

    @staticmethod
    def detect_squeeze(metrics: Dict, avg_volume: float, avg_wall: float) -> List[Dict]:
        """
        Detect squeeze: Thin order book with high imbalance.

        Signs:
        - Total volume < 0.3x average (low participation)
        - |Imbalance| > 0.5 (strong pressure one side)
        """
        anomalies = []

        if metrics['total_volume'] < avg_volume * 0.3 and abs(metrics['imbalance']) > 0.5:
            # Determine direction from imbalance
            if metrics['imbalance'] < 0:  # Negative = more asks
                direction = "bullish"  # Squeezed to upside
            else:
                direction = "bearish"  # Squeezed to downside

            severity = min(abs(metrics['imbalance']), 1.0)

            anomalies.append({
                'type': 'SQUEEZE',
                'direction': direction,
                'severity': severity,
                'volume_ratio': metrics['total_volume'] / avg_volume,
                'imbalance': metrics['imbalance'],
                'reason': f"Squeeze at ${metrics['price']:.2f}: low volume ({metrics['total_volume']/avg_volume:.1%} avg), strong {direction} pressure (imbalance: {metrics['imbalance']:.2f})"
            })

        return anomalies

    @staticmethod
    def detect_exhaustion(metrics: Dict, prev_metrics: Dict) -> List[Dict]:
        """
        Detect exhaustion: Wall size decreasing rapidly.

        Signs:
        - Wall strength drops > 50% from previous snapshot
        """
        anomalies = []

        if prev_metrics:
            wall_change = (metrics['wall_strength'] - prev_metrics['wall_strength']) / prev_metrics['wall_strength'] if prev_metrics['wall_strength'] > 0 else 0

            if wall_change < -0.5:  # 50%+ drop
                # Determine direction
                if metrics['bid_volume'] > metrics['ask_volume']:
                    direction = "bullish"
                else:
                    direction = "bearish"

                severity = abs(wall_change)

                anomalies.append({
                    'type': 'EXHAUSTION',
                    'direction': direction,
                    'severity': severity,
                    'wall_change': wall_change,
                    'reason': f"Exhaustion at ${metrics['price']:.2f}: wall dropped {wall_change:.1%}"
                })

        return anomalies


class OrderFlowGraphBuilder:
    """Build order flow graph from Binance snapshots."""

    def __init__(self, output_dir: Path):
        self.output_dir = output_dir
        self.entities = []
        self.relationships = []

    def build_from_snapshots(self, snapshots: List[Dict]) -> Dict:
        """Build order flow graph from Binance snapshots."""
        print("\n" + "="*80)
        print("BUILDING ORDER FLOW GRAPH (BINANCE DATA)")
        print("="*80)

        if not snapshots:
            print("No snapshots to process")
            return {'entities': [], 'relationships': []}

        # Analyze order flow
        analyzer = OrderFlowAnalyzer()
        levels_data = analyzer.compute_price_level_metrics(snapshots)

        # Compute averages
        volumes = [m['total_volume'] for m in levels_data.values()]
        walls = [m['wall_strength'] for m in levels_data.values()]

        avg_volume = statistics.mean(volumes) if volumes else 0
        avg_wall = statistics.mean(walls) if walls else 0

        print(f"\nAverages:")
        print(f"  Total volume: {avg_volume:.2f}")
        print(f"  Wall strength: {avg_wall:.2f}")

        # Create price level entities
        print("\nCreating price level nodes...")
        for price, metrics in sorted(levels_data.items()):
            entity = {
                'id': f"price_level_{price}",
                'type': 'PriceLevel',
                'name': f"Price_{price}",
                'properties': metrics
            }
            self.entities.append(entity)

        print(f"  ✓ Created {len(levels_data)} price level nodes")

        # Detect anomalies
        print("\nDetecting anomalies...")
        detector = AnomalyDetector()
        anomalies_found = 0

        sorted_prices = sorted(levels_data.keys())

        for i, price in enumerate(sorted_prices):
            metrics = levels_data[price]

            # Check for absorption
            for anomaly in detector.detect_absorption(metrics, avg_volume, avg_wall):
                anomaly_id = f"anomaly_{price}_{anomaly['type']}_{anomalies_found}"

                anomaly_entity = {
                    'id': anomaly_id,
                    'type': 'Anomaly',
                    'name': f"{anomaly['type']}_{price}",
                    'properties': anomaly
                }
                self.entities.append(anomaly_entity)

                self.relationships.append({
                    'from': f"price_level_{price}",
                    'to': anomaly_id,
                    'type': 'HAS_ANOMALY',
                    'properties': {
                        'detected_at': datetime.now().isoformat()
                    }
                })

                anomalies_found += 1
                print(f"  ✓ {anomaly['type']} at ${price:.2f}: {anomaly['direction']} (severity: {anomaly['severity']:.2f})")
                print(f"    {anomaly['reason']}")

            # Check for squeeze
            for anomaly in detector.detect_squeeze(metrics, avg_volume, avg_wall):
                anomaly_id = f"anomaly_{price}_{anomaly['type']}_{anomalies_found}"

                anomaly_entity = {
                    'id': anomaly_id,
                    'type': 'Anomaly',
                    'name': f"{anomaly['type']}_{price}",
                    'properties': anomaly
                }
                self.entities.append(anomaly_entity)

                self.relationships.append({
                    'from': f"price_level_{price}",
                    'to': anomaly_id,
                    'type': 'HAS_ANOMALY',
                    'properties': {
                        'detected_at': datetime.now().isoformat()
                    }
                })

                anomalies_found += 1
                print(f"  ✓ {anomaly['type']} at ${price:.2f}: {anomaly['direction']} (severity: {anomaly['severity']:.2f})")
                print(f"    {anomaly['reason']}")

            # Check for exhaustion (compare with previous price level)
            if i > 0:
                prev_price = sorted_prices[i - 1]
                prev_metrics = levels_data[prev_price]

                for anomaly in detector.detect_exhaustion(metrics, prev_metrics):
                    anomaly_id = f"anomaly_{price}_{anomaly['type']}_{anomalies_found}"

                    anomaly_entity = {
                        'id': anomaly_id,
                        'type': 'Anomaly',
                        'name': f"{anomaly['type']}_{price}",
                        'properties': anomaly
                    }
                    self.entities.append(anomaly_entity)

                    self.relationships.append({
                        'from': f"price_level_{price}",
                        'to': anomaly_id,
                        'type': 'HAS_ANOMALY',
                        'properties': {
                            'detected_at': datetime.now().isoformat()
                        }
                    })

                    anomalies_found += 1
                    print(f"  ✓ {anomaly['type']} at ${price:.2f}: {anomaly['direction']} (severity: {anomaly['severity']:.2f})")
                    print(f"    {anomaly['reason']}")

        print(f"\n  ✓ Detected {anomalies_found} anomalies")

        # Build graph
        graph = {
            'metadata': {
                'created_at': datetime.now().isoformat(),
                'version': '0.2.0',
                'phase': '1',
                'data_source': 'Binance Order Book',
                'entity_count': len(self.entities),
                'relationship_count': len(self.relationships)
            },
            'entities': self.entities,
            'relationships': self.relationships
        }

        return graph

    def save_graph(self, graph: Dict, filename: str):
        """Save graph to JSON file."""
        output_path = self.output_dir / filename

        with open(output_path, 'w') as f:
            json.dump(graph, f, indent=2)

        print(f"\n✓ Saved graph to {output_path}")
        print(f"  Entities: {len(graph['entities'])}")
        print(f"  Relationships: {len(graph['relationships'])}")


def main():
    """Main Phase 1 script with Binance data."""
    print("="*80)
    print("ORDER FLOW GRAPH - PHASE 1 (BINANCE DATA)")
    print("="*80)

    # Paths
    binance_db = Path("/home/ubuntu/.hermes/workspace/projects/ORDER_FLOW_GRAPH/data/binance_order_book.db")
    output_dir = Path("/home/ubuntu/.hermes/workspace/projects/ORDER_FLOW_GRAPH/data")

    # Load Binance order book snapshots
    loader = BinanceOrderBookLoader(binance_db)
    snapshots = loader.load_recent_snapshots(limit=10)

    if not snapshots:
        print("\n✗ No snapshots loaded. Exiting.")
        return

    # Build order flow graph
    builder = OrderFlowGraphBuilder(output_dir)
    graph = builder.build_from_snapshots(snapshots)

    # Save graph
    builder.save_graph(graph, "order_flow_graph_binance.json")

    # Print summary
    print("\n" + "="*80)
    print("PHASE 1 COMPLETE (BINANCE DATA)")
    print("="*80)

    anomaly_count = len([e for e in graph['entities'] if e['type'] == 'Anomaly'])
    price_level_count = len([e for e in graph['entities'] if e['type'] == 'PriceLevel'])

    print(f"\nGraph Statistics:")
    print(f"  Price levels: {price_level_count}")
    print(f"  Anomalies: {anomaly_count}")
    print(f"  Relationships: {len(graph['relationships'])}")

    print(f"\nNext steps:")
    print(f"  1. Generate signals: python3 scripts/generate_signals_binance.py")
    print(f"  2. Query the graph: python3 scripts/query_order_flow.py")

    print("\n" + "="*80)


if __name__ == "__main__":
    main()
