"""Core Middleware

This module contains common middleware classes for security, logging,
and request processing in the Adtlas project.
"""

import json
import logging
import time
from datetime import datetime, timedelta
from typing import Callable, Optional

from django.conf import settings
from django.contrib.auth import get_user_model
from django.core.cache import cache
from django.core.exceptions import PermissionDenied
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.shortcuts import redirect
from django.urls import reverse
from django.utils import timezone
from django.utils.deprecation import MiddlewareMixin
from django.utils.translation import gettext as _

from .constants import (
    ACTIVITY_LOGIN, ACTIVITY_LOGOUT, ACTIVITY_VIEW, ACTIVITY_ERROR,
    MAX_LOGIN_ATTEMPTS, ACCOUNT_LOCKOUT_DURATION
)
from .models import ActivityLog
from .utils import get_client_ip, get_user_agent, is_ajax_request


User = get_user_model()
logger = logging.getLogger(__name__)


# ============================================================================
# Security Middleware
# ============================================================================

class SecurityHeadersMiddleware(MiddlewareMixin):
    """Middleware to add security headers to responses."""
    
    def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
        """Add security headers to response.
        
        Args:
            request: HTTP request object
            response: HTTP response object
            
        Returns:
            Modified HTTP response object
        """
        # X-Content-Type-Options
        response['X-Content-Type-Options'] = 'nosniff'
        
        # X-Frame-Options
        response['X-Frame-Options'] = 'DENY'
        
        # X-XSS-Protection
        response['X-XSS-Protection'] = '1; mode=block'
        
        # Referrer-Policy
        response['Referrer-Policy'] = 'strict-origin-when-cross-origin'
        
        # Content-Security-Policy (basic)
        if not settings.DEBUG:
            response['Content-Security-Policy'] = (
                "default-src 'self'; "
                "script-src 'self' 'unsafe-inline' 'unsafe-eval'; "
                "style-src 'self' 'unsafe-inline'; "
                "img-src 'self' data: https:; "
                "font-src 'self' https:; "
                "connect-src 'self';"
            )
        
        # Strict-Transport-Security (HTTPS only)
        if request.is_secure():
            response['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
        
        return response


class IPWhitelistMiddleware(MiddlewareMixin):
    """Middleware to restrict access based on IP whitelist."""
    
    def __init__(self, get_response: Callable):
        """Initialize middleware.
        
        Args:
            get_response: Next middleware or view function
        """
        self.get_response = get_response
        self.whitelist = getattr(settings, 'IP_WHITELIST', [])
        self.enabled = getattr(settings, 'IP_WHITELIST_ENABLED', False)
        super().__init__(get_response)
    
    def process_request(self, request: HttpRequest) -> Optional[HttpResponse]:
        """Check if client IP is whitelisted.
        
        Args:
            request: HTTP request object
            
        Returns:
            HTTP response if access denied, None otherwise
        """
        if not self.enabled or not self.whitelist:
            return None
        
        # Skip whitelist check for certain paths
        exempt_paths = getattr(settings, 'IP_WHITELIST_EXEMPT_PATHS', [])
        if any(request.path.startswith(path) for path in exempt_paths):
            return None
        
        client_ip = get_client_ip(request)
        
        if client_ip not in self.whitelist:
            logger.warning(
                f"Access denied for IP {client_ip} - not in whitelist",
                extra={'ip': client_ip, 'path': request.path}
            )
            
            if is_ajax_request(request):
                return JsonResponse({
                    'success': False,
                    'error': 'Access denied from your IP address'
                }, status=403)
            else:
                raise PermissionDenied("Access denied from your IP address")
        
        return None


class RateLimitMiddleware(MiddlewareMixin):
    """Middleware for rate limiting requests."""
    
    def __init__(self, get_response: Callable):
        """Initialize middleware.
        
        Args:
            get_response: Next middleware or view function
        """
        self.get_response = get_response
        self.enabled = getattr(settings, 'RATE_LIMIT_ENABLED', True)
        self.default_limit = getattr(settings, 'RATE_LIMIT_DEFAULT', 100)
        self.window_size = getattr(settings, 'RATE_LIMIT_WINDOW', 3600)  # 1 hour
        super().__init__(get_response)
    
    def process_request(self, request: HttpRequest) -> Optional[HttpResponse]:
        """Check rate limit for request.
        
        Args:
            request: HTTP request object
            
        Returns:
            HTTP response if rate limit exceeded, None otherwise
        """
        if not self.enabled:
            return None
        
        # Skip rate limiting for certain paths
        exempt_paths = getattr(settings, 'RATE_LIMIT_EXEMPT_PATHS', ['/admin/', '/api/health/'])
        if any(request.path.startswith(path) for path in exempt_paths):
            return None
        
        # Get rate limit key
        if request.user.is_authenticated:
            key = f"rate_limit_user_{request.user.id}"
            limit = getattr(settings, 'RATE_LIMIT_AUTHENTICATED', self.default_limit)
        else:
            client_ip = get_client_ip(request)
            key = f"rate_limit_ip_{client_ip}"
            limit = getattr(settings, 'RATE_LIMIT_ANONYMOUS', self.default_limit // 2)
        
        # Check current count
        current_count = cache.get(key, 0)
        
        if current_count >= limit:
            logger.warning(
                f"Rate limit exceeded for {key}",
                extra={'key': key, 'count': current_count, 'limit': limit}
            )
            
            if is_ajax_request(request):
                return JsonResponse({
                    'success': False,
                    'error': 'Rate limit exceeded. Please try again later.',
                    'retry_after': self.window_size
                }, status=429)
            else:
                response = HttpResponse(
                    'Rate limit exceeded. Please try again later.',
                    status=429
                )
                response['Retry-After'] = str(self.window_size)
                return response
        
        # Increment counter
        cache.set(key, current_count + 1, self.window_size)
        
        return None


# ============================================================================
# Activity Tracking Middleware
# ============================================================================

class ActivityTrackingMiddleware(MiddlewareMixin):
    """Middleware to track user activities."""
    
    def __init__(self, get_response: Callable):
        """Initialize middleware.
        
        Args:
            get_response: Next middleware or view function
        """
        self.get_response = get_response
        self.enabled = getattr(settings, 'ACTIVITY_TRACKING_ENABLED', True)
        super().__init__(get_response)
    
    def process_request(self, request: HttpRequest) -> None:
        """Process request and track activity.
        
        Args:
            request: HTTP request object
        """
        if not self.enabled:
            return
        
        # Store request start time
        request._activity_start_time = time.time()
        
        # Skip tracking for certain paths
        exempt_paths = getattr(settings, 'ACTIVITY_TRACKING_EXEMPT_PATHS', [
            '/admin/jsi18n/', '/static/', '/media/', '/favicon.ico'
        ])
        
        if any(request.path.startswith(path) for path in exempt_paths):
            request._skip_activity_tracking = True
            return
        
        request._skip_activity_tracking = False
    
    def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
        """Process response and log activity.
        
        Args:
            request: HTTP request object
            response: HTTP response object
            
        Returns:
            HTTP response object
        """
        if not self.enabled or getattr(request, '_skip_activity_tracking', True):
            return response
        
        # Calculate response time
        start_time = getattr(request, '_activity_start_time', time.time())
        response_time = time.time() - start_time
        
        # Log activity asynchronously
        try:
            self._log_activity(request, response, response_time)
        except Exception as e:
            logger.error(f"Failed to log activity: {e}")
        
        return response
    
    def _log_activity(self, request: HttpRequest, response: HttpResponse, response_time: float) -> None:
        """Log user activity.
        
        Args:
            request: HTTP request object
            response: HTTP response object
            response_time: Response time in seconds
        """
        # Determine activity type
        if request.method == 'GET':
            activity_type = ACTIVITY_VIEW
        elif response.status_code >= 400:
            activity_type = ACTIVITY_ERROR
        else:
            activity_type = ACTIVITY_VIEW
        
        # Prepare activity data
        extra_data = {
            'method': request.method,
            'status_code': response.status_code,
            'response_time': round(response_time, 3),
            'content_length': len(response.content) if hasattr(response, 'content') else 0,
        }
        
        # Add query parameters if present
        if request.GET:
            extra_data['query_params'] = dict(request.GET)
        
        # Create activity log
        ActivityLog.objects.create(
            user=request.user if request.user.is_authenticated else None,
            activity_type=activity_type,
            description=f"{request.method} {request.path}",
            object_type='http_request',
            ip_address=get_client_ip(request),
            user_agent=get_user_agent(request),
            extra_data=extra_data
        )


# ============================================================================
# Session Management Middleware
# ============================================================================

class SessionTimeoutMiddleware(MiddlewareMixin):
    """Middleware to handle session timeout."""
    
    def __init__(self, get_response: Callable):
        """Initialize middleware.
        
        Args:
            get_response: Next middleware or view function
        """
        self.get_response = get_response
        self.timeout = getattr(settings, 'SESSION_TIMEOUT', 1800)  # 30 minutes
        self.warning_time = getattr(settings, 'SESSION_WARNING_TIME', 300)  # 5 minutes
        super().__init__(get_response)
    
    def process_request(self, request: HttpRequest) -> Optional[HttpResponse]:
        """Check session timeout.
        
        Args:
            request: HTTP request object
            
        Returns:
            HTTP response if session expired, None otherwise
        """
        if not request.user.is_authenticated:
            return None
        
        # Skip timeout check for certain paths
        exempt_paths = getattr(settings, 'SESSION_TIMEOUT_EXEMPT_PATHS', [
            '/admin/', '/api/auth/refresh/', '/api/auth/logout/'
        ])
        
        if any(request.path.startswith(path) for path in exempt_paths):
            return None
        
        # Check last activity
        last_activity = request.session.get('last_activity')
        
        if last_activity:
            last_activity_time = datetime.fromisoformat(last_activity)
            time_since_activity = timezone.now() - last_activity_time
            
            if time_since_activity.total_seconds() > self.timeout:
                # Session expired
                request.session.flush()
                
                logger.info(
                    f"Session expired for user {request.user.username}",
                    extra={'user_id': request.user.id, 'timeout': self.timeout}
                )
                
                if is_ajax_request(request):
                    return JsonResponse({
                        'success': False,
                        'error': 'Session expired. Please log in again.',
                        'redirect': reverse('accounts:login')
                    }, status=401)
                else:
                    return redirect('accounts:login')
        
        # Update last activity
        request.session['last_activity'] = timezone.now().isoformat()
        
        # Add session info to request
        if last_activity:
            last_activity_time = datetime.fromisoformat(last_activity)
            time_remaining = self.timeout - (timezone.now() - last_activity_time).total_seconds()
            request.session_time_remaining = max(0, time_remaining)
            request.session_warning = time_remaining <= self.warning_time
        
        return None


# ============================================================================
# Request Processing Middleware
# ============================================================================

class RequestProcessingMiddleware(MiddlewareMixin):
    """Middleware for general request processing."""
    
    def process_request(self, request: HttpRequest) -> None:
        """Process incoming request.
        
        Args:
            request: HTTP request object
        """
        # Add custom attributes to request
        request.client_ip = get_client_ip(request)
        request.user_agent = get_user_agent(request)
        request.is_ajax = is_ajax_request(request)
        request.is_mobile = self._is_mobile_request(request)
        request.request_id = self._generate_request_id()
        
        # Add request start time for performance monitoring
        request.start_time = time.time()
    
    def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
        """Process outgoing response.
        
        Args:
            request: HTTP request object
            response: HTTP response object
            
        Returns:
            HTTP response object
        """
        # Add request ID to response headers
        if hasattr(request, 'request_id'):
            response['X-Request-ID'] = request.request_id
        
        # Add processing time header
        if hasattr(request, 'start_time'):
            processing_time = time.time() - request.start_time
            response['X-Processing-Time'] = f"{processing_time:.3f}s"
        
        return response
    
    def _is_mobile_request(self, request: HttpRequest) -> bool:
        """Check if request is from mobile device.
        
        Args:
            request: HTTP request object
            
        Returns:
            Boolean indicating if request is from mobile device
        """
        user_agent = request.META.get('HTTP_USER_AGENT', '').lower()
        mobile_keywords = [
            'mobile', 'android', 'iphone', 'ipad', 'ipod', 'blackberry',
            'windows phone', 'opera mini', 'opera mobi'
        ]
        return any(keyword in user_agent for keyword in mobile_keywords)
    
    def _generate_request_id(self) -> str:
        """Generate unique request ID.
        
        Returns:
            Unique request ID string
        """
        import uuid
        return str(uuid.uuid4())[:8]


# ============================================================================
# Error Handling Middleware
# ============================================================================

class ErrorHandlingMiddleware(MiddlewareMixin):
    """Middleware for centralized error handling."""
    
    def process_exception(self, request: HttpRequest, exception: Exception) -> Optional[HttpResponse]:
        """Handle exceptions.
        
        Args:
            request: HTTP request object
            exception: Exception instance
            
        Returns:
            HTTP response for handled exceptions, None otherwise
        """
        # Log the exception
        logger.error(
            f"Unhandled exception: {exception}",
            exc_info=True,
            extra={
                'request_path': request.path,
                'request_method': request.method,
                'user': str(request.user) if hasattr(request, 'user') and request.user.is_authenticated else 'Anonymous',
                'ip': get_client_ip(request),
            }
        )
        
        # Create activity log for errors
        try:
            ActivityLog.objects.create(
                user=request.user if hasattr(request, 'user') and request.user.is_authenticated else None,
                activity_type=ACTIVITY_ERROR,
                description=f"Exception: {exception.__class__.__name__}: {str(exception)}",
                object_type='exception',
                ip_address=get_client_ip(request),
                user_agent=get_user_agent(request),
                extra_data={
                    'exception_type': exception.__class__.__name__,
                    'exception_message': str(exception),
                    'request_path': request.path,
                    'request_method': request.method,
                }
            )
        except Exception as log_error:
            logger.error(f"Failed to log exception activity: {log_error}")
        
        # Return custom error response for AJAX requests
        if is_ajax_request(request):
            if settings.DEBUG:
                error_message = str(exception)
            else:
                error_message = "An error occurred while processing your request."
            
            return JsonResponse({
                'success': False,
                'error': error_message,
                'error_type': exception.__class__.__name__
            }, status=500)
        
        # Let Django handle non-AJAX requests
        return None


# ============================================================================
# Maintenance Mode Middleware
# ============================================================================

class MaintenanceModeMiddleware(MiddlewareMixin):
    """Middleware to handle maintenance mode."""
    
    def __init__(self, get_response: Callable):
        """Initialize middleware.
        
        Args:
            get_response: Next middleware or view function
        """
        self.get_response = get_response
        self.enabled = getattr(settings, 'MAINTENANCE_MODE', False)
        super().__init__(get_response)
    
    def process_request(self, request: HttpRequest) -> Optional[HttpResponse]:
        """Check if maintenance mode is enabled.
        
        Args:
            request: HTTP request object
            
        Returns:
            HTTP response if in maintenance mode, None otherwise
        """
        if not self.enabled:
            return None
        
        # Allow access for superusers
        if hasattr(request, 'user') and request.user.is_authenticated and request.user.is_superuser:
            return None
        
        # Allow access to certain paths
        exempt_paths = getattr(settings, 'MAINTENANCE_MODE_EXEMPT_PATHS', [
            '/admin/', '/api/health/', '/maintenance/'
        ])
        
        if any(request.path.startswith(path) for path in exempt_paths):
            return None
        
        # Return maintenance response
        if is_ajax_request(request):
            return JsonResponse({
                'success': False,
                'error': 'The system is currently under maintenance. Please try again later.',
                'maintenance_mode': True
            }, status=503)
        else:
            return HttpResponse(
                'The system is currently under maintenance. Please try again later.',
                status=503
            )


# ============================================================================
# CORS Middleware (Simple Implementation)
# ============================================================================

class SimpleCORSMiddleware(MiddlewareMixin):
    """Simple CORS middleware for API requests."""
    
    def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse:
        """Add CORS headers to response.
        
        Args:
            request: HTTP request object
            response: HTTP response object
            
        Returns:
            Modified HTTP response object
        """
        # Only add CORS headers for API requests
        if request.path.startswith('/api/'):
            allowed_origins = getattr(settings, 'CORS_ALLOWED_ORIGINS', ['*'])
            
            if '*' in allowed_origins:
                response['Access-Control-Allow-Origin'] = '*'
            else:
                origin = request.META.get('HTTP_ORIGIN')
                if origin in allowed_origins:
                    response['Access-Control-Allow-Origin'] = origin
            
            response['Access-Control-Allow-Methods'] = 'GET, POST, PUT, PATCH, DELETE, OPTIONS'
            response['Access-Control-Allow-Headers'] = (
                'Accept, Accept-Language, Content-Language, Content-Type, '
                'Authorization, X-Requested-With, X-CSRFToken'
            )
            response['Access-Control-Allow-Credentials'] = 'true'
            response['Access-Control-Max-Age'] = '3600'
        
        return response
    
    def process_request(self, request: HttpRequest) -> Optional[HttpResponse]:
        """Handle preflight OPTIONS requests.
        
        Args:
            request: HTTP request object
            
        Returns:
            HTTP response for OPTIONS requests, None otherwise
        """
        if request.method == 'OPTIONS' and request.path.startswith('/api/'):
            response = HttpResponse()
            return self.process_response(request, response)
        
        return None