#!/usr/bin/env python3
"""
Phase 1: Build order flow graph from Sierra tick data.

This script:
1. Loads Sierra tick data from SQLite
2. Computes order flow metrics (delta, imbalance, spread)
3. Builds price level nodes with aggregate metrics
4. Creates temporal edges between levels
5. Detects simple anomalies (absorption, squeeze, exhaustion)
"""

import sqlite3
import json
from pathlib import Path
from datetime import datetime, timedelta
from collections import defaultdict
from typing import Dict, List, Optional
import statistics


class SierraTickLoader:
    """Load Sierra Chart tick data from SQLite."""

    def __init__(self, db_path: Path):
        self.db_path = db_path

    def load_recent_ticks(self, limit: int = 10000) -> List[Dict]:
        """Load recent ticks from database."""
        print(f"Loading ticks from {self.db_path}...")

        try:
            conn = sqlite3.connect(self.db_path)
            conn.row_factory = sqlite3.Row
            cursor = conn.cursor()

            # Get recent ticks
            cursor.execute("""
                SELECT * FROM ticks
                ORDER BY timestamp DESC
                LIMIT ?
            """, (limit,))

            rows = cursor.fetchall()
            conn.close()

            # Convert to list of dicts
            ticks = []
            for row in rows:
                ticks.append({
                    'symbol': row['symbol'],
                    'sequence': row['sequence'],
                    'timestamp': row['timestamp'],
                    'price': row['price'],
                    'volume': row['volume'],
                    'bid': row['bid'] if row['bid'] else row['price'],
                    'ask': row['ask'] if row['ask'] else row['price'],
                    'type': row['type'],
                    'collected_at': row['collected_at']
                })

            # Sort by timestamp ascending
            ticks.sort(key=lambda x: x['timestamp'])

            print(f"  ✓ Loaded {len(ticks)} ticks")
            if ticks:
                from datetime import datetime
                print(f"    Period: {datetime.fromtimestamp(ticks[0]['timestamp'])} to {datetime.fromtimestamp(ticks[-1]['timestamp'])}")

            return ticks

        except Exception as e:
            print(f"  ✗ Error loading ticks: {e}")
            import traceback
            traceback.print_exc()
            return []


class OrderFlowMetrics:
    """Compute order flow metrics from tick data."""

    PRICE_LEVEL_PRECISION = 2  # Round to 2 decimal places (0.25 point levels for XAUUSD)

    @staticmethod
    def round_to_level(price: float) -> float:
        """Round price to standard level."""
        return round(price, OrderFlowMetrics.PRICE_LEVEL_PRECISION)

    @staticmethod
    def compute_metrics(ticks: List[Dict], level_price: float) -> Dict:
        """Compute order flow metrics for a price level."""
        if not ticks:
            return {}

        # Filter ticks at this price level
        level_ticks = [t for t in ticks if OrderFlowMetrics.round_to_level(t['price']) == level_price]

        if not level_ticks:
            return {}

        # Compute basic metrics
        total_volume = sum(t['volume'] for t in level_ticks)

        # Compute bid/ask volume
        bid_volume = sum(t['volume'] for t in level_ticks if t['price'] <= t['bid'])
        ask_volume = sum(t['volume'] for t in level_ticks if t['price'] >= t['ask'])

        # Compute delta (ask - bid, selling pressure - buying pressure)
        delta = ask_volume - bid_volume

        # Compute imbalance (-1 to +1)
        total_bid_ask = bid_volume + ask_volume
        if total_bid_ask > 0:
            imbalance = (bid_volume - ask_volume) / total_bid_ask
        else:
            imbalance = 0.0

        # Compute spread
        if len(level_ticks) >= 2:
            spreads = [abs(t['ask'] - t['bid']) for t in level_ticks]
            avg_spread = statistics.mean(spreads)
        else:
            avg_spread = 0.0

        # Count ticks (activity level)
        tick_count = len(level_ticks)

        # Time range
        if len(level_ticks) >= 2:
            time_range = level_ticks[-1]['timestamp'] - level_ticks[0]['timestamp']
        else:
            time_range = 0

        return {
            'level_price': level_price,
            'tick_count': tick_count,
            'total_volume': total_volume,
            'bid_volume': bid_volume,
            'ask_volume': ask_volume,
            'delta': delta,
            'imbalance': imbalance,
            'avg_spread': avg_spread,
            'time_range': time_range,
            'first_tick': level_ticks[0]['timestamp'],
            'last_tick': level_ticks[-1]['timestamp']
        }


class AnomalyDetector:
    """Detect order flow anomalies."""

    @staticmethod
    def detect_absorption(metrics: Dict, avg_volume: float, avg_delta: float) -> Optional[Dict]:
        """
        Detect absorption: High volume but price fails to break through.

        Signs:
        - Volume > 2x average
        - Delta shows rejection (price returns to level)
        - Price stuck at level (time range > average)
        """
        if not metrics:
            return None

        volume_ratio = metrics['total_volume'] / avg_volume if avg_volume > 0 else 0

        # Absorption criteria
        if volume_ratio >= 2.0:
            # Determine direction (absorbing bids or asks?)
            if metrics['delta'] < 0:  # More selling pressure
                direction = "bullish"  # Absorbing sellers = bullish signal
            else:
                direction = "bearish"  # Absorbing buyers = bearish signal

            # Severity based on volume ratio
            severity = min(volume_ratio / 5.0, 1.0)  # Cap at 1.0

            return {
                'type': 'ABSORPTION',
                'direction': direction,
                'severity': severity,
                'volume_ratio': volume_ratio,
                'reason': f"Volume {volume_ratio:.1f}x average, {'bearish' if metrics['delta'] > 0 else 'bullish'} pressure being absorbed"
            }

        return None

    @staticmethod
    def detect_squeeze(metrics: Dict, avg_spread: float, avg_volume: float, avg_delta: float) -> Optional[Dict]:
        """
        Detect squeeze: Thin book + high delta = volatility explosion coming.

        Signs:
        - Spread < 0.5x average (thin liquidity)
        - Volume < 0.3x average (low participation)
        - |Delta| > 2x average (high pressure)
        """
        if not metrics:
            return None

        spread_ratio = metrics['avg_spread'] / avg_spread if avg_spread > 0 else 1.0
        volume_ratio = metrics['total_volume'] / avg_volume if avg_volume > 0 else 1.0
        delta_ratio = abs(metrics['delta']) / avg_delta if avg_delta > 0 else 0

        # Squeeze criteria
        if spread_ratio < 0.5 and volume_ratio < 0.3 and delta_ratio > 2.0:
            # Determine direction from delta
            if metrics['delta'] > 0:
                direction = "bullish"
            else:
                direction = "bearish"

            # Severity based on how extreme the squeeze is
            severity = min((1.0 - spread_ratio) * delta_ratio / 5.0, 1.0)

            return {
                'type': 'SQUEEZE',
                'direction': direction,
                'severity': severity,
                'spread_ratio': spread_ratio,
                'delta_ratio': delta_ratio,
                'reason': f"Squeeze: spread {spread_ratio:.1%} of avg, delta {delta_ratio:.1f}x avg"
            }

        return None

    @staticmethod
    def detect_exhaustion(metrics: Dict, previous_metrics: Dict) -> Optional[Dict]:
        """
        Detect exhaustion: Order depletion at key level.

        Signs:
        - Tick count decreasing rapidly
        - Volume drying up
        - Imbalance shifting
        """
        if not metrics or not previous_metrics:
            return None

        # Check if activity is decreasing
        if metrics['tick_count'] < previous_metrics['tick_count'] * 0.5:
            # Determine direction from imbalance shift
            imbalance_change = metrics['imbalance'] - previous_metrics['imbalance']

            if imbalance_change > 0.2:
                direction = "bullish"  # Buyers stepping in
            elif imbalance_change < -0.2:
                direction = "bearish"  # Sellers stepping in
            else:
                direction = "neutral"

            # Severity based on how much activity dropped
            severity = 1.0 - (metrics['tick_count'] / previous_metrics['tick_count'])

            return {
                'type': 'EXHAUSTION',
                'direction': direction,
                'severity': severity,
                'tick_count_ratio': metrics['tick_count'] / previous_metrics['tick_count'],
                'reason': f"Exhaustion: activity dropped to {metrics['tick_count'] / previous_metrics['tick_count']:.1%} of previous"
            }

        return None


class OrderFlowGraphBuilder:
    """Build order flow graph from tick data."""

    def __init__(self, output_dir: Path):
        self.output_dir = output_dir
        self.entities = []
        self.relationships = []

    def build_from_ticks(self, ticks: List[Dict]) -> Dict:
        """Build order flow graph from tick data."""
        print("\n" + "="*80)
        print("BUILDING ORDER FLOW GRAPH")
        print("="*80)

        if not ticks:
            print("No ticks to process")
            return {'entities': [], 'relationships': []}

        # Group ticks by price level
        print("\nGrouping ticks by price level...")
        level_ticks = defaultdict(list)
        for tick in ticks:
            level_price = OrderFlowMetrics.round_to_level(tick['price'])
            level_ticks[level_price].append(tick)

        print(f"  ✓ Found {len(level_ticks)} price levels")

        # Compute metrics for each level
        print("\nComputing order flow metrics...")
        level_metrics = {}
        for level_price, ticks_at_level in level_ticks.items():
            metrics = OrderFlowMetrics.compute_metrics(ticks_at_level, level_price)
            if metrics:
                level_metrics[level_price] = metrics

        print(f"  ✓ Computed metrics for {len(level_metrics)} levels")

        # Compute averages for anomaly detection
        volumes = [m['total_volume'] for m in level_metrics.values()]
        deltas = [abs(m['delta']) for m in level_metrics.values()]
        spreads = [m['avg_spread'] for m in level_metrics.values()]

        avg_volume = statistics.mean(volumes) if volumes else 0
        avg_delta = statistics.mean(deltas) if deltas else 0
        avg_spread = statistics.mean(spreads) if spreads else 0

        print(f"\nAverages:")
        print(f"  Volume: {avg_volume:.1f}")
        print(f"  Delta: {avg_delta:.1f}")
        print(f"  Spread: {avg_spread:.2f}")

        # Create entities (price levels)
        print("\nCreating price level nodes...")
        for level_price, metrics in sorted(level_metrics.items()):
            entity = {
                'id': f"price_level_{level_price}",
                'type': 'PriceLevel',
                'name': f"Price_{level_price}",
                'properties': {
                    'price': level_price,
                    'tick_count': metrics['tick_count'],
                    'total_volume': metrics['total_volume'],
                    'bid_volume': metrics['bid_volume'],
                    'ask_volume': metrics['ask_volume'],
                    'delta': metrics['delta'],
                    'imbalance': metrics['imbalance'],
                    'avg_spread': metrics['avg_spread'],
                    'first_tick': metrics['first_tick'],
                    'last_tick': metrics['last_tick']
                }
            }
            self.entities.append(entity)

        print(f"  ✓ Created {len(self.entities)} price level nodes")

        # Detect anomalies
        print("\nDetecting anomalies...")
        anomalies_found = 0

        # Sort levels by price for temporal ordering
        sorted_levels = sorted(level_metrics.keys())

        for i, level_price in enumerate(sorted_levels):
            metrics = level_metrics[level_price]

            # Check for absorption
            anomaly = AnomalyDetector.detect_absorption(metrics, avg_volume, avg_delta)
            if anomaly:
                anomaly_id = f"anomaly_{level_price}_{anomaly['type']}"

                # Create anomaly entity
                anomaly_entity = {
                    'id': anomaly_id,
                    'type': 'Anomaly',
                    'name': f"{anomaly['type']}_{level_price}",
                    'properties': anomaly
                }
                self.entities.append(anomaly_entity)

                # Create relationship
                self.relationships.append({
                    'from': f"price_level_{level_price}",
                    'to': anomaly_id,
                    'type': 'HAS_ANOMALY',
                    'properties': {
                        'detected_at': datetime.now().isoformat()
                    }
                })

                anomalies_found += 1
                print(f"  ✓ {anomaly['type']} at {level_price}: {anomaly['direction']} (severity: {anomaly['severity']:.2f})")

            # Check for squeeze
            anomaly = AnomalyDetector.detect_squeeze(metrics, avg_spread, avg_volume, avg_delta)
            if anomaly:
                anomaly_id = f"anomaly_{level_price}_{anomaly['type']}"

                anomaly_entity = {
                    'id': anomaly_id,
                    'type': 'Anomaly',
                    'name': f"{anomaly['type']}_{level_price}",
                    'properties': anomaly
                }
                self.entities.append(anomaly_entity)

                self.relationships.append({
                    'from': f"price_level_{level_price}",
                    'to': anomaly_id,
                    'type': 'HAS_ANOMALY',
                    'properties': {
                        'detected_at': datetime.now().isoformat()
                    }
                })

                anomalies_found += 1
                print(f"  ✓ {anomaly['type']} at {level_price}: {anomaly['direction']} (severity: {anomaly['severity']:.2f})")

            # Check for exhaustion (compare with previous level)
            if i > 0:
                prev_price = sorted_levels[i - 1]
                prev_metrics = level_metrics[prev_price]

                anomaly = AnomalyDetector.detect_exhaustion(metrics, prev_metrics)
                if anomaly:
                    anomaly_id = f"anomaly_{level_price}_{anomaly['type']}"

                    anomaly_entity = {
                        'id': anomaly_id,
                        'type': 'Anomaly',
                        'name': f"{anomaly['type']}_{level_price}",
                        'properties': anomaly
                    }
                    self.entities.append(anomaly_entity)

                    self.relationships.append({
                        'from': f"price_level_{level_price}",
                        'to': anomaly_id,
                        'type': 'HAS_ANOMALY',
                        'properties': {
                            'detected_at': datetime.now().isoformat()
                        }
                    })

                    anomalies_found += 1
                    print(f"  ✓ {anomaly['type']} at {level_price}: {anomaly['direction']} (severity: {anomaly['severity']:.2f})")

        print(f"\n  ✓ Detected {anomalies_found} anomalies")

        # Create temporal edges between price levels
        print("\nCreating temporal edges...")
        for i in range(len(sorted_levels) - 1):
            from_level = sorted_levels[i]
            to_level = sorted_levels[i + 1]

            self.relationships.append({
                'from': f"price_level_{from_level}",
                'to': f"price_level_{to_level}",
                'type': 'NEXT_LEVEL',
                'properties': {
                    'price_gap': to_level - from_level
                }
            })

        print(f"  ✓ Created {len(sorted_levels) - 1} temporal edges")

        # Build graph
        graph = {
            'metadata': {
                'created_at': datetime.now().isoformat(),
                'version': '0.1.0',
                'phase': '1',
                'entity_count': len(self.entities),
                'relationship_count': len(self.relationships),
                'source': 'Sierra tick data'
            },
            'entities': self.entities,
            'relationships': self.relationships
        }

        return graph

    def save_graph(self, graph: Dict, filename: str):
        """Save graph to JSON file."""
        output_path = self.output_dir / filename

        with open(output_path, 'w') as f:
            json.dump(graph, f, indent=2)

        print(f"\n✓ Saved graph to {output_path}")
        print(f"  Entities: {len(graph['entities'])}")
        print(f"  Relationships: {len(graph['relationships'])}")


def main():
    """Main Phase 1 script."""
    print("="*80)
    print("ORDER FLOW GRAPH - PHASE 1")
    print("="*80)

    # Paths
    sierra_db = Path("/home/ubuntu/.hermes/workspace/tick_collector_api/ticks.db")
    output_dir = Path("/home/ubuntu/.hermes/workspace/projects/ORDER_FLOW_GRAPH/data")

    # Create output directory
    output_dir.mkdir(parents=True, exist_ok=True)

    # Load Sierra tick data
    loader = SierraTickLoader(sierra_db)
    ticks = loader.load_recent_ticks(limit=10000)

    if not ticks:
        print("\n✗ No ticks loaded. Exiting.")
        return

    # Build order flow graph
    builder = OrderFlowGraphBuilder(output_dir)
    graph = builder.build_from_ticks(ticks)

    # Save graph
    builder.save_graph(graph, "order_flow_graph_phase1.json")

    # Print summary
    print("\n" + "="*80)
    print("PHASE 1 COMPLETE")
    print("="*80)

    anomaly_count = len([e for e in graph['entities'] if e['type'] == 'Anomaly'])
    price_level_count = len([e for e in graph['entities'] if e['type'] == 'PriceLevel'])

    print(f"\nGraph Statistics:")
    print(f"  Price levels: {price_level_count}")
    print(f"  Anomalies: {anomaly_count}")
    print(f"  Relationships: {len(graph['relationships'])}")

    print(f"\nNext steps:")
    print(f"  1. Query the graph: python3 scripts/query_order_flow.py")
    print(f"  2. Generate signals: python3 scripts/generate_signals.py")
    print(f"  3. Visualize anomalies: python3 scripts/visualize_order_flow.py")

    print("\n" + "="*80)


if __name__ == "__main__":
    main()
