"""
Authentication Service
Handles user authentication, JWT tokens, and OTP verification
"""
import secrets
import string
from datetime import datetime, timedelta
from typing import List, Optional, Tuple
import logging

from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_, and_
from sqlalchemy.exc import IntegrityError
import bcrypt
from jose import jwt, JWTError

from app.config import settings
from app.models.user import User, UserRole
from app.media.profile_picture import save_profile_picture_from_base64
from app.models.driver import Driver, DriverStatus
from app.models.payment import Wallet
from app.schemas.auth import RegisterRequest

logger = logging.getLogger(__name__)


class AuthService:
    """Authentication service for user management and JWT handling"""
    
    @staticmethod
    def hash_password(password: str) -> str:
        """Hash a password using bcrypt (limits input to 72 bytes)"""
        pw = password.encode("utf-8")[:72]
        return bcrypt.hashpw(pw, bcrypt.gensalt()).decode("utf-8")
    
    @staticmethod
    def verify_password(plain_password: str, hashed_password: str) -> bool:
        """Verify a password against its hash"""
        try:
            return bcrypt.checkpw(
                plain_password.encode("utf-8"),
                hashed_password.encode("utf-8"),
            )
        except (ValueError, TypeError):
            return False
    
    @staticmethod
    def normalize_phone_e164(phone: Optional[str]) -> str:
        """Single canonical form for OTP flow (+ then digits only)."""
        if phone is None or not str(phone).strip():
            raise ValueError("Phone is required")
        if isinstance(phone, (int, float)):
            phone = str(int(phone))
        digits = "".join(c for c in str(phone) if c.isdigit())
        if not digits:
            raise ValueError("Invalid phone number")
        return f"+{digits}"
    
    @staticmethod
    def phone_lookup_values(phone: Optional[str]) -> List[str]:
        """Digits-only variants with/without leading + to match how phones are stored in DB."""
        if not phone or not str(phone).strip():
            return []
        digits = "".join(c for c in str(phone) if c.isdigit())
        if not digits:
            return []
        plus_form = f"+{digits}"
        return list(dict.fromkeys([plus_form, digits]))

    @staticmethod
    def generate_otp(length: int = 6) -> str:
        """Generate OTP; in development uses dev_fixed_otp when set."""
        fixed = (settings.dev_fixed_otp or "").strip()
        if settings.environment == "development" and fixed.isdigit() and 4 <= len(fixed) <= 10:
            return fixed
        return "".join(secrets.choice(string.digits) for _ in range(length))

    @staticmethod
    def _user_lookup_by_phone_or_email(
        phone: Optional[str],
        email: Optional[str],
    ):
        """
        If both phone and email are sent, require BOTH to match the same user (avoids
        Swagger's example email matching a different account than the phone you type).
        """
        has_phone = bool(phone and str(phone).strip())
        has_email = bool(email and str(email).strip())
        if not has_phone and not has_email:
            return None
        if has_phone and has_email:
            vals = AuthService.phone_lookup_values(phone)
            if not vals:
                return None
            em = str(email).strip().lower()
            return and_(User.phone.in_(vals), User.email == em)
        if has_phone:
            vals = AuthService.phone_lookup_values(phone)
            return User.phone.in_(vals) if vals else None
        return User.email == str(email).strip().lower()
    
    @staticmethod
    def _phones_equivalent(stored: Optional[str], request_phone: Optional[str]) -> bool:
        if not stored or not request_phone:
            return False
        a = set(AuthService.phone_lookup_values(stored))
        b = set(AuthService.phone_lookup_values(request_phone))
        return bool(a & b)
    
    @classmethod
    async def resolve_identity_for_otp(
        cls,
        db: AsyncSession,
        phone: Optional[str] = None,
        email: Optional[str] = None,
        *,
        allow_create: bool = False,
    ) -> User:
        """
        Find (and optionally create) the user row for OTP.
        ``allow_create=True`` on send-otp only; verify-otp never creates a new row.
        """
        has_phone = bool(phone and str(phone).strip())
        has_email = bool(email and str(email).strip())
        if not has_phone and not has_email:
            raise ValueError("Either email or phone is required")

        if has_phone:
            phone = cls.normalize_phone_e164(phone)

        em = str(email).strip().lower() if has_email else None
        vals = cls.phone_lookup_values(phone) if has_phone else []

        cond = cls._user_lookup_by_phone_or_email(phone, email)
        if cond is None:
            raise ValueError("Invalid phone number")
        result = await db.execute(select(User).where(cond))
        user = result.scalar_one_or_none()
        if user is not None:
            return user

        u_phone = None
        u_email = None
        if vals:
            r = await db.execute(select(User).where(User.phone.in_(vals)))
            u_phone = r.scalar_one_or_none()
        if em:
            r = await db.execute(select(User).where(User.email == em))
            u_email = r.scalar_one_or_none()

        if u_phone and u_email and u_phone.id != u_email.id:
            raise ValueError("Phone and email belong to different accounts")

        candidate = u_phone or u_email

        if candidate is not None:
            if candidate.password_hash is not None:
                if has_phone and has_email:
                    phone_ok = cls._phones_equivalent(candidate.phone, phone)
                    email_ok = (candidate.email or "").lower() == (em or "")
                    if not (phone_ok and email_ok):
                        raise ValueError(
                            "Phone and email do not match one account. "
                            "Use the same details as your login, or send only phone or only email."
                        )
                elif has_phone:
                    if not cls._phones_equivalent(candidate.phone, phone):
                        raise ValueError("User not found")
                else:
                    if (candidate.email or "").lower() != em:
                        raise ValueError("User not found")
                return candidate

            if has_email and em:
                if candidate.email is None:
                    candidate.email = em
                elif candidate.email.lower() != em:
                    raise ValueError("This phone is already linked to another email address")
            if has_phone and vals:
                canon = vals[0]
                if candidate.phone is None:
                    candidate.phone = canon
                elif not cls._phones_equivalent(candidate.phone, phone):
                    raise ValueError("This email is already linked to another phone number")

            await db.flush()
            return candidate

        if not allow_create:
            raise ValueError(
                "No pending signup for this phone. Call POST /auth/send-otp first "
                "with the same number (international format, e.g. +923407279539)."
            )

        kwargs = dict(role=UserRole.PASSENGER, password_hash=None, first_name=None, last_name=None)
        if vals:
            kwargs["phone"] = vals[0]
        if em:
            kwargs["email"] = em
        stub = User(**kwargs)
        db.add(stub)
        try:
            await db.flush()
        except IntegrityError:
            await db.rollback()
            return await cls.resolve_identity_for_otp(db, phone, email, allow_create=True)

        await db.refresh(stub)
        return stub
    
    @classmethod
    def _merge_register_contact_fields(cls, user: User, data: RegisterRequest) -> None:
        """Ensure body phone/email match the verified user, or fill missing fields."""
        if data.phone and str(data.phone).strip():
            pv = cls.phone_lookup_values(data.phone)
            if pv:
                if user.phone:
                    if not cls._phones_equivalent(user.phone, data.phone):
                        raise ValueError("Phone does not match the verified account")
                else:
                    user.phone = pv[0]
        if data.email and str(data.email).strip():
            em = str(data.email).strip().lower()
            if user.email:
                if user.email.lower() != em:
                    raise ValueError("Email does not match the verified account")
            else:
                user.email = em

    @staticmethod
    async def _assert_contact_uniqueness_for_register(
        db: AsyncSession,
        user: User,
        data: RegisterRequest,
    ) -> None:
        """Pre-check unique contact fields to avoid raw DB 500s on register."""
        if data.email and str(data.email).strip() and not user.email:
            em = str(data.email).strip().lower()
            email_query = select(User).where(User.email == em, User.id != user.id)
            email_result = await db.execute(email_query)
            if email_result.scalar_one_or_none():
                raise ValueError("Email already in use")
    
    @staticmethod
    def create_access_token(user_id: int, role: str, expires_delta: Optional[timedelta] = None) -> str:
        """Create a JWT access token"""
        if expires_delta is None:
            expires_delta = timedelta(minutes=settings.access_token_expire_minutes)
        
        expire = datetime.utcnow() + expires_delta
        to_encode = {
            "sub": str(user_id),
            "role": role,
            "type": "access",
            "exp": expire,
            "iat": datetime.utcnow(),
        }
        return jwt.encode(to_encode, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
    
    @staticmethod
    def create_refresh_token(user_id: int) -> str:
        """Create a JWT refresh token"""
        expire = datetime.utcnow() + timedelta(days=settings.refresh_token_expire_days)
        to_encode = {
            "sub": str(user_id),
            "type": "refresh",
            "exp": expire,
            "iat": datetime.utcnow(),
        }
        return jwt.encode(to_encode, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
    
    @staticmethod
    def decode_token(token: str) -> Optional[dict]:
        """Decode and validate a JWT token"""
        try:
            payload = jwt.decode(
                token,
                settings.jwt_secret_key,
                algorithms=[settings.jwt_algorithm]
            )
            return payload
        except JWTError:
            return None
    
    @classmethod
    async def register_user(
        cls,
        db: AsyncSession,
        data: RegisterRequest,
        user: User,
    ) -> Tuple[User, str, str]:
        """
        Complete registration for the JWT-identified user (same row as verify-otp).
        """
        if user.password_hash is not None:
            raise ValueError("Account already completed — use login")
        
        if not user.is_verified:
            raise ValueError("Verify OTP before completing registration")

        await cls._assert_contact_uniqueness_for_register(db, user, data)
        cls._merge_register_contact_fields(user, data)
        
        if data.first_name:
            user.first_name = data.first_name
        if data.last_name:
            user.last_name = data.last_name
        if data.fcm_token:
            user.device_token = data.fcm_token
        
        user.profile_picture = save_profile_picture_from_base64(user.id, data.profile_picture_base64)
        
        user.role = UserRole(data.role)
        
        wq = await db.execute(select(Wallet).where(Wallet.user_id == user.id))
        if wq.scalar_one_or_none() is None:
            db.add(Wallet(user_id=user.id))
        
        if data.role == "driver":
            dq = await db.execute(select(Driver).where(Driver.user_id == user.id))
            if dq.scalar_one_or_none() is None:
                db.add(Driver(user_id=user.id, status=DriverStatus.PENDING))
        
        try:
            await db.commit()
        except IntegrityError:
            await db.rollback()
            raise ValueError("Registration failed due to duplicate or invalid data")
        await db.refresh(user)
        
        access_token = cls.create_access_token(user.id, user.role.value)
        refresh_token = cls.create_refresh_token(user.id)
        user.refresh_token = refresh_token
        try:
            await db.commit()
        except IntegrityError:
            await db.rollback()
            raise ValueError("Could not issue session tokens. Please login again")
        await db.refresh(user)
        
        return user, access_token, refresh_token
    
    @classmethod
    async def login(
        cls,
        db: AsyncSession,
        email: Optional[str] = None,
        phone: Optional[str] = None,
        fcm_token: Optional[str] = None
    ) -> Tuple[User, str, str]:
        """
        Issue tokens for an existing user identified by email or phone (no password).
        User must be OTP-verified and active.
        
        Returns:
            Tuple of (User, access_token, refresh_token)
        """
        query = select(User).where(
            or_(
                User.email == email if email else False,
                User.phone == phone if phone else False
            )
        )
        result = await db.execute(query)
        user = result.scalar_one_or_none()
        
        if not user:
            raise ValueError("Invalid credentials")
        if not user.is_verified:
            raise ValueError("Verify OTP before logging in")
        if not user.is_active:
            raise ValueError("Account is not active")
        if user.is_blocked:
            raise ValueError("Account is blocked")
        
        # Update login info
        user.last_login_at = datetime.utcnow()
        user.login_count += 1
        if fcm_token:
            user.device_token = fcm_token
        
        # Generate tokens
        access_token = cls.create_access_token(user.id, user.role.value)
        refresh_token = cls.create_refresh_token(user.id)
        user.refresh_token = refresh_token
        
        await db.commit()
        await db.refresh(user)
        
        return user, access_token, refresh_token
    
    @classmethod
    async def refresh_tokens(
        cls,
        db: AsyncSession,
        refresh_token: str
    ) -> Tuple[str, str]:
        """
        Refresh access and refresh tokens.
        
        Returns:
            Tuple of (new_access_token, new_refresh_token)
        """
        payload = cls.decode_token(refresh_token)
        
        if not payload or payload.get("type") != "refresh":
            raise ValueError("Invalid refresh token")
        
        user_id = int(payload.get("sub", 0))
        
        query = select(User).where(User.id == user_id)
        result = await db.execute(query)
        user = result.scalar_one_or_none()
        
        if not user or user.refresh_token != refresh_token:
            raise ValueError("Invalid refresh token")
        
        if not user.is_active or user.is_blocked:
            raise ValueError("Account is not active")
        
        # Generate new tokens
        new_access_token = cls.create_access_token(user.id, user.role.value)
        new_refresh_token = cls.create_refresh_token(user.id)
        user.refresh_token = new_refresh_token
        
        await db.commit()
        
        return new_access_token, new_refresh_token
    
    @classmethod
    async def send_otp(
        cls,
        db: AsyncSession,
        phone: Optional[str] = None,
        email: Optional[str] = None
    ) -> str:
        """Generate and store OTP (creates a stub user before register when needed)."""
        user = await cls.resolve_identity_for_otp(db, phone, email, allow_create=True)
        
        otp = cls.generate_otp()
        user.otp_code = otp
        user.otp_expires_at = datetime.utcnow() + timedelta(minutes=5)
        await db.flush()
        
        # TODO: Actually send OTP via Twilio/Email service
        logger.info(f"OTP generated for user {user.id}: {otp}")
        
        return otp
    
    @classmethod
    async def verify_otp(
        cls,
        db: AsyncSession,
        phone: Optional[str] = None,
        email: Optional[str] = None,
        otp: str = ""
    ) -> User:
        """Verify OTP code"""
        user = await cls.resolve_identity_for_otp(db, phone, email, allow_create=False)

        submitted = (otp or "").strip()
        stored = (user.otp_code or "").strip()
        if not stored or stored != submitted:
            raise ValueError("Invalid OTP")
        
        if user.otp_expires_at and user.otp_expires_at < datetime.utcnow():
            raise ValueError("OTP expired")
        
        # Clear OTP and mark as verified
        user.otp_code = None
        user.otp_expires_at = None
        
        if phone:
            user.is_phone_verified = True
        if email:
            user.is_email_verified = True
        
        user.is_verified = True
        
        await db.flush()
        await db.refresh(user)
        
        return user
    
    @classmethod
    async def get_user_by_id(cls, db: AsyncSession, user_id: int) -> Optional[User]:
        """Get user by ID"""
        query = select(User).where(User.id == user_id)
        result = await db.execute(query)
        return result.scalar_one_or_none()
    
    @classmethod
    async def delete_account(cls, db: AsyncSession, user_id: int) -> None:
        """Delete user account and all related data"""
        query = select(User).where(User.id == user_id)
        result = await db.execute(query)
        user = result.scalar_one_or_none()
        if not user:
            raise ValueError("User not found")
        db.delete(user)
        await db.commit()

    @classmethod
    async def logout(cls, db: AsyncSession, user_id: int) -> None:
        """Logout user by clearing refresh token"""
        query = select(User).where(User.id == user_id)
        result = await db.execute(query)
        user = result.scalar_one_or_none()
        
        if user:
            user.refresh_token = None
            user.device_token = None
            await db.commit()

    @classmethod
    async def update_profile(
        cls,
        db: AsyncSession,
        user_id: int,
        data: "UserProfileUpdate"
    ) -> User:
        """Update user profile"""
        query = select(User).where(User.id == user_id)
        result = await db.execute(query)
        user = result.scalar_one_or_none()
        
        if not user:
            raise ValueError("User not found")
            
        if data.email is not None and data.email != user.email:
            # Check for existing email
            email_query = select(User).where(User.email == data.email, User.id != user_id)
            email_result = await db.execute(email_query)
            if email_result.scalar_one_or_none():
                raise ValueError("Email already in use")
            user.email = data.email
            
        if data.first_name is not None:
            user.first_name = data.first_name
        if data.last_name is not None:
            user.last_name = data.last_name
        if data.profile_picture_base64 is not None:
            user.profile_picture = save_profile_picture_from_base64(user.id, data.profile_picture_base64)
        
        await db.commit()
        await db.refresh(user)
        return user
