import asyncio
import random
from datetime import datetime, timedelta
from sqlalchemy import text
from app.database import init_db, async_session_maker, engine, Base
from app.models.user import User, UserRole
from app.models.driver import Driver, DriverStatus
from app.models.vehicle import VehicleCategory
from app.models.ride import Ride, RideStatus, RideType
from app.models.payment import Payment, TransactionStatus, PaymentMethod

async def reset_db():
    print("Dropping all tables...")
    async with engine.begin() as conn:
        await conn.execute(text("SET FOREIGN_KEY_CHECKS = 0;"))
        for table in Base.metadata.tables.values():
            await conn.execute(text(f"DROP TABLE IF EXISTS {table.name}"))
        await conn.execute(text("SET FOREIGN_KEY_CHECKS = 1;"))

async def seed():
    await reset_db()
    print("Initializing DB...")
    await init_db()
    
    async with async_session_maker() as session:
        print("Checking/Creating Vehicle Category...")
        category = VehicleCategory(
            name="Standard",
            display_name="Standard Ride",
            description="Regular Ride",
            base_fare=500,
            per_km_rate=100,
            per_minute_rate=50,
            minimum_fare=2000
        )

        session.add(category)
        await session.flush() # get id

        # Create Admin User
        admin = User(
            email="admin@gmail.com",
            first_name="Admin",
            last_name="User",
            role=UserRole.ADMIN,
            is_active=True,
            is_verified=True,
            profile_picture="https://ui-avatars.com/api/?name=Admin+User&background=0D8ABC&color=fff"
        )
        # Set password using hash (Optional, since we generate token manually, but good for completeness)
        # admin.password_hash = ... (Need auth service import? No, just skip for now as we forge token)
        session.add(admin)
        await session.flush()
        print(f"Created Admin: {admin.email} (ID: {admin.id})")

        # Specific Drivers matching Design
        specific_drivers = [
            {"name": "Alex Carry", "status": DriverStatus.APPROVED, "earnings": 4520000, "rides": 200, "rating": 4.5, "pic": "https://randomuser.me/api/portraits/men/32.jpg"},
            {"name": "Alex Carry", "status": DriverStatus.SUSPENDED, "earnings": 4520000, "rides": 200, "rating": 4.5, "pic": "https://randomuser.me/api/portraits/men/33.jpg"},
            {"name": "Alex Carry", "status": DriverStatus.APPROVED, "earnings": 4520000, "rides": 200, "rating": 4.5, "pic": "https://randomuser.me/api/portraits/men/34.jpg"}
        ]
        
        drivers = []
        for i, d_data in enumerate(specific_drivers):
            u = User(
                email=f"driver_alex_{i}@example.com",
                first_name="Alex",
                last_name="Carry",
                role=UserRole.DRIVER,
                is_active=True, 
                is_verified=True,
                profile_picture=d_data["pic"]
            )
            session.add(u)
            await session.flush()
            
            d = Driver(
                user_id=u.id,
                status=d_data["status"],
                is_online=True if d_data["status"] == DriverStatus.APPROVED else False,
                current_latitude=33.6844 + (random.random() - 0.5) * 0.1,
                current_longitude=73.0479 + (random.random() - 0.5) * 0.1,
                total_earnings=d_data["earnings"], 
                total_rides=d_data["rides"],
                average_rating=d_data["rating"]
            )
            session.add(d)
            drivers.append(d)

        # Additional random drivers
        for i in range(12):
            u = User(
                email=f"driver{i}@example.com",
                first_name=f"Driver",
                last_name=f"{i}",
                role=UserRole.DRIVER,
                is_active=True, 
                is_verified=True,
                profile_picture=f"https://randomuser.me/api/portraits/men/{i}.jpg"
            )
            session.add(u)
            await session.flush()
            
            d = Driver(
                user_id=u.id,
                status=DriverStatus.APPROVED,
                is_online=True,
                current_latitude=33.6844 + (random.random() - 0.5) * 0.1,
                current_longitude=73.0479 + (random.random() - 0.5) * 0.1,
                total_earnings=random.randint(10000, 500000), 
                total_rides=random.randint(50, 500),
                average_rating=round(random.uniform(3.5, 5.0), 1)
            )
            session.add(d)
            drivers.append(d)
        
        # Passengers
        passengers = []
        for i in range(20):
            u = User(
                email=f"passenger{i}@example.com", 
                role=UserRole.PASSENGER,
                is_active=True
            )
            session.add(u)
            passengers.append(u)
            
        await session.flush()
        
        print("Creating Active Rides (Map)...")
        # Same active rides logic...
        for i in range(5):
            ride = Ride(
                passenger_id=passengers[i].id,
                driver_id=drivers[i].id,
                vehicle_category_id=category.id,
                status=random.choice([RideStatus.STARTED, RideStatus.DRIVER_ARRIVED]),
                pickup_address="Pickup Address Here",
                pickup_latitude=33.6844 + (random.random() - 0.05), 
                pickup_longitude=73.0479 + (random.random() - 0.05),
                dropoff_address="Dropoff Address Here",
                dropoff_latitude=33.7, dropoff_longitude=73.1,
                estimated_fare=1500,
                final_fare=1500,
                ride_code=str(1000+i),
                ride_type=RideType.ONE_WAY
            )
            session.add(ride)

        print("Creating Past Rides & Payments (Chart - Last 12 Months)...")
        # Generate data for 12 months for chart (2 lines: Total and Today(simulated as different series? No, design says Total vs Today. Actually design says 'Total Revenue' and 'Today's Revenue'? Reference says Blue=Total, Orange=Today? No, usually it's Current Period vs Previous or similar. 
        # Design screenshot snippet: "Table Revenue | Today's Revenue". It implies 2 datasets.
        # I will generate data such that we have daily data for 365 days.
        
        end_date = datetime.now()
        start_date = end_date - timedelta(days=365)
        current = start_date
        
        while current <= end_date:
            # Randomize rides per day
            if random.random() > 0.3: # 70% chance of rides
                num_rides = random.randint(1, 5)
                for _ in range(num_rides):
                    amount = random.randint(1000, 8000)
                    ride = Ride(
                        passenger_id=passengers[0].id,
                        driver_id=drivers[0].id,
                        vehicle_category_id=category.id,
                        status=RideStatus.COMPLETED,
                        pickup_address="Old", pickup_latitude=0, pickup_longitude=0,
                        dropoff_address="Old", dropoff_latitude=0, dropoff_longitude=0,
                        final_fare=amount,
                        ride_code=f"hist_{current.strftime('%Y%j')}_{_}",
                        ride_type=RideType.ONE_WAY,
                        created_at=current,
                        updated_at=current
                    )
                    session.add(ride)
                    await session.flush()
                    
                    pay = Payment(
                        ride_id=ride.id,
                        passenger_id=passengers[0].id,
                        driver_id=drivers[0].id,
                        subtotal=amount,
                        total=amount,
                        driver_amount=int(amount * 0.85),
                        status=TransactionStatus.COMPLETED,
                        payment_method=PaymentMethod.CASH,
                        paid_at=current,
                        created_at=current # Important for chart query
                    )
                    session.add(pay)
            current += timedelta(days=1)

        await session.commit()
        print("Seed Complete!")

if __name__ == "__main__":
    asyncio.run(seed())
