"""Celery tasks for Reporting and Analytics app."""

import logging
import json
from datetime import datetime, timedelta, time
from decimal import Decimal
from typing import Dict, List, Optional, Any

from celery import shared_task
from django.utils import timezone
from django.db import transaction
from django.db.models import Q, Avg, Sum, Count, Max, Min
from django.core.mail import send_mail
from django.conf import settings
from django.template.loader import render_to_string

from .models import (
    AnalyticsRegion,
    AnalyticsTarget,
    SfrAnalytics,
    MarketShare,
    VerificationRecord,
    PredictionModel,
    SfrPrediction,
    AdbreakPrediction,
    ActivityLog,
    RealTimeAdbreak,
)
from apps.channels.models import Channel
from apps.campaigns.models import Campaign
from apps.advertisers.models import Brand

logger = logging.getLogger(__name__)


@shared_task(bind=True, max_retries=3)
def process_sfr_analytics_data(self, data_batch: List[Dict], source: str = 'api'):
    """
    Process a batch of SFR analytics data.
    
    Args:
        data_batch: List of analytics data dictionaries
        source: Data source identifier
    """
    try:
        processed_count = 0
        error_count = 0
        
        with transaction.atomic():
            for data in data_batch:
                try:
                    # Extract and validate data
                    channel_id = data.get('channel_id')
                    region_id = data.get('region_id')
                    target_id = data.get('target_id')
                    
                    # Get or create related objects
                    channel = Channel.objects.get(id=channel_id) if channel_id else None
                    region = AnalyticsRegion.objects.get(id=region_id) if region_id else None
                    target = AnalyticsTarget.objects.get(id=target_id) if target_id else None
                    
                    # Create or update analytics record
                    analytics, created = SfrAnalytics.objects.update_or_create(
                        measurement_date=data['measurement_date'],
                        measurement_time=data.get('measurement_time', time(0, 0)),
                        channel=channel,
                        region=region,
                        target=target,
                        indicator=data['indicator'],
                        sfr_channel_name=data.get('sfr_channel_name', ''),
                        defaults={
                            'value': data['value'],
                            'percentage': data.get('percentage'),
                            'quality_score': data.get('quality_score'),
                            'data_source': source,
                            'metadata': data.get('metadata', {}),
                        }
                    )
                    
                    processed_count += 1
                    
                    if created:
                        logger.info(f"Created new SFR analytics record: {analytics.id}")
                    else:
                        logger.info(f"Updated SFR analytics record: {analytics.id}")
                        
                except Exception as e:
                    error_count += 1
                    logger.error(f"Error processing analytics data: {e}")
                    continue
        
        logger.info(f"Processed {processed_count} analytics records, {error_count} errors")
        return {
            'processed': processed_count,
            'errors': error_count,
            'source': source
        }
        
    except Exception as exc:
        logger.error(f"Failed to process SFR analytics batch: {exc}")
        raise self.retry(exc=exc, countdown=60 * (self.request.retries + 1))


@shared_task(bind=True, max_retries=3)
def calculate_market_share(self, measurement_date: str, region_id: Optional[int] = None):
    """
    Calculate market share for all channels on a specific date.
    
    Args:
        measurement_date: Date to calculate market share for (YYYY-MM-DD)
        region_id: Optional region filter
    """
    try:
        date_obj = datetime.strptime(measurement_date, '%Y-%m-%d').date()
        
        # Get all channels
        channels = Channel.objects.filter(is_active=True)
        
        # Get regions to process
        if region_id:
            regions = AnalyticsRegion.objects.filter(id=region_id)
        else:
            regions = AnalyticsRegion.objects.filter(is_active=True)
        
        # Get targets
        targets = AnalyticsTarget.objects.filter(is_active=True)
        
        processed_count = 0
        
        for region in regions:
            for target in targets:
                # Get analytics data for the date, region, and target
                analytics_data = SfrAnalytics.objects.filter(
                    measurement_date=date_obj,
                    region=region,
                    target=target,
                    indicator='audience_share'  # Assuming this indicator for market share
                ).select_related('channel')
                
                if not analytics_data.exists():
                    continue
                
                # Calculate total audience
                total_audience = analytics_data.aggregate(
                    total=Sum('value')
                )['total'] or 0
                
                if total_audience == 0:
                    continue
                
                # Calculate market share for each channel
                channel_data = []
                for analytics in analytics_data:
                    if analytics.channel:
                        market_share_pct = (analytics.value / total_audience) * 100
                        channel_data.append({
                            'channel': analytics.channel,
                            'users': analytics.value,
                            'market_share': market_share_pct
                        })
                
                # Sort by market share and assign rankings
                channel_data.sort(key=lambda x: x['market_share'], reverse=True)
                
                for rank, data in enumerate(channel_data, 1):
                    MarketShare.objects.update_or_create(
                        measurement_date=date_obj,
                        channel=data['channel'],
                        region=region,
                        target=target,
                        defaults={
                            'market_share_percentage': data['market_share'],
                            'ranking': rank,
                            'total_users': total_audience,
                            'channel_users': data['users'],
                            'tool_name': 'automated_calculation',
                        }
                    )
                    processed_count += 1
        
        logger.info(f"Calculated market share for {processed_count} records")
        return {
            'processed': processed_count,
            'date': measurement_date,
            'region_id': region_id
        }
        
    except Exception as exc:
        logger.error(f"Failed to calculate market share: {exc}")
        raise self.retry(exc=exc, countdown=60 * (self.request.retries + 1))


@shared_task(bind=True, max_retries=3)
def process_verification_records(self, records_batch: List[Dict]):
    """
    Process a batch of verification records.
    
    Args:
        records_batch: List of verification record dictionaries
    """
    try:
        processed_count = 0
        error_count = 0
        
        with transaction.atomic():
            for record_data in records_batch:
                try:
                    # Get related objects
                    campaign = None
                    advertiser = None
                    
                    if record_data.get('campaign_id'):
                        campaign = Campaign.objects.get(id=record_data['campaign_id'])
                        advertiser = campaign.advertiser
                    elif record_data.get('advertiser_id'):
                        advertiser = Brand.objects.get(id=record_data['advertiser_id'])
                    
                    # Create or update verification record
                    verification, created = VerificationRecord.objects.update_or_create(
                        traffic_id=record_data['traffic_id'],
                        broadcast_date=record_data['broadcast_date'],
                        defaults={
                            'broadcast_time': record_data.get('broadcast_time'),
                            'network_name': record_data.get('network_name', ''),
                            'zone_name': record_data.get('zone_name', ''),
                            'spot_id': record_data.get('spot_id', ''),
                            'campaign': campaign,
                            'advertiser': advertiser,
                            'status': record_data.get('status', 'pending'),
                            'verification_complete': record_data.get('verification_complete', ''),
                            'metadata': record_data.get('metadata', {}),
                        }
                    )
                    
                    processed_count += 1
                    
                    if created:
                        logger.info(f"Created verification record: {verification.id}")
                    else:
                        logger.info(f"Updated verification record: {verification.id}")
                        
                except Exception as e:
                    error_count += 1
                    logger.error(f"Error processing verification record: {e}")
                    continue
        
        logger.info(f"Processed {processed_count} verification records, {error_count} errors")
        return {
            'processed': processed_count,
            'errors': error_count
        }
        
    except Exception as exc:
        logger.error(f"Failed to process verification records: {exc}")
        raise self.retry(exc=exc, countdown=60 * (self.request.retries + 1))


@shared_task(bind=True, max_retries=3)
def generate_predictions(self, model_id: int, prediction_date: str, prediction_type: str = 'sfr'):
    """
    Generate predictions using a specific model.
    
    Args:
        model_id: ID of the prediction model to use
        prediction_date: Date to generate predictions for (YYYY-MM-DD)
        prediction_type: Type of prediction ('sfr' or 'adbreak')
    """
    try:
        model = PredictionModel.objects.get(id=model_id, is_active=True)
        date_obj = datetime.strptime(prediction_date, '%Y-%m-%d').date()
        
        if prediction_type == 'sfr':
            return _generate_sfr_predictions(model, date_obj)
        elif prediction_type == 'adbreak':
            return _generate_adbreak_predictions(model, date_obj)
        else:
            raise ValueError(f"Unknown prediction type: {prediction_type}")
            
    except Exception as exc:
        logger.error(f"Failed to generate predictions: {exc}")
        raise self.retry(exc=exc, countdown=60 * (self.request.retries + 1))


def _generate_sfr_predictions(model: PredictionModel, prediction_date: datetime.date) -> Dict:
    """Generate SFR predictions for a specific date."""
    predictions_created = 0
    
    # Get channels, regions, and targets
    channels = Channel.objects.filter(is_active=True)
    regions = AnalyticsRegion.objects.filter(is_active=True)
    targets = AnalyticsTarget.objects.filter(is_active=True)
    
    # Get historical data for the model
    historical_days = 30  # Use last 30 days for prediction
    start_date = prediction_date - timedelta(days=historical_days)
    
    for channel in channels:
        for region in regions:
            for target in targets:
                for indicator in ['audience_share', 'rating', 'reach']:
                    # Get historical data
                    historical_data = SfrAnalytics.objects.filter(
                        measurement_date__range=[start_date, prediction_date - timedelta(days=1)],
                        channel=channel,
                        region=region,
                        target=target,
                        indicator=indicator
                    ).order_by('measurement_date')
                    
                    if historical_data.count() < 7:  # Need at least 7 days of data
                        continue
                    
                    # Simple prediction logic (in real implementation, use ML model)
                    avg_value = historical_data.aggregate(avg=Avg('value'))['avg']
                    recent_trend = _calculate_trend(historical_data)
                    
                    predicted_value = avg_value + (avg_value * recent_trend * 0.1)
                    confidence_score = min(0.95, 0.5 + (historical_data.count() / 60.0))
                    
                    # Create prediction
                    SfrPrediction.objects.update_or_create(
                        model=model,
                        prediction_date=prediction_date,
                        channel=channel,
                        region=region,
                        target=target,
                        indicator=indicator,
                        defaults={
                            'predicted_value': predicted_value,
                            'confidence_score': confidence_score,
                            'model_version': model.version,
                            'metadata': {
                                'historical_days': historical_days,
                                'data_points': historical_data.count(),
                                'avg_historical': float(avg_value),
                                'trend': recent_trend
                            }
                        }
                    )
                    predictions_created += 1
    
    logger.info(f"Generated {predictions_created} SFR predictions")
    return {'predictions_created': predictions_created, 'type': 'sfr'}


def _generate_adbreak_predictions(model: PredictionModel, prediction_date: datetime.date) -> Dict:
    """Generate ad break predictions for a specific date."""
    predictions_created = 0
    
    # Get channels
    channels = Channel.objects.filter(is_active=True)
    
    # Generate predictions for each hour of the day
    for channel in channels:
        for hour in range(24):
            prediction_datetime = datetime.combine(prediction_date, time(hour, 0))
            
            # Get historical ad break data
            historical_data = RealTimeAdbreak.objects.filter(
                channel=channel,
                start_time__hour=hour,
                start_time__date__gte=prediction_date - timedelta(days=30),
                start_time__date__lt=prediction_date,
                status='completed'
            )
            
            if historical_data.count() < 3:  # Need at least 3 historical records
                continue
            
            # Calculate predictions
            avg_duration = historical_data.aggregate(avg=Avg('duration_seconds'))['avg']
            avg_ad_count = historical_data.aggregate(avg=Avg('ad_count'))['avg']
            avg_revenue = historical_data.aggregate(avg=Avg('total_revenue'))['avg']
            
            confidence_score = min(0.95, 0.3 + (historical_data.count() / 30.0))
            
            # Create prediction
            AdbreakPrediction.objects.update_or_create(
                model=model,
                prediction_date=prediction_date,
                prediction_datetime=prediction_datetime,
                channel=channel,
                defaults={
                    'predicted_duration_seconds': int(avg_duration or 0),
                    'predicted_ad_count': int(avg_ad_count or 0),
                    'predicted_revenue': avg_revenue or Decimal('0.00'),
                    'confidence_score': confidence_score,
                    'model_version': model.version,
                    'metadata': {
                        'historical_records': historical_data.count(),
                        'avg_duration': float(avg_duration or 0),
                        'avg_ad_count': float(avg_ad_count or 0),
                        'avg_revenue': float(avg_revenue or 0)
                    }
                }
            )
            predictions_created += 1
    
    logger.info(f"Generated {predictions_created} ad break predictions")
    return {'predictions_created': predictions_created, 'type': 'adbreak'}


def _calculate_trend(queryset) -> float:
    """Calculate trend from historical data (simple linear trend)."""
    values = list(queryset.values_list('value', flat=True))
    if len(values) < 2:
        return 0.0
    
    # Simple trend calculation
    first_half = sum(values[:len(values)//2]) / (len(values)//2)
    second_half = sum(values[len(values)//2:]) / (len(values) - len(values)//2)
    
    if first_half == 0:
        return 0.0
    
    return (second_half - first_half) / first_half


@shared_task
def update_prediction_accuracy():
    """
    Update prediction accuracy by comparing predictions with actual values.
    """
    try:
        updated_count = 0
        
        # Update SFR prediction accuracy
        sfr_predictions = SfrPrediction.objects.filter(
            actual_value__isnull=True,
            prediction_date__lt=timezone.now().date()
        )
        
        for prediction in sfr_predictions:
            # Look for actual analytics data
            actual_data = SfrAnalytics.objects.filter(
                measurement_date=prediction.prediction_date,
                channel=prediction.channel,
                region=prediction.region,
                target=prediction.target,
                indicator=prediction.indicator
            ).first()
            
            if actual_data:
                prediction.actual_value = actual_data.value
                prediction.save()
                updated_count += 1
        
        # Update ad break prediction accuracy
        adbreak_predictions = AdbreakPrediction.objects.filter(
            actual_datetime__isnull=True,
            prediction_datetime__lt=timezone.now() - timedelta(hours=1)
        )
        
        for prediction in adbreak_predictions:
            # Look for actual ad break data
            actual_adbreak = RealTimeAdbreak.objects.filter(
                channel=prediction.channel,
                start_time__date=prediction.prediction_date,
                start_time__hour=prediction.prediction_datetime.hour,
                status='completed'
            ).first()
            
            if actual_adbreak:
                prediction.actual_datetime = actual_adbreak.start_time
                prediction.actual_duration_seconds = actual_adbreak.duration_seconds
                prediction.actual_ad_count = actual_adbreak.ad_count
                prediction.actual_revenue = actual_adbreak.total_revenue
                prediction.save()
                updated_count += 1
        
        logger.info(f"Updated accuracy for {updated_count} predictions")
        return {'updated': updated_count}
        
    except Exception as e:
        logger.error(f"Failed to update prediction accuracy: {e}")
        raise


@shared_task
def cleanup_old_analytics_data(days_to_keep: int = 365):
    """
    Clean up old analytics data to manage database size.
    
    Args:
        days_to_keep: Number of days of data to keep
    """
    try:
        cutoff_date = timezone.now().date() - timedelta(days=days_to_keep)
        
        # Clean up old SFR analytics
        sfr_deleted = SfrAnalytics.objects.filter(
            measurement_date__lt=cutoff_date
        ).delete()[0]
        
        # Clean up old market share data
        market_deleted = MarketShare.objects.filter(
            measurement_date__lt=cutoff_date
        ).delete()[0]
        
        # Clean up old verification records
        verif_deleted = VerificationRecord.objects.filter(
            broadcast_date__lt=cutoff_date
        ).delete()[0]
        
        # Clean up old predictions
        sfr_pred_deleted = SfrPrediction.objects.filter(
            prediction_date__lt=cutoff_date
        ).delete()[0]
        
        adbreak_pred_deleted = AdbreakPrediction.objects.filter(
            prediction_date__lt=cutoff_date
        ).delete()[0]
        
        # Clean up old activity logs
        activity_deleted = ActivityLog.objects.filter(
            activity_date__lt=cutoff_date
        ).delete()[0]
        
        # Clean up old real-time ad breaks
        adbreak_deleted = RealTimeAdbreak.objects.filter(
            start_time__date__lt=cutoff_date
        ).delete()[0]
        
        total_deleted = (
            sfr_deleted + market_deleted + verif_deleted +
            sfr_pred_deleted + adbreak_pred_deleted +
            activity_deleted + adbreak_deleted
        )
        
        logger.info(f"Cleaned up {total_deleted} old records")
        return {
            'total_deleted': total_deleted,
            'sfr_analytics': sfr_deleted,
            'market_share': market_deleted,
            'verification': verif_deleted,
            'sfr_predictions': sfr_pred_deleted,
            'adbreak_predictions': adbreak_pred_deleted,
            'activity_logs': activity_deleted,
            'real_time_adbreaks': adbreak_deleted,
            'cutoff_date': cutoff_date.isoformat()
        }
        
    except Exception as e:
        logger.error(f"Failed to cleanup old analytics data: {e}")
        raise


@shared_task
def generate_daily_reports(report_date: str, recipients: List[str]):
    """
    Generate and send daily analytics reports.
    
    Args:
        report_date: Date for the report (YYYY-MM-DD)
        recipients: List of email addresses to send the report to
    """
    try:
        date_obj = datetime.strptime(report_date, '%Y-%m-%d').date()
        
        # Generate report data
        report_data = {
            'date': date_obj,
            'sfr_summary': _get_sfr_summary(date_obj),
            'market_share_summary': _get_market_share_summary(date_obj),
            'verification_summary': _get_verification_summary(date_obj),
            'prediction_accuracy': _get_prediction_accuracy_summary(date_obj),
        }
        
        # Render email template
        html_content = render_to_string('reporting/daily_report_email.html', report_data)
        text_content = render_to_string('reporting/daily_report_email.txt', report_data)
        
        # Send email
        send_mail(
            subject=f'Daily Analytics Report - {date_obj}',
            message=text_content,
            html_message=html_content,
            from_email=settings.DEFAULT_FROM_EMAIL,
            recipient_list=recipients,
            fail_silently=False,
        )
        
        logger.info(f"Sent daily report for {report_date} to {len(recipients)} recipients")
        return {
            'report_date': report_date,
            'recipients_count': len(recipients),
            'status': 'sent'
        }
        
    except Exception as e:
        logger.error(f"Failed to generate daily report: {e}")
        raise


def _get_sfr_summary(date_obj: datetime.date) -> Dict:
    """Get SFR analytics summary for a date."""
    analytics = SfrAnalytics.objects.filter(measurement_date=date_obj)
    
    return {
        'total_records': analytics.count(),
        'channels_count': analytics.values('channel').distinct().count(),
        'avg_quality_score': analytics.aggregate(avg=Avg('quality_score'))['avg'],
        'top_channels': list(analytics.filter(
            indicator='audience_share'
        ).order_by('-value')[:5].values(
            'channel__name', 'value', 'percentage'
        ))
    }


def _get_market_share_summary(date_obj: datetime.date) -> Dict:
    """Get market share summary for a date."""
    market_share = MarketShare.objects.filter(measurement_date=date_obj)
    
    return {
        'total_records': market_share.count(),
        'top_channels': list(market_share.order_by('ranking')[:5].values(
            'channel__name', 'market_share_percentage', 'ranking'
        ))
    }


def _get_verification_summary(date_obj: datetime.date) -> Dict:
    """Get verification summary for a date."""
    verifications = VerificationRecord.objects.filter(broadcast_date=date_obj)
    
    status_counts = verifications.values('status').annotate(
        count=Count('id')
    ).order_by('status')
    
    return {
        'total_records': verifications.count(),
        'status_breakdown': list(status_counts),
        'pending_count': verifications.filter(status='pending').count(),
        'verified_count': verifications.filter(status='verified').count(),
        'failed_count': verifications.filter(status='failed').count(),
    }


def _get_prediction_accuracy_summary(date_obj: datetime.date) -> Dict:
    """Get prediction accuracy summary for a date."""
    sfr_predictions = SfrPrediction.objects.filter(
        prediction_date=date_obj,
        actual_value__isnull=False
    )
    
    adbreak_predictions = AdbreakPrediction.objects.filter(
        prediction_date=date_obj,
        actual_datetime__isnull=False
    )
    
    return {
        'sfr_predictions_count': sfr_predictions.count(),
        'adbreak_predictions_count': adbreak_predictions.count(),
        'avg_sfr_confidence': sfr_predictions.aggregate(
            avg=Avg('confidence_score')
        )['avg'],
        'avg_adbreak_confidence': adbreak_predictions.aggregate(
            avg=Avg('confidence_score')
        )['avg'],
    }