#!/usr/bin/env python3
"""
Collect Binance order book snapshots for multiple assets.

This script:
1. Connects to Binance WebSocket API
2. Collects real-time order book depth for multiple symbols
3. Stores bid/ask volumes at each price level
4. Saves to SQLite for order flow graph processing
"""

import websocket
import json
import sqlite3
from pathlib import Path
from datetime import datetime
from collections import defaultdict
import time


class BinanceMultiAssetCollector:
    """Collect real-time Binance order book data for multiple assets."""

    def __init__(self, symbols: list, output_db: Path):
        self.symbols = [s.lower() for s in symbols]
        self.output_db = output_db
        self.order_books = {symbol: {'bids': {}, 'asks': {}} for symbol in self.symbols}
        self.last_update_ids = {symbol: 0 for symbol in self.symbols}
        self.snapshot_counts = {symbol: 0 for symbol in self.symbols}
        self.snapshot_id = 0

    def create_database(self):
        """Create SQLite database for order book snapshots."""
        print(f"Creating database at {self.output_db}...")

        conn = sqlite3.connect(self.output_db)
        cursor = conn.cursor()

        # Create order_book_snapshots table
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS order_book_snapshots (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                symbol TEXT NOT NULL,
                timestamp INTEGER NOT NULL,
                update_id INTEGER NOT NULL,
                bid_levels INTEGER NOT NULL,
                ask_levels INTEGER NOT NULL,
                total_bid_volume REAL NOT NULL,
                total_ask_volume REAL NOT NULL,
                created_at INTEGER NOT NULL
            )
        """)

        # Create order_book_levels table
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS order_book_levels (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                snapshot_id INTEGER NOT NULL,
                side TEXT NOT NULL,
                price REAL NOT NULL,
                volume REAL NOT NULL,
                FOREIGN KEY (snapshot_id) REFERENCES order_book_snapshots(id)
            )
        """)

        # Create index for faster queries
        cursor.execute("""
            CREATE INDEX IF NOT EXISTS idx_symbol_timestamp
            ON order_book_snapshots(symbol, timestamp)
        """)

        conn.commit()
        conn.close()

        print("  ✓ Database created")

    def save_snapshot(self, symbol: str, timestamp: int) -> int:
        """Save current order book state to database."""
        conn = sqlite3.connect(self.output_db)
        cursor = conn.cursor()

        order_book = self.order_books[symbol]

        # Compute aggregate stats
        total_bid_volume = sum(order_book['bids'].values())
        total_ask_volume = sum(order_book['asks'].values())
        bid_levels = len(order_book['bids'])
        ask_levels = len(order_book['asks'])

        # Insert snapshot
        cursor.execute("""
            INSERT INTO order_book_snapshots
            (symbol, timestamp, update_id, bid_levels, ask_levels, total_bid_volume, total_ask_volume, created_at)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?)
        """, (
            symbol.upper(),
            timestamp,
            self.last_update_ids[symbol],
            bid_levels,
            ask_levels,
            total_bid_volume,
            total_ask_volume,
            int(time.time())
        ))

        snapshot_id = cursor.lastrowid

        # Insert bid levels
        for price, volume in order_book['bids'].items():
            cursor.execute("""
                INSERT INTO order_book_levels (snapshot_id, side, price, volume)
                VALUES (?, ?, ?, ?)
            """, (snapshot_id, 'bid', price, volume))

        # Insert ask levels
        for price, volume in order_book['asks'].items():
            cursor.execute("""
                INSERT INTO order_book_levels (snapshot_id, side, price, volume)
                VALUES (?, ?, ?, ?)
            """, (snapshot_id, 'ask', price, volume))

        conn.commit()
        conn.close()

        self.snapshot_counts[symbol] += 1
        return snapshot_id

    def on_message(self, ws, message):
        """Handle WebSocket message."""
        try:
            data = json.loads(message)

            # Handle combined stream format
            if 'stream' in data and 'data' in data:
                # Extract symbol from stream name (e.g., "btcusdt@depth" -> "btcusdt")
                stream_name = data['stream']
                symbol = stream_name.split('@')[0]

                if symbol not in self.symbols:
                    return

                depth_data = data['data']

                # Check if this is a depth update
                if 'e' in depth_data and depth_data['e'] == 'depthUpdate':
                    self.last_update_ids[symbol] = depth_data['u']
                    order_book = self.order_books[symbol]

                    # Update bids
                    for price, volume in depth_data['b']:
                        price_float = float(price)
                        volume_float = float(volume)

                        if volume_float == 0:
                            if price_float in order_book['bids']:
                                del order_book['bids'][price_float]
                        else:
                            order_book['bids'][price_float] = volume_float

                    # Update asks
                    for price, volume in depth_data['a']:
                        price_float = float(price)
                        volume_float = float(volume)

                        if volume_float == 0:
                            if price_float in order_book['asks']:
                                del order_book['asks'][price_float]
                        else:
                            order_book['asks'][price_float] = volume_float

                    # Save snapshot every 10 updates
                    if self.last_update_ids[symbol] % 10 == 0:
                        timestamp = int(depth_data['E'] / 1000)
                        snapshot_id = self.save_snapshot(symbol, timestamp)

                        bid_levels = len(order_book['bids'])
                        ask_levels = len(order_book['asks'])
                        total_bid_volume = sum(order_book['bids'].values())
                        total_ask_volume = sum(order_book['asks'].values())

                        self.snapshot_id += 1

                        print(f"[Snapshot #{self.snapshot_id}] {symbol.upper()}")
                        print(f"  ID: {snapshot_id}")
                        print(f"  Bids: {bid_levels} levels, {total_bid_volume:.2f} total volume")
                        print(f"  Asks: {ask_levels} levels, {total_ask_volume:.2f} total volume")

        except Exception as e:
            print(f"  ✗ Error processing message: {e}")

    def on_error(self, ws, error):
        """Handle WebSocket error."""
        print(f"  ✗ WebSocket error: {error}")

    def on_close(self, ws, close_status_code, close_msg):
        """Handle WebSocket close."""
        print(f"\n✓ WebSocket closed")

        total_snapshots = sum(self.snapshot_counts.values())
        print(f"  Total snapshots collected: {total_snapshots}")

        for symbol, count in self.snapshot_counts.items():
            print(f"    {symbol.upper()}: {count} snapshots")

    def on_open(self, ws):
        """Handle WebSocket open."""
        print(f"✓ WebSocket connected to Binance")
        print(f"  Symbols: {', '.join([s.upper() for s in self.symbols])}")
        print(f"  ✓ Receiving combined depth stream")

    def start(self, duration_seconds: int = 60):
        """Start collecting order book data."""
        print("="*80)
        print("BINANCE MULTI-ASSET ORDER BOOK COLLECTOR")
        print("="*80)

        # Create database
        self.create_database()

        # Build streams URL
        streams = "/".join([f"{s}@depth" for s in self.symbols])
        ws_url = f"wss://stream.binance.com:9443/stream?streams={streams}"

        print(f"\nConnecting to {ws_url}...")

        ws = websocket.WebSocketApp(
            ws_url,
            on_open=self.on_open,
            on_message=self.on_message,
            on_error=self.on_error,
            on_close=self.on_close
        )

        # Run for specified duration
        print(f"\nCollecting for {duration_seconds} seconds...")
        print("Press Ctrl+C to stop early\n")

        try:
            ws.run_forever()
        except KeyboardInterrupt:
            print(f"\n✓ Collection stopped by user")


def main():
    """Main collector script."""
    # Configuration - Add more assets here!
    symbols = [
        "ethusdt",   # Ethereum
        "btcusdt",   # Bitcoin
        "xautusdt"   # Gold (already have data)
    ]

    duration = 300  # 5 minutes
    output_db = Path("/home/ubuntu/.hermes/workspace/projects/ORDER_FLOW_GRAPH/data/binance_multi_asset.db")

    # Create collector
    collector = BinanceMultiAssetCollector(symbols, output_db)

    # Start collection
    collector.start(duration)


if __name__ == "__main__":
    main()
