# Adtlas TV Advertising Platform - Channel Data Import Command
# Management command for importing TV channel data from external sources

from django.core.management.base import BaseCommand, CommandError
from django.db import transaction
from django.utils import timezone
from channels.models import GeographicZone, BroadcastNetwork, TVChannel, ChannelCoverage
import json
import csv
import logging
from typing import Dict, List, Any

# Configure logging for import operations
logger = logging.getLogger('adtlas.channels.import')


class Command(BaseCommand):
    """
    Django management command for importing TV channel data.
    
    Supports multiple data formats:
    - JSON: Structured channel data with relationships
    - CSV: Tabular channel data
    - XML: Industry-standard EPG formats
    
    Usage:
        python manage.py import_channels --file channels.json --format json
        python manage.py import_channels --file channels.csv --format csv --update
    """
    
    help = 'Import TV channel data from external sources (JSON, CSV, XML)'
    
    def add_arguments(self, parser):
        """
        Define command-line arguments for the import command.
        
        Args:
            parser: Django's argument parser instance
        """
        parser.add_argument(
            '--file',
            type=str,
            required=True,
            help='Path to the data file to import'
        )
        
        parser.add_argument(
            '--format',
            type=str,
            choices=['json', 'csv', 'xml'],
            default='json',
            help='Format of the input file (default: json)'
        )
        
        parser.add_argument(
            '--update',
            action='store_true',
            help='Update existing channels instead of creating new ones'
        )
        
        parser.add_argument(
            '--dry-run',
            action='store_true',
            help='Perform a dry run without making database changes'
        )
        
        parser.add_argument(
            '--batch-size',
            type=int,
            default=100,
            help='Number of records to process in each batch (default: 100)'
        )
    
    def handle(self, *args, **options):
        """
        Main command handler for importing channel data.
        
        Args:
            *args: Positional arguments
            **options: Command options from argument parser
        """
        file_path = options['file']
        file_format = options['format']
        update_existing = options['update']
        dry_run = options['dry_run']
        batch_size = options['batch_size']
        
        # Log import operation start
        logger.info(f"Starting channel import from {file_path} (format: {file_format})")
        
        if dry_run:
            self.stdout.write(
                self.style.WARNING('DRY RUN MODE: No database changes will be made')
            )
        
        try:
            # Load and validate data based on format
            if file_format == 'json':
                data = self._load_json_data(file_path)
            elif file_format == 'csv':
                data = self._load_csv_data(file_path)
            elif file_format == 'xml':
                data = self._load_xml_data(file_path)
            else:
                raise CommandError(f"Unsupported format: {file_format}")
            
            # Process data in batches
            total_records = len(data)
            processed = 0
            created = 0
            updated = 0
            errors = 0
            
            self.stdout.write(f"Processing {total_records} records...")
            
            # Process data in batches for better performance
            for i in range(0, total_records, batch_size):
                batch = data[i:i + batch_size]
                
                if not dry_run:
                    with transaction.atomic():
                        batch_created, batch_updated, batch_errors = self._process_batch(
                            batch, update_existing
                        )
                        created += batch_created
                        updated += batch_updated
                        errors += batch_errors
                else:
                    # Validate data without saving
                    batch_errors = self._validate_batch(batch)
                    errors += batch_errors
                
                processed += len(batch)
                
                # Progress indicator
                progress = (processed / total_records) * 100
                self.stdout.write(
                    f"Progress: {processed}/{total_records} ({progress:.1f}%)",
                    ending='\r'
                )
            
            # Final summary
            self.stdout.write('')  # New line after progress
            
            if dry_run:
                self.stdout.write(
                    self.style.SUCCESS(
                        f"DRY RUN COMPLETE: Validated {processed} records, {errors} errors found"
                    )
                )
            else:
                self.stdout.write(
                    self.style.SUCCESS(
                        f"Import complete: {created} created, {updated} updated, {errors} errors"
                    )
                )
            
            # Log completion
            logger.info(
                f"Channel import completed: {created} created, {updated} updated, {errors} errors"
            )
            
        except Exception as e:
            error_msg = f"Import failed: {str(e)}"
            logger.error(error_msg, exc_info=True)
            raise CommandError(error_msg)
    
    def _load_json_data(self, file_path: str) -> List[Dict[str, Any]]:
        """
        Load channel data from JSON file.
        
        Args:
            file_path: Path to the JSON file
            
        Returns:
            List of channel data dictionaries
            
        Raises:
            CommandError: If file cannot be loaded or parsed
        """
        try:
            with open(file_path, 'r', encoding='utf-8') as file:
                data = json.load(file)
                
            # Validate JSON structure
            if isinstance(data, dict) and 'channels' in data:
                return data['channels']
            elif isinstance(data, list):
                return data
            else:
                raise CommandError("Invalid JSON structure: expected 'channels' array or direct array")
                
        except FileNotFoundError:
            raise CommandError(f"File not found: {file_path}")
        except json.JSONDecodeError as e:
            raise CommandError(f"Invalid JSON format: {str(e)}")
    
    def _load_csv_data(self, file_path: str) -> List[Dict[str, Any]]:
        """
        Load channel data from CSV file.
        
        Expected CSV columns:
        - name: Channel name
        - call_sign: Channel call sign
        - frequency: Broadcast frequency
        - network_name: Network name
        - zone_name: Geographic zone name
        - is_active: Active status (true/false)
        
        Args:
            file_path: Path to the CSV file
            
        Returns:
            List of channel data dictionaries
        """
        try:
            data = []
            with open(file_path, 'r', encoding='utf-8') as file:
                reader = csv.DictReader(file)
                
                # Validate required columns
                required_columns = ['name', 'call_sign', 'frequency']
                missing_columns = [col for col in required_columns if col not in reader.fieldnames]
                
                if missing_columns:
                    raise CommandError(f"Missing required columns: {', '.join(missing_columns)}")
                
                for row in reader:
                    # Convert string boolean values
                    if 'is_active' in row:
                        row['is_active'] = row['is_active'].lower() in ('true', '1', 'yes')
                    
                    data.append(row)
            
            return data
            
        except FileNotFoundError:
            raise CommandError(f"File not found: {file_path}")
        except Exception as e:
            raise CommandError(f"Error reading CSV file: {str(e)}")
    
    def _load_xml_data(self, file_path: str) -> List[Dict[str, Any]]:
        """
        Load channel data from XML file (EPG format).
        
        Args:
            file_path: Path to the XML file
            
        Returns:
            List of channel data dictionaries
        """
        # XML parsing implementation would go here
        # For now, raise not implemented error
        raise CommandError("XML format support not yet implemented")
    
    def _process_batch(self, batch: List[Dict[str, Any]], update_existing: bool) -> tuple:
        """
        Process a batch of channel records.
        
        Args:
            batch: List of channel data dictionaries
            update_existing: Whether to update existing channels
            
        Returns:
            Tuple of (created_count, updated_count, error_count)
        """
        created = 0
        updated = 0
        errors = 0
        
        for record in batch:
            try:
                # Get or create related objects
                network = self._get_or_create_network(record.get('network_name'))
                zone = self._get_or_create_zone(record.get('zone_name'))
                
                # Create or update channel
                channel_data = {
                    'name': record['name'],
                    'call_sign': record['call_sign'],
                    'frequency': record['frequency'],
                    'network': network,
                    'is_active': record.get('is_active', True),
                    'description': record.get('description', ''),
                    'website_url': record.get('website_url', ''),
                }
                
                if update_existing:
                    channel, was_created = TVChannel.objects.update_or_create(
                        call_sign=record['call_sign'],
                        defaults=channel_data
                    )
                    if was_created:
                        created += 1
                    else:
                        updated += 1
                else:
                    channel = TVChannel.objects.create(**channel_data)
                    created += 1
                
                # Create coverage relationship if zone is specified
                if zone:
                    ChannelCoverage.objects.get_or_create(
                        channel=channel,
                        zone=zone,
                        defaults={
                            'coverage_percentage': record.get('coverage_percentage', 100.0),
                            'signal_strength': record.get('signal_strength', 'strong')
                        }
                    )
                
            except Exception as e:
                logger.error(f"Error processing record {record}: {str(e)}")
                errors += 1
        
        return created, updated, errors
    
    def _validate_batch(self, batch: List[Dict[str, Any]]) -> int:
        """
        Validate a batch of records without saving to database.
        
        Args:
            batch: List of channel data dictionaries
            
        Returns:
            Number of validation errors found
        """
        errors = 0
        
        for record in batch:
            try:
                # Validate required fields
                required_fields = ['name', 'call_sign', 'frequency']
                missing_fields = [field for field in required_fields if not record.get(field)]
                
                if missing_fields:
                    self.stdout.write(
                        self.style.ERROR(
                            f"Missing required fields in record: {', '.join(missing_fields)}"
                        )
                    )
                    errors += 1
                    continue
                
                # Validate frequency format
                try:
                    float(record['frequency'])
                except (ValueError, TypeError):
                    self.stdout.write(
                        self.style.ERROR(
                            f"Invalid frequency format: {record.get('frequency')}"
                        )
                    )
                    errors += 1
                
            except Exception as e:
                self.stdout.write(
                    self.style.ERROR(f"Validation error: {str(e)}")
                )
                errors += 1
        
        return errors
    
    def _get_or_create_network(self, network_name: str) -> BroadcastNetwork:
        """
        Get or create a broadcast network by name.
        
        Args:
            network_name: Name of the network
            
        Returns:
            BroadcastNetwork instance or None if name is empty
        """
        if not network_name:
            return None
        
        network, created = BroadcastNetwork.objects.get_or_create(
            name=network_name,
            defaults={
                'description': f'Imported network: {network_name}',
                'is_active': True
            }
        )
        
        if created:
            logger.info(f"Created new network: {network_name}")
        
        return network
    
    def _get_or_create_zone(self, zone_name: str) -> GeographicZone:
        """
        Get or create a geographic zone by name.
        
        Args:
            zone_name: Name of the zone
            
        Returns:
            GeographicZone instance or None if name is empty
        """
        if not zone_name:
            return None
        
        zone, created = GeographicZone.objects.get_or_create(
            name=zone_name,
            defaults={
                'zone_type': 'city',  # Default type
                'description': f'Imported zone: {zone_name}',
                'is_active': True
            }
        )
        
        if created:
            logger.info(f"Created new zone: {zone_name}")
        
        return zone