"""
WebSocket Routes
Real-time ride tracking and notifications
"""
from typing import Dict, Set, Optional, Any
from datetime import datetime
import logging
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from sqlalchemy import select

from app.database import async_session_maker
from app.services.auth_service import AuthService
from app.models.user import UserRole
from app.models.driver import Driver
from app.models.ride import Ride, RideStatus

logger = logging.getLogger(__name__)

router = APIRouter()


class ConnectionManager:
    """WebSocket connection manager for real-time updates"""
    
    def __init__(self):
        # Active connections by ride_id
        self.ride_connections: Dict[int, Set[WebSocket]] = {}
        
        # Driver tracking connections
        self.driver_connections: Dict[int, WebSocket] = {}
        
        # User notification connections (supports multiple tabs/devices)
        self.user_connections: Dict[int, Set[WebSocket]] = {}
        self.driver_broadcast_connections: Set[WebSocket] = set()
    
    async def connect_to_ride(self, websocket: WebSocket, ride_id: int):
        """Connect to a ride's real-time updates"""
        await websocket.accept()
        if ride_id not in self.ride_connections:
            self.ride_connections[ride_id] = set()
        self.ride_connections[ride_id].add(websocket)
        logger.info(f"WebSocket connected to ride {ride_id}")
    
    def disconnect_from_ride(self, websocket: WebSocket, ride_id: int):
        """Disconnect from ride updates"""
        if ride_id in self.ride_connections:
            self.ride_connections[ride_id].discard(websocket)
            if not self.ride_connections[ride_id]:
                del self.ride_connections[ride_id]
        logger.info(f"WebSocket disconnected from ride {ride_id}")
    
    async def connect_driver(self, websocket: WebSocket, driver_id: int):
        """Connect driver for location tracking"""
        await websocket.accept()
        self.driver_connections[driver_id] = websocket
        logger.info(f"Driver {driver_id} connected for tracking")
    
    def disconnect_driver(self, driver_id: int):
        """Disconnect driver tracking"""
        if driver_id in self.driver_connections:
            del self.driver_connections[driver_id]
        logger.info(f"Driver {driver_id} disconnected")
    
    async def connect_user(self, websocket: WebSocket, user_id: int):
        """Connect user for notifications"""
        if websocket.client_state.name == "CONNECTING":
            await websocket.accept()
        if user_id not in self.user_connections:
            self.user_connections[user_id] = set()
        self.user_connections[user_id].add(websocket)
        logger.info(f"User {user_id} connected for notifications")
    
    def disconnect_user(self, user_id: int, websocket: Optional[WebSocket] = None):
        """Disconnect user notifications"""
        if user_id in self.user_connections:
            if websocket is None:
                del self.user_connections[user_id]
            else:
                self.user_connections[user_id].discard(websocket)
                if not self.user_connections[user_id]:
                    del self.user_connections[user_id]
        logger.info(f"User {user_id} disconnected")

    async def connect_driver_broadcast(self, websocket: WebSocket):
        """Connect driver to receive global ride request events."""
        if websocket.client_state.name == "CONNECTING":
            await websocket.accept()
        self.driver_broadcast_connections.add(websocket)

    def disconnect_driver_broadcast(self, websocket: WebSocket):
        """Disconnect driver from global ride request stream."""
        self.driver_broadcast_connections.discard(websocket)
    
    async def broadcast_to_ride(self, ride_id: int, message: dict):
        """Broadcast message to all connections for a ride"""
        if ride_id in self.ride_connections:
            disconnected = set()
            for websocket in self.ride_connections[ride_id]:
                try:
                    await websocket.send_json(message)
                except Exception as e:
                    logger.error(f"Error broadcasting to ride {ride_id}: {e}")
                    disconnected.add(websocket)
            
            # Clean up disconnected
            for ws in disconnected:
                self.ride_connections[ride_id].discard(ws)
    
    async def send_to_user(self, user_id: int, message: dict):
        """Send message to specific user"""
        if user_id in self.user_connections:
            disconnected: Set[WebSocket] = set()
            for ws in self.user_connections[user_id]:
                try:
                    await ws.send_json(message)
                except Exception as e:
                    logger.error(f"Error sending to user {user_id}: {e}")
                    disconnected.add(ws)
            for ws in disconnected:
                self.disconnect_user(user_id, ws)

    async def broadcast_to_drivers(self, message: dict):
        """Broadcast message to all connected drivers."""
        disconnected: Set[WebSocket] = set()
        for ws in self.driver_broadcast_connections:
            try:
                await ws.send_json(message)
            except Exception as e:
                logger.error(f"Error broadcasting to drivers: {e}")
                disconnected.add(ws)
        for ws in disconnected:
            self.disconnect_driver_broadcast(ws)


# Global connection manager
manager = ConnectionManager()


@router.websocket("/ws/ride/{ride_id}")
async def websocket_ride_tracking(
    websocket: WebSocket,
    ride_id: int,
    token: str = Query(...)
):
    """
    WebSocket endpoint for real-time ride tracking.
    
    Connect to receive:
    - Driver location updates
    - Ride status changes
    - ETA updates
    
    Send:
    - Driver location updates (for drivers)
    """
    # Verify token
    payload = AuthService.decode_token(token)
    if not payload:
        await websocket.close(code=4001, reason="Invalid token")
        return
    
    user_id = int(payload.get("sub", 0))
    role = payload.get("role", "")
    
    await manager.connect_to_ride(websocket, ride_id)
    
    try:
        while True:
            # Receive messages
            data = await websocket.receive_json()
            
            # Handle different message types
            message_type = data.get("type")
            
            if message_type == "location_update" and role == "driver":
                # Driver sending location update
                location_update = {
                    "type": "driver_location",
                    "ride_id": ride_id,
                    "latitude": data.get("latitude"),
                    "longitude": data.get("longitude"),
                    "heading": data.get("heading"),
                    "timestamp": datetime.utcnow().isoformat()
                }
                await manager.broadcast_to_ride(ride_id, location_update)
            
            elif message_type == "ping":
                # Keep-alive ping
                await websocket.send_json({"type": "pong"})
    
    except WebSocketDisconnect:
        manager.disconnect_from_ride(websocket, ride_id)
    except Exception as e:
        logger.error(f"WebSocket error for ride {ride_id}: {e}")
        manager.disconnect_from_ride(websocket, ride_id)


@router.websocket("/ws/driver/{driver_id}")
async def websocket_driver_tracking(
    websocket: WebSocket,
    driver_id: int,
    token: str = Query(...)
):
    """
    WebSocket endpoint for driver location tracking.
    
    Drivers send their location updates here.
    """
    # Verify token
    payload = AuthService.decode_token(token)
    if not payload:
        await websocket.close(code=4001, reason="Invalid token")
        return
    
    if payload.get("role") != "driver":
        await websocket.close(code=4003, reason="Not a driver")
        return
    
    await manager.connect_driver(websocket, driver_id)
    
    try:
        while True:
            data = await websocket.receive_json()
            
            if data.get("type") == "location":
                # Update driver location in database
                # TODO: Store location update
                pass
            
            elif data.get("type") == "ping":
                await websocket.send_json({"type": "pong"})
    
    except WebSocketDisconnect:
        manager.disconnect_driver(driver_id)
    except Exception as e:
        logger.error(f"WebSocket error for driver {driver_id}: {e}")
        manager.disconnect_driver(driver_id)


@router.websocket("/ws/notifications")
async def websocket_notifications(
    websocket: WebSocket,
    token: str = Query(...)
):
    """
    WebSocket endpoint for real-time notifications.
    
    Receive:
    - New ride requests (for drivers)
    - Ride status updates (for passengers)
    - New bids (for passengers)
    - General notifications
    """
    # Verify token
    payload = AuthService.decode_token(token)
    if not payload:
        await websocket.close(code=4001, reason="Invalid token")
        return
    
    user_id = int(payload.get("sub", 0))
    
    await manager.connect_user(websocket, user_id)
    
    try:
        while True:
            data = await websocket.receive_json()
            
            if data.get("type") == "ping":
                await websocket.send_json({"type": "pong"})
    
    except WebSocketDisconnect:
        manager.disconnect_user(user_id, websocket)
    except Exception as e:
        logger.error(f"WebSocket error for user {user_id}: {e}")
        manager.disconnect_user(user_id, websocket)


@router.websocket("/ws/rides/live")
async def websocket_live_rides(
    websocket: WebSocket,
    token: str = Query(...)
):
    """
    Real-time ride feed for passenger/driver apps.
    Pushes ride events: requested, bidding, accepted, started, completed, cancelled.
    """
    payload = AuthService.decode_token(token)
    if not payload:
        await websocket.close(code=4001, reason="Invalid token")
        return

    user_id = int(payload.get("sub", 0))
    role = payload.get("role")

    if role == UserRole.DRIVER.value:
        await manager.connect_driver_broadcast(websocket)
    await manager.connect_user(websocket, user_id)

    # Send an initial active-ride snapshot so clients can render immediately.
    active = await _get_active_ride_snapshot(user_id, role)
    await websocket.send_json(
        {
            "type": "ride_snapshot",
            "active_ride": active,
            "timestamp": datetime.utcnow().isoformat(),
        }
    )

    try:
        while True:
            data = await websocket.receive_json()
            if data.get("type") == "ping":
                await websocket.send_json({"type": "pong"})
    except WebSocketDisconnect:
        manager.disconnect_user(user_id, websocket)
        if role == UserRole.DRIVER.value:
            manager.disconnect_driver_broadcast(websocket)
    except Exception as e:
        logger.error(f"WebSocket error for live rides user={user_id}: {e}")
        manager.disconnect_user(user_id, websocket)
        if role == UserRole.DRIVER.value:
            manager.disconnect_driver_broadcast(websocket)


# Helper functions for other parts of the application to use

async def notify_ride_update(ride_id: int, status: str, data: dict = None):
    """Send ride status update to all connected clients"""
    message = {
        "type": "ride_status",
        "ride_id": ride_id,
        "status": status,
        "data": data,
        "timestamp": datetime.utcnow().isoformat()
    }
    await manager.broadcast_to_ride(ride_id, message)


async def notify_new_bid(ride_id: int, passenger_id: int, bid_data: dict):
    """Notify passenger of new bid"""
    message = {
        "type": "new_bid",
        "ride_id": ride_id,
        "bid": bid_data,
        "timestamp": datetime.utcnow().isoformat()
    }
    await manager.broadcast_to_ride(ride_id, message)
    await manager.send_to_user(passenger_id, message)


async def notify_driver_location(ride_id: int, latitude: float, longitude: float, heading: float = None):
    """Broadcast driver location to ride participants"""
    message = {
        "type": "driver_location",
        "ride_id": ride_id,
        "latitude": latitude,
        "longitude": longitude,
        "heading": heading,
        "timestamp": datetime.utcnow().isoformat()
    }
    await manager.broadcast_to_ride(ride_id, message)


def _serialize_ride_status(value: Any) -> Any:
    if hasattr(value, "value"):
        return value.value
    return value


def _serialize_datetime(value: Any) -> Any:
    if isinstance(value, datetime):
        return value.isoformat()
    return value


def _ride_payload(ride: Ride, event: str, extra: Optional[dict] = None) -> dict:
    payload = {
        "type": "ride_event",
        "event": event,
        "ride": {
            "id": ride.id,
            "ride_code": ride.ride_code,
            "status": _serialize_ride_status(ride.status),
            "passenger_id": ride.passenger_id,
            "driver_id": ride.driver_id,
            "pickup_address": ride.pickup_address,
            "dropoff_address": ride.dropoff_address,
            "estimated_fare": ride.estimated_fare,
            "accepted_bid_amount": ride.accepted_bid_amount,
            "requested_at": _serialize_datetime(ride.requested_at),
            "accepted_at": _serialize_datetime(ride.accepted_at),
            "started_at": _serialize_datetime(ride.started_at),
            "completed_at": _serialize_datetime(ride.completed_at),
            "cancelled_at": _serialize_datetime(ride.cancelled_at),
            "cancelled_by": ride.cancelled_by,
        },
        "timestamp": datetime.utcnow().isoformat(),
    }
    if extra:
        payload["data"] = extra
    return payload


async def publish_ride_event(ride: Ride, event: str, extra: Optional[dict] = None):
    """Publish a ride lifecycle event to passenger, driver, and ride subscribers."""
    message = _ride_payload(ride, event, extra)
    await manager.broadcast_to_ride(ride.id, message)
    await manager.send_to_user(ride.passenger_id, message)
    if ride.driver and ride.driver.user_id:
        await manager.send_to_user(ride.driver.user_id, message)
    if event == "ride_requested":
        await manager.broadcast_to_drivers(message)


async def _get_active_ride_snapshot(user_id: int, role: str) -> Optional[dict]:
    async with async_session_maker() as db:
        if role == UserRole.DRIVER.value:
            driver_result = await db.execute(select(Driver).where(Driver.user_id == user_id))
            driver = driver_result.scalar_one_or_none()
            if not driver:
                return None
            query = select(Ride).where(
                Ride.driver_id == driver.id,
                Ride.status.in_(
                    [
                        RideStatus.ACCEPTED,
                        RideStatus.DRIVER_ARRIVED,
                        RideStatus.STARTED,
                    ]
                ),
            ).order_by(Ride.created_at.desc()).limit(1)
        else:
            query = select(Ride).where(
                Ride.passenger_id == user_id,
                Ride.status.in_(
                    [
                        RideStatus.REQUESTED,
                        RideStatus.SEARCHING,
                        RideStatus.BIDDING,
                        RideStatus.ACCEPTED,
                        RideStatus.DRIVER_ARRIVED,
                        RideStatus.STARTED,
                    ]
                ),
            ).order_by(Ride.created_at.desc()).limit(1)

        result = await db.execute(query)
        ride = result.scalar_one_or_none()
        if not ride:
            return None
        return {
            "id": ride.id,
            "ride_code": ride.ride_code,
            "status": _serialize_ride_status(ride.status),
            "pickup_address": ride.pickup_address,
            "dropoff_address": ride.dropoff_address,
            "driver_id": ride.driver_id,
            "passenger_id": ride.passenger_id,
        }
