#!/usr/bin/env python3
"""
Collect Binance order book snapshots for order flow analysis.

This script:
1. Connects to Binance WebSocket API
2. Collects real-time order book depth snapshots
3. Stores bid/ask volumes at each price level
4. Saves to SQLite for order flow graph processing
"""

import websocket
import json
import sqlite3
from pathlib import Path
from datetime import datetime
from collections import defaultdict
import time


class BinanceOrderBookCollector:
    """Collect real-time Binance order book data."""

    def __init__(self, symbol: str, output_db: Path):
        self.symbol = symbol.lower()
        self.output_db = output_db
        self.order_book = {
            'bids': {},  # {price: volume}
            'asks': {}   # {price: volume}
        }
        self.last_update_id = 0
        self.snapshot_count = 0

    def create_database(self):
        """Create SQLite database for order book snapshots."""
        print(f"Creating database at {self.output_db}...")

        conn = sqlite3.connect(self.output_db)
        cursor = conn.cursor()

        # Create order_book_snapshots table
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS order_book_snapshots (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                symbol TEXT NOT NULL,
                timestamp INTEGER NOT NULL,
                update_id INTEGER NOT NULL,
                bid_levels INTEGER NOT NULL,
                ask_levels INTEGER NOT NULL,
                total_bid_volume REAL NOT NULL,
                total_ask_volume REAL NOT NULL,
                created_at INTEGER NOT NULL
            )
        """)

        # Create order_book_levels table (individual price levels)
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS order_book_levels (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                snapshot_id INTEGER NOT NULL,
                side TEXT NOT NULL,
                price REAL NOT NULL,
                volume REAL NOT NULL,
                FOREIGN KEY (snapshot_id) REFERENCES order_book_snapshots(id)
            )
        """)

        conn.commit()
        conn.close()

        print("  ✓ Database created")

    def save_snapshot(self, timestamp: int) -> int:
        """Save current order book state to database."""
        conn = sqlite3.connect(self.output_db)
        cursor = conn.cursor()

        # Compute aggregate stats
        total_bid_volume = sum(self.order_book['bids'].values())
        total_ask_volume = sum(self.order_book['asks'].values())
        bid_levels = len(self.order_book['bids'])
        ask_levels = len(self.order_book['asks'])

        # Insert snapshot
        cursor.execute("""
            INSERT INTO order_book_snapshots
            (symbol, timestamp, update_id, bid_levels, ask_levels, total_bid_volume, total_ask_volume, created_at)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?)
        """, (
            self.symbol.upper(),
            timestamp,
            self.last_update_id,
            bid_levels,
            ask_levels,
            total_bid_volume,
            total_ask_volume,
            int(time.time())
        ))

        snapshot_id = cursor.lastrowid

        # Insert bid levels
        for price, volume in self.order_book['bids'].items():
            cursor.execute("""
                INSERT INTO order_book_levels (snapshot_id, side, price, volume)
                VALUES (?, ?, ?, ?)
            """, (snapshot_id, 'bid', price, volume))

        # Insert ask levels
        for price, volume in self.order_book['asks'].items():
            cursor.execute("""
                INSERT INTO order_book_levels (snapshot_id, side, price, volume)
                VALUES (?, ?, ?, ?)
            """, (snapshot_id, 'ask', price, volume))

        conn.commit()
        conn.close()

        self.snapshot_count += 1
        return snapshot_id

    def on_message(self, ws, message):
        """Handle WebSocket message."""
        try:
            data = json.loads(message)

            # Handle depth update
            if 'e' in data and data['e'] == 'depthUpdate':
                self.last_update_id = data['u']

                # Update bids
                for price, volume in data['b']:
                    price_float = float(price)
                    volume_float = float(volume)

                    if volume_float == 0:
                        # Remove level
                        if price_float in self.order_book['bids']:
                            del self.order_book['bids'][price_float]
                    else:
                        # Update/add level
                        self.order_book['bids'][price_float] = volume_float

                # Update asks
                for price, volume in data['a']:
                    price_float = float(price)
                    volume_float = float(volume)

                    if volume_float == 0:
                        # Remove level
                        if price_float in self.order_book['asks']:
                            del self.order_book['asks'][price_float]
                    else:
                        # Update/add level
                        self.order_book['asks'][price_float] = volume_float

                # Save snapshot every 10 updates
                if self.last_update_id % 10 == 0:
                    timestamp = int(data['E'] / 1000)  # Convert ms to seconds
                    snapshot_id = self.save_snapshot(timestamp)

                    bid_levels = len(self.order_book['bids'])
                    ask_levels = len(self.order_book['asks'])
                    total_bid_volume = sum(self.order_book['bids'].values())
                    total_ask_volume = sum(self.order_book['asks'].values())

                    print(f"[Snapshot #{self.snapshot_count}] ID: {snapshot_id}")
                    print(f"  Bids: {bid_levels} levels, {total_bid_volume:.2f} total volume")
                    print(f"  Asks: {ask_levels} levels, {total_ask_volume:.2f} total volume")

                    # Show top 3 bids and asks
                    sorted_bids = sorted(self.order_book['bids'].items(), key=lambda x: x[1], reverse=True)[:3]
                    sorted_asks = sorted(self.order_book['asks'].items(), key=lambda x: x[1], reverse=True)[:3]

                    print(f"  Top 3 Bids: {sorted_bids}")
                    print(f"  Top 3 Asks: {sorted_asks}")

        except Exception as e:
            print(f"  ✗ Error processing message: {e}")

    def on_error(self, ws, error):
        """Handle WebSocket error."""
        print(f"  ✗ WebSocket error: {error}")

    def on_close(self, ws, close_status_code, close_msg):
        """Handle WebSocket close."""
        print(f"\n✓ WebSocket closed")
        print(f"  Collected {self.snapshot_count} snapshots")

    def on_open(self, ws):
        """Handle WebSocket open."""
        print(f"✓ WebSocket connected to Binance")
        print(f"  Symbol: {self.symbol.upper()}")

        # Subscribe to depth stream
        subscribe_msg = {
            "method": "SUBSCRIBE",
            "params": [f"{self.symbol}@depth"],
            "id": 1
        }

        ws.send(json.dumps(subscribe_msg))
        print(f"  ✓ Subscribed to depth stream")

    def start(self, duration_seconds: int = 60):
        """Start collecting order book data."""
        print("="*80)
        print("BINANCE ORDER BOOK COLLECTOR")
        print("="*80)

        # Create database
        self.create_database()

        # Connect to Binance WebSocket
        ws_url = f"wss://stream.binance.com:9443/ws/{self.symbol}@depth"
        print(f"\nConnecting to {ws_url}...")

        ws = websocket.WebSocketApp(
            ws_url,
            on_open=self.on_open,
            on_message=self.on_message,
            on_error=self.on_error,
            on_close=self.on_close
        )

        # Run for specified duration
        print(f"\nCollecting for {duration_seconds} seconds...")
        print("Press Ctrl+C to stop early\n")

        try:
            ws.run_forever()
        except KeyboardInterrupt:
            print(f"\n✓ Collection stopped by user")
            print(f"  Total snapshots: {self.snapshot_count}")


def main():
    """Main collector script."""
    # Configuration
    symbol = "xautusdt"  # Gold spot on Binance
    duration = 300  # 5 minutes
    output_db = Path("/home/ubuntu/.hermes/workspace/projects/ORDER_FLOW_GRAPH/data/binance_order_book.db")

    # Create collector
    collector = BinanceOrderBookCollector(symbol, output_db)

    # Start collection
    collector.start(duration)


if __name__ == "__main__":
    main()
