from typing import Optional, Dict, Any
import sys

# Add shared modules to path
sys.path.append('/app')

from shared.security import verify_token, is_token_expired
from shared.redis_client import redis_client
from app.models.user import User
from app.models.auth_session import AuthSession
from app.schemas.token import TokenValidationResponse
from bson import ObjectId

class TokenService:
    
    @staticmethod
    async def validate_token(token: str) -> TokenValidationResponse:
        """Validate a JWT token"""
        try:
            # Check if token is blacklisted
            blacklisted = await redis_client.exists(f"blacklisted_token:{token}")
            if blacklisted:
                return TokenValidationResponse(
                    valid=False,
                    error="Token has been revoked"
                )
            
            # Verify token signature and decode payload
            payload = verify_token(token)
            if not payload:
                return TokenValidationResponse(
                    valid=False,
                    error="Invalid token signature"
                )
            
            # Check if token is expired
            if is_token_expired(payload):
                return TokenValidationResponse(
                    valid=False,
                    error="Token has expired"
                )
            
            user_id = payload.get("sub")
            email = payload.get("email")
            username = payload.get("username")
            
            if not user_id:
                return TokenValidationResponse(
                    valid=False,
                    error="Invalid token payload"
                )
            
            # Verify user still exists and is active
            user = await User.get(ObjectId(user_id))
            if not user or not user.can_login():
                return TokenValidationResponse(
                    valid=False,
                    error="User not found or inactive"
                )
            
            return TokenValidationResponse(
                valid=True,
                user_id=user_id,
                email=email,
                username=username
            )
            
        except Exception as e:
            return TokenValidationResponse(
                valid=False,
                error=f"Token validation error: {str(e)}"
            )

    @staticmethod
    async def blacklist_token(token: str, reason: str = "manual") -> bool:
        """Blacklist a token"""
        try:
            # Decode token to get expiration
            payload = verify_token(token)
            if not payload:
                return False
            
            exp = payload.get("exp")
            if not exp:
                return False
            
            # Calculate TTL (time until token expires)
            import datetime
            exp_datetime = datetime.datetime.fromtimestamp(exp, datetime.timezone.utc)
            now = datetime.datetime.now(datetime.timezone.utc)
            ttl = int((exp_datetime - now).total_seconds())
            
            if ttl > 0:
                # Store in Redis with TTL
                await redis_client.set(
                    f"blacklisted_token:{token}",
                    reason,
                    expire=ttl
                )
                return True
            
            return False
            
        except Exception:
            return False

    @staticmethod
    async def get_user_sessions(user_id: str, current_token: Optional[str] = None) -> Dict[str, Any]:
        """Get all active sessions for a user"""
        try:
            sessions = await AuthSession.find(
                AuthSession.user_id == ObjectId(user_id),
                AuthSession.is_active == True
            ).to_list()
            
            # Filter out expired sessions and prepare response
            active_sessions = []
            current_session_id = None
            
            # If we have current token, try to identify current session
            if current_token:
                payload = verify_token(current_token)
                # This would require additional logic to match token to session
                # For now, we'll mark the first active session as current
            
            for session in sessions:
                if not session.is_expired():
                    is_current = len(active_sessions) == 0  # Mark first as current for demo
                    
                    active_sessions.append({
                        "session_id": str(session.id),
                        "user_id": str(session.user_id),
                        "created_at": session.created_at,
                        "last_accessed": session.last_accessed,
                        "expires_at": session.expires_at,
                        "device_info": session.device_info,
                        "ip_address": session.ip_address,
                        "user_agent": session.user_agent,
                        "is_current": is_current
                    })
                    
                    if is_current:
                        current_session_id = str(session.id)
                else:
                    # Cleanup expired session
                    await session.revoke("expired")
            
            return {
                "sessions": active_sessions,
                "total_count": len(active_sessions),
                "current_session_id": current_session_id
            }
            
        except Exception as e:
            return {
                "sessions": [],
                "total_count": 0,
                "error": str(e)
            }

    @staticmethod
    async def revoke_session(user_id: str, session_id: str) -> bool:
        """Revoke a specific session"""
        try:
            session = await AuthSession.find_one(
                AuthSession.id == ObjectId(session_id),
                AuthSession.user_id == ObjectId(user_id)
            )
            
            if session and session.is_active:
                await session.revoke("manual_revocation")
                
                # Also revoke associated refresh tokens
                from app.models.refresh_token import RefreshToken
                tokens = await RefreshToken.find(
                    RefreshToken.session_id == session.id,
                    RefreshToken.is_active == True
                ).to_list()
                
                for token in tokens:
                    await token.revoke("session_revoked")
                
                return True
            
            return False
            
        except Exception:
            return False

    @staticmethod
    async def cleanup_expired_tokens():
        """Cleanup expired tokens and sessions (background task)"""
        try:
            from datetime import datetime
            
            # Cleanup expired sessions
            expired_sessions = await AuthSession.find(
                AuthSession.expires_at < datetime.utcnow(),
                AuthSession.is_active == True
            ).to_list()
            
            for session in expired_sessions:
                await session.revoke("expired")
            
            # Cleanup expired refresh tokens
            from app.models.refresh_token import RefreshToken
            expired_tokens = await RefreshToken.find(
                RefreshToken.expires_at < datetime.utcnow(),
                RefreshToken.is_active == True
            ).to_list()
            
            for token in expired_tokens:
                await token.revoke("expired")
            
            return {
                "expired_sessions_cleaned": len(expired_sessions),
                "expired_tokens_cleaned": len(expired_tokens)
            }
            
        except Exception as e:
            return {"error": str(e)}
