from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, desc, and_, or_
from datetime import datetime, timedelta

from app.utils.helpers import build_order_by
from typing import List, Optional

from app.models.user import User, UserRole
from app.models.driver import Driver
from app.models.ride import Ride, RideStatus
from app.models.payment import Payment, TransactionStatus
from app.schemas.dashboard import (
    DashboardStatsResponse, 
    RevenueChartPoint, 
    ActiveRideLocation,
    DriverPerformance,
    RideStatsResponse
)

from sqlalchemy.orm import selectinload, aliased

class DashboardService:
    def __init__(self, db: AsyncSession):
        self.db = db

    async def get_stats(self) -> DashboardStatsResponse:
        # Passengers
        passengers_query = select(func.count(User.id)).where(User.role == UserRole.PASSENGER)
        passengers_count = await self.db.scalar(passengers_query) or 0
        
        # Drivers
        drivers_query = select(func.count(Driver.id))
        drivers_count = await self.db.scalar(drivers_query) or 0
        
        # Active Rides (Today)
        today = datetime.now().date()
        rides_query = select(func.count(Ride.id)).where(
            func.date(Ride.created_at) == today,
            Ride.status.in_([
                RideStatus.ACCEPTED, 
                RideStatus.DRIVER_ARRIVED, 
                RideStatus.STARTED
            ])
        )
        active_rides_count = await self.db.scalar(rides_query) or 0
        
        # Total Revenue (Total sum of Payment.total where status is COMPLETED)
        revenue_query = select(func.sum(Payment.total)).where(Payment.status == TransactionStatus.COMPLETED)
        total_revenue_cents = await self.db.scalar(revenue_query) or 0
        total_revenue = float(total_revenue_cents) / 100.0
        
        return DashboardStatsResponse(
            total_passengers=passengers_count,
            total_drivers=drivers_count,
            active_rides=active_rides_count,
            total_revenue=total_revenue
        )

    async def get_revenue_chart(self, days: int = 365) -> List[RevenueChartPoint]:
        # Last 365 days (1 year) by default to match design
        end_date = datetime.now()
        start_date = end_date - timedelta(days=days)
        
        query = select(Payment.created_at, Payment.total, Payment.commission_amount).where(
            Payment.status == TransactionStatus.COMPLETED,
            Payment.created_at >= start_date
        )
        result = await self.db.execute(query)
        rows = result.all()
        
        # Process in python
        # If days > 60, group by Month. Else by Day.
        group_by_month = days > 60
        data = {}
        
        # Pre-fill timeline
        current = start_date
        while current <= end_date:
            if group_by_month:
                key = current.strftime("%b") # Jan, Feb..
                # For sorting, maybe use %Y-%m, but output needs display name
                # Let's use ordered dict logic or just list
                if key not in data:
                    data[key] = {"revenue": 0.0, "rides": 0, "commission": 0.0, "sort": current.strftime("%Y%m")}
                # Advance 1 day (inefficient loop for 365 days but safe) or 1 month?
                # current += timedelta(days=1) is fine for 365.
                current += timedelta(days=1)
            else:
                key = current.strftime("%b %d")
                data[key] = {"revenue": 0.0, "rides": 0, "commission": 0.0}
                current += timedelta(days=1)
            
        for created_at, amount_cents, commission_cents in rows:
            if created_at:
                if group_by_month:
                    date_key = created_at.strftime("%b")
                else:
                    date_key = created_at.strftime("%b %d")

                if date_key in data:
                    data[date_key]["revenue"] += (float(amount_cents) / 100.0)
                    data[date_key]["rides"] += 1
                    data[date_key]["commission"] += (float(commission_cents or 0) / 100.0)
        
        # Convert to list
        # If monthly, ensure sorted by month order (using pre-filled logic, keys are insertion ordered in py3.7+)
        # But looping date filling ensures order.
        return [
            RevenueChartPoint(
                date=k, 
                revenue=v["revenue"], 
                rides=v["rides"],
                commission=v["commission"]
            )
            for k, v in data.items()
        ]

    async def get_user_growth(self) -> List["UserGrowthPoint"]:
        from app.schemas.dashboard import UserGrowthPoint
        
        # Last 6 months or 1 year by default for reports
        # Group by month
        end_date = datetime.now()
        start_date = end_date - timedelta(days=365)
        
        query = select(User.created_at).where(User.created_at >= start_date)
        result = await self.db.execute(query)
        rows = result.scalars().all()
        
        data = {}
        # Pre-fill
        current = start_date
        while current <= end_date:
            key = current.strftime("%b")
            if key not in data:
                data[key] = 0
            current += timedelta(days=1)
            
        for created_at in rows:
            key = created_at.strftime("%b")
            if key in data:
                data[key] += 1
                
        return [
            UserGrowthPoint(date=k, count=v)
            for k, v in data.items()
        ]

    async def get_active_rides(self) -> List[ActiveRideLocation]:
        # Get rides using map provider simulation coordinates would be ideal, 
        # but here we use Driver's last known location for active rides.
        
        query = select(Ride).options(
            selectinload(Ride.driver).selectinload(Driver.user)
        ).where(
            Ride.status.in_([RideStatus.STARTED, RideStatus.DRIVER_ARRIVED])
        )
        result = await self.db.execute(query)
        rides = result.scalars().all()
        
        locations = []
        for ride in rides:
            # Refresh to load driver
            # Note: Assuming driver relation is loaded or lazy loadable
            if ride.driver:
                locations.append(ActiveRideLocation(
                    ride_id=ride.id,
                    driver_name=ride.driver.user.full_name if ride.driver.user else "Unknown Driver",
                    lat=ride.driver.current_latitude or 0.0,
                    lng=ride.driver.current_longitude or 0.0,
                    status=ride.status
                ))
        return locations

    async def get_driver_performance(self, limit: int = 10) -> List[DriverPerformance]:
        query = select(Driver).options(
            selectinload(Driver.user)
        ).order_by(desc(Driver.average_rating)).limit(limit)
        
        result = await self.db.execute(query)
        performance_list = []
        for d in result.scalars().all():
            performance_list.append(DriverPerformance(
                id=d.id,
                name=d.user.full_name if d.user else "Unknown",
                profile_picture=d.user.profile_picture,
                completed_rides=d.total_rides,
                earnings=d.total_earnings / 100.0,
                rating=d.average_rating,
                status=d.status
            ))
        return performance_list

    async def get_passengers(
        self, 
        page: int = 1, 
        limit: int = 10, 
        search: Optional[str] = None,
        status: Optional[str] = None,
        sort_by: str = "created_at",
        order: str = "desc"
    ):
        from app.schemas.dashboard import PassengerListItem, PassengerListResponse
        from app.models.payment import Wallet
        from app.models.rating import Rating
        from sqlalchemy import or_

        offset = (page - 1) * limit

        # Subquery for completed rides count
        rides_subquery = select(
            Ride.passenger_id, 
            func.count(Ride.id).label("rides_count")
        ).where(Ride.status == RideStatus.COMPLETED).group_by(Ride.passenger_id).subquery()

        # Subquery for average rating
        rating_subquery = select(
            Ride.passenger_id,
            func.avg(Rating.passenger_rating).label("avg_rating")
        ).join(
            Rating, Ride.id == Rating.ride_id
        ).where(Rating.passenger_rating != None).group_by(Ride.passenger_id).subquery()

        # Base query joining User and Wallet
        query = select(
            User, 
            Wallet.balance,
            func.coalesce(rides_subquery.c.rides_count, 0).label("rides_count"),
            func.coalesce(rating_subquery.c.avg_rating, 0.0).label("avg_rating")
        ).join(
            Wallet, User.id == Wallet.user_id, isouter=True
        ).join(
            rides_subquery, User.id == rides_subquery.c.passenger_id, isouter=True
        ).join(
            rating_subquery, User.id == rating_subquery.c.passenger_id, isouter=True
        ).where(User.role == UserRole.PASSENGER)

        if search:
            search_filter = or_(
                User.first_name.ilike(f"%{search}%"),
                User.last_name.ilike(f"%{search}%"),
                User.phone.ilike(f"%{search}%"),
                User.email.ilike(f"%{search}%")
            )
            query = query.where(search_filter)
        if status:
            if status.lower() == "active":
                query = query.where(and_(User.is_active == True, User.is_blocked == False))
            elif status.lower() == "blocked":
                query = query.where(User.is_blocked == True)
            elif status.lower() == "suspended":
                query = query.where(and_(User.is_active == False, User.is_blocked == False))

        count_query = select(func.count(User.id)).where(User.role == UserRole.PASSENGER)
        if search:
            count_query = count_query.where(search_filter)
        if status:
            if status.lower() == "active":
                count_query = count_query.where(and_(User.is_active == True, User.is_blocked == False))
            elif status.lower() == "blocked":
                count_query = count_query.where(User.is_blocked == True)
            elif status.lower() == "suspended":
                count_query = count_query.where(and_(User.is_active == False, User.is_blocked == False))
        total = await self.db.scalar(count_query) or 0

        order_clause = build_order_by(User, sort_by, order, {"created_at", "updated_at", "email", "first_name", "last_name"})
        query = query.order_by(order_clause).offset(offset).limit(limit)
        result = await self.db.execute(query)
        rows = result.all()

        items = []
        for user, balance, rides_count, avg_rating in rows:
            p_status = "Active"
            if user.is_blocked:
                p_status = "Blocked"
            elif not user.is_active:
                p_status = "Suspended"

            items.append(PassengerListItem(
                id=user.id,
                name=user.full_name,
                pax_id=f"PAX-{10000 + user.id}",
                phone=user.phone or "N/A",
                rides_count=rides_count,
                rating=float(avg_rating) if avg_rating else 5.0,
                status=p_status,
                wallet_balance=float(balance or 0) / 100.0,
                profile_picture=user.profile_picture
            ))

        pages = (total + limit - 1) // limit if limit > 0 else 0
        return PassengerListResponse(
            items=items,
            total=total,
            page=page,
            limit=limit,
            pages=pages,
            has_next=page < pages,
            has_prev=page > 1
        )

    async def get_driver_verifications(
        self,
        page: int = 1,
        limit: int = 10,
        search: Optional[str] = None,
        status: Optional[str] = None,
        sort_by: str = "created_at",
        order: str = "desc"
    ):
        from app.schemas.dashboard import DriverVerificationItem, DriverVerificationResponse
        from app.models.vehicle import Vehicle
        from sqlalchemy import or_

        offset = (page - 1) * limit

        # Base query joining User, Driver and Vehicle
        query = select(
            User,
            Driver,
            Vehicle.make,
            Vehicle.model
        ).join(
            Driver, User.id == Driver.user_id
        ).join(
            Vehicle, Driver.current_vehicle_id == Vehicle.id, isouter=True
        ).where(User.role == UserRole.DRIVER)

        if search:
            search_filter = or_(
                User.first_name.ilike(f"%{search}%"),
                User.last_name.ilike(f"%{search}%"),
                User.phone.ilike(f"%{search}%"),
                User.email.ilike(f"%{search}%")
            )
            query = query.where(search_filter)
        if status:
            status_filter = Driver.status == status.lower()
            query = query.where(status_filter)

        count_query = select(func.count(Driver.id)).join(User, Driver.user_id == User.id).where(User.role == UserRole.DRIVER)
        if search:
            count_query = count_query.where(search_filter)
        if status:
            count_query = count_query.where(status_filter)
        total = await self.db.scalar(count_query) or 0

        order_clause = build_order_by(User, sort_by, order, {"created_at", "updated_at", "email", "first_name", "last_name"})
        query = query.order_by(order_clause).offset(offset).limit(limit)
        result = await self.db.execute(query)
        rows = result.all()

        items = []
        for user, driver, make, model_name in rows:
            d_status = "Pending"
            if driver.status == "approved":
                d_status = "Approved"
            elif driver.status == "rejected":
                d_status = "Rejected"
            elif driver.status == "under_review":
                d_status = "Pending"

            vehicle_desc = f"{make} {model_name}" if make and model_name else "No Vehicle"

            items.append(DriverVerificationItem(
                id=user.id,
                name=user.full_name,
                drv_id=f"DRV-{10000 + user.id}",
                phone=user.phone or "N/A",
                vehicle_type=vehicle_desc,
                document_status=d_status,
                applied_on=user.created_at.strftime("%d %b, %Y") if user.created_at else "N/A",
                profile_picture=user.profile_picture
            ))

        pages = (total + limit - 1) // limit if limit > 0 else 0
        return DriverVerificationResponse(
            items=items,
            total=total,
            page=page,
            limit=limit,
            pages=pages,
            has_next=page < pages,
            has_prev=page > 1
        )

    async def get_ride_stats(self) -> RideStatsResponse:
        
        # All Rides
        all_query = select(func.count(Ride.id))
        all_count = await self.db.scalar(all_query) or 0
        
        # Ongoing (ACCEPTED, DRIVER_ARRIVED, STARTED)
        ongoing_query = select(func.count(Ride.id)).where(
            Ride.status.in_([RideStatus.ACCEPTED, RideStatus.DRIVER_ARRIVED, RideStatus.STARTED])
        )
        ongoing_count = await self.db.scalar(ongoing_query) or 0
        
        # Upcoming (REQUESTED, SEARCHING, BIDDING or is_scheduled=True)
        upcoming_query = select(func.count(Ride.id)).where(
            or_(
                Ride.status.in_([RideStatus.REQUESTED, RideStatus.SEARCHING, RideStatus.BIDDING]),
                Ride.is_scheduled == True
            )
        )
        upcoming_count = await self.db.scalar(upcoming_query) or 0
        
        # Completed
        completed_query = select(func.count(Ride.id)).where(Ride.status == RideStatus.COMPLETED)
        completed_count = await self.db.scalar(completed_query) or 0
        
        return RideStatsResponse(
            all_rides=all_count,
            ongoing_rides=ongoing_count,
            upcoming_rides=upcoming_count,
            completed_rides=completed_count
        )

    async def get_rides_list(
        self,
        page: int = 1,
        limit: int = 10,
        search: Optional[str] = None,
        status: Optional[str] = None,
        sort_by: str = "created_at",
        order: str = "desc"
    ):
        from app.schemas.dashboard import RideListItem, RideListResponse
        from sqlalchemy import or_

        offset = (page - 1) * limit

        # Base query joining Ride, Driver, User (Driver)
        query = select(Ride).options(
            selectinload(Ride.driver).selectinload(Driver.user)
        )

        status_filter = None
        if search:
            search_filter = or_(
                Ride.ride_code.ilike(f"%{search}%"),
                Ride.pickup_address.ilike(f"%{search}%"),
                Ride.dropoff_address.ilike(f"%{search}%")
            )
            query = query.where(search_filter)
        if status:
            if status.lower() == "live":
                status_filter = Ride.status == RideStatus.STARTED
            elif status.lower() == "completed":
                status_filter = Ride.status == RideStatus.COMPLETED
            elif status.lower() == "upcoming":
                status_filter = Ride.status.in_([RideStatus.REQUESTED, RideStatus.SEARCHING, RideStatus.BIDDING])
            elif status.lower() == "ongoing":
                status_filter = Ride.status.in_([RideStatus.ACCEPTED, RideStatus.DRIVER_ARRIVED])
            else:
                status_filter = None
            if status_filter:
                query = query.where(status_filter)

        count_query = select(func.count(Ride.id))
        if search:
            count_query = count_query.where(search_filter)
        if status and status_filter:
            count_query = count_query.where(status_filter)
        total = await self.db.scalar(count_query) or 0

        order_clause = build_order_by(Ride, sort_by, order, {"created_at", "updated_at", "requested_at", "id"})
        query = query.order_by(order_clause).offset(offset).limit(limit)
        result = await self.db.execute(query)
        rides = result.scalars().all()

        items = []
        for r in rides:
            # Determine status display string
            status_display = "Pending"
            if r.status == RideStatus.STARTED:
                status_display = "Live"
            elif r.status == RideStatus.COMPLETED:
                status_display = "Completed"
            elif r.status in [RideStatus.ACCEPTED, RideStatus.DRIVER_ARRIVED]:
                status_display = "Ongoing"
            else:
                status_display = "Upcoming"

            driver_name = "Searching..."
            driver_pic = None
            if r.driver and r.driver.user:
                driver_name = r.driver.user.full_name
                driver_pic = r.driver.user.profile_picture

            items.append(RideListItem(
                id=r.id,
                trip_id=f"#{r.ride_code}" if r.ride_code else f"#t{r.id}",
                driver_name=driver_name,
                driver_profile_picture=driver_pic,
                start_address=r.pickup_address[:30] + "..." if len(r.pickup_address) > 30 else r.pickup_address,
                destination_address=r.dropoff_address[:30] + "..." if len(r.dropoff_address) > 30 else r.dropoff_address,
                date_time=(r.scheduled_at or r.requested_at).strftime("%d %b %y, %I:%M %p"),
                passengers=r.passenger_count,
                status=status_display
            ))

        return RideListResponse(
            items=items,
            total=total,
            page=page,
            size=limit
        )

    async def get_transactions(
        self,
        page: int = 1,
        limit: int = 10,
        search: Optional[str] = None,
        status: Optional[str] = None,
        sort_by: str = "created_at",
        order: str = "desc"
    ):
        from app.schemas.dashboard import TransactionListItem, TransactionListResponse
        from app.models.payment import Payment, TransactionStatus
        from app.models.ride import Ride
        
        offset = (page - 1) * limit
        
        # Base query with eager loading
        query = select(Payment).options(
            selectinload(Payment.ride),
            selectinload(Payment.passenger),
            selectinload(Payment.driver).selectinload(Driver.user)
        )
        
        # Filters
        if search:
            from sqlalchemy import or_
            # Search by transaction ID, ride code, or user names (passenger + driver)
            DriverUser = aliased(User)
            search_filter = or_(
                Payment.gateway_payment_id.ilike(f"%{search}%"),
                select(Ride.ride_code).where(Ride.id == Payment.ride_id).correlate(Payment).scalar_subquery().ilike(f"%{search}%"),
                User.first_name.ilike(f"%{search}%"),
                User.last_name.ilike(f"%{search}%"),
                DriverUser.first_name.ilike(f"%{search}%"),
                DriverUser.last_name.ilike(f"%{search}%")
            )
            query = query.join(Payment.passenger).join(Payment.driver).join(DriverUser, Driver.user_id == DriverUser.id).where(search_filter)

        status_filter = None
        if status and status != 'All':
            if status.lower() == 'completed':
                status_filter = Payment.status == TransactionStatus.COMPLETED
            elif status.lower() == 'pending':
                status_filter = Payment.status == TransactionStatus.PENDING
            elif status.lower() == 'failed':
                status_filter = Payment.status == TransactionStatus.FAILED
            if status_filter:
                query = query.where(status_filter)

        if search:
            DriverUserCount = aliased(User)
            search_subq = select(Payment.id).join(Payment.passenger).join(Payment.driver).join(DriverUserCount, Driver.user_id == DriverUserCount.id).where(
                or_(
                    Payment.gateway_payment_id.ilike(f"%{search}%"),
                    select(Ride.ride_code).where(Ride.id == Payment.ride_id).correlate(Payment).scalar_subquery().ilike(f"%{search}%"),
                    User.first_name.ilike(f"%{search}%"),
                    User.last_name.ilike(f"%{search}%"),
                    DriverUserCount.first_name.ilike(f"%{search}%"),
                    DriverUserCount.last_name.ilike(f"%{search}%")
                )
            )
            if status_filter:
                search_subq = search_subq.where(status_filter)
            count_query = select(func.count()).select_from(search_subq.subquery())
        else:
            count_query = select(func.count(Payment.id))
            if status_filter:
                count_query = count_query.where(status_filter)
        total = await self.db.scalar(count_query) or 0

        order_clause = build_order_by(Payment, sort_by, order, {"created_at", "updated_at", "paid_at", "total", "id"})
        query = query.order_by(order_clause).offset(offset).limit(limit)
        result = await self.db.execute(query)
        payments = result.scalars().all()
        
        items = []
        for p in payments:
            driver_name = "N/A"
            if p.driver and p.driver.user:
                driver_name = p.driver.user.full_name
            
            items.append(TransactionListItem(
                id=p.id,
                transaction_id=p.gateway_payment_id or f"TXN-{p.id}",
                date_time=(p.paid_at or p.created_at).strftime("%d %b %Y • %I:%M %p"),
                ride_id=f"RIDE-{p.ride.ride_code}" if p.ride else "N/A",
                passenger_name=p.passenger.full_name if p.passenger else "N/A",
                driver_name=driver_name,
                payment_method=p.payment_method.value.title() if p.payment_method else "Cash",
                status=p.status.value.title(),
                driver_earning=f"${(p.driver_amount or 0):.2f}",
                admin_commission=f"${(p.commission_amount or 0):.2f}",
                action="View"
            ))
            
        pages = (total + limit - 1) // limit if limit > 0 else 0
        return TransactionListResponse(
            items=items,
            total=total,
            page=page,
            limit=limit,
            pages=pages,
            has_next=page < pages,
            has_prev=page > 1
        )

    async def update_driver_status(self, driver_id: int, status: str) -> dict:
        from app.models.driver import Driver, DriverStatus
        from fastapi import HTTPException
        
        query = select(Driver).where(Driver.id == driver_id)
        result = await self.db.execute(query)
        driver = result.scalar_one_or_none()
        
        if not driver:
            raise HTTPException(status_code=404, detail="Driver not found")
        
        driver.status = status
        if status == DriverStatus.APPROVED:
            driver.verified_at = func.now()
            # In a real app we would set verified_by from current_user
        
        await self.db.commit()
        return {"message": "Driver status updated successfully", "status": status}
