"""Enhanced task manager with Kafka integration and comprehensive task tracking."""

import asyncio
import logging
import time
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from collections import defaultdict
import json
import threading
from concurrent.futures import ThreadPoolExecutor

from src.models.task import (
    TaskStatus, TaskPriority, TaskType, TaskStats, ScrapeTaskRequest,
    BulkScrapeTaskRequest, TaskResponse, TaskResult, QueueStats,
    generate_task_id, create_task_stats
)
from src.models.schemas import SearchResult
from src.services.kafka_service import kafka_service
from src.services.database import db_service
from src.services.cache_service import cache_service
from src.scrapers.multi_engine_manager import multi_engine_manager

logger = logging.getLogger(__name__)


class TaskManager:
    """Enhanced task manager with Kafka integration and comprehensive tracking."""
    
    def __init__(self):
        self.tasks: Dict[str, TaskStats] = {}
        self.task_results: Dict[str, TaskResult] = {}
        self.queue_stats = QueueStats(
            total_tasks=0,
            pending_tasks=0,
            processing_tasks=0,
            completed_tasks=0,
            failed_tasks=0,
            queue_size=0,
            average_processing_time=0.0,
            throughput_per_minute=0.0,
            active_workers=0
        )
        self.user_stats: Dict[str, Dict[str, Any]] = defaultdict(dict)
        self._lock = threading.Lock()
        self._running = False
        self._consumer_thread = None
        self._stats_thread = None
        
        # Initialize Kafka connections
        self._init_kafka()
    
    def _init_kafka(self):
        """Initialize Kafka connections."""
        try:
            # Connect producer
            if kafka_service.connect_producer():
                logger.info("Kafka producer connected successfully")
            else:
                logger.warning("Failed to connect Kafka producer")
            
            # Connect consumer with message handler
            if kafka_service.connect_consumer(
                group_id="task_manager_consumer",
                message_handler=self._handle_kafka_message
            ):
                logger.info("Kafka consumer connected successfully")
            else:
                logger.warning("Failed to connect Kafka consumer")
        except Exception as e:
            logger.error(f"Error initializing Kafka: {e}")
    
    def _handle_kafka_message(self, message: Dict[str, Any]):
        """Handle incoming Kafka messages."""
        try:
            message_type = message.get('type')
            
            if message_type == 'scrape_task':
                self._process_scrape_task(message)
            elif message_type == 'bulk_scrape_task':
                self._process_bulk_scrape_task(message)
            elif message_type == 'task_status_update':
                self._update_task_status(message)
            else:
                logger.warning(f"Unknown message type: {message_type}")
        except Exception as e:
            logger.error(f"Error handling Kafka message: {e}")
    
    async def submit_scrape_task(self, request: ScrapeTaskRequest, user_id: str = None) -> TaskResponse:
        """Submit a new scrape task."""
        task_id = generate_task_id()
        
        # Create task stats
        task_stats = create_task_stats(task_id, TaskStatus.PENDING)
        
        # Store task
        with self._lock:
            self.tasks[task_id] = task_stats
            self.queue_stats.total_tasks += 1
            self.queue_stats.pending_tasks += 1
        
        # Send to Kafka
        task_message = {
            'task_id': task_id,
            'type': 'scrape_task',
            'query': request.query,
            'engines': request.engines,
            'max_results': request.max_results,
            'priority': request.priority.value,
            'callback_url': request.callback_url,
            'user_id': user_id or request.user_id,
            'metadata': request.metadata,
            'timeout': request.timeout,
            'retry_count': request.retry_count,
            'created_at': datetime.utcnow().isoformat()
        }
        
        success = await kafka_service.send_message(task_message, key=task_id)
        
        if success:
            task_stats.status = TaskStatus.QUEUED
            logger.info(f"Task {task_id} submitted successfully")
            
            return TaskResponse(
                task_id=task_id,
                status=TaskStatus.QUEUED,
                message="Task submitted successfully",
                created_at=task_stats.created_at,
                queue_position=self.queue_stats.pending_tasks
            )
        else:
            task_stats.status = TaskStatus.FAILED
            task_stats.last_error = "Failed to submit task to queue"
            
            return TaskResponse(
                task_id=task_id,
                status=TaskStatus.FAILED,
                message="Failed to submit task to queue",
                created_at=task_stats.created_at
            )
    
    async def submit_bulk_scrape_task(self, request: BulkScrapeTaskRequest, user_id: str = None) -> TaskResponse:
        """Submit a bulk scrape task."""
        task_id = generate_task_id()
        
        # Create task stats
        task_stats = create_task_stats(task_id, TaskStatus.PENDING)
        
        # Store task
        with self._lock:
            self.tasks[task_id] = task_stats
            self.queue_stats.total_tasks += 1
            self.queue_stats.pending_tasks += 1
        
        # Send to Kafka
        task_message = {
            'task_id': task_id,
            'type': 'bulk_scrape_task',
            'queries': request.queries,
            'engines': request.engines,
            'max_results': request.max_results,
            'priority': request.priority.value,
            'callback_url': request.callback_url,
            'user_id': user_id or request.user_id,
            'metadata': request.metadata,
            'parallel_queries': request.parallel_queries,
            'created_at': datetime.utcnow().isoformat()
        }
        
        success = await kafka_service.send_message(task_message, key=task_id)
        
        if success:
            task_stats.status = TaskStatus.QUEUED
            logger.info(f"Bulk task {task_id} submitted successfully")
            
            return TaskResponse(
                task_id=task_id,
                status=TaskStatus.QUEUED,
                message="Bulk task submitted successfully",
                created_at=task_stats.created_at,
                queue_position=self.queue_stats.pending_tasks
            )
        else:
            task_stats.status = TaskStatus.FAILED
            task_stats.last_error = "Failed to submit bulk task to queue"
            
            return TaskResponse(
                task_id=task_id,
                status=TaskStatus.FAILED,
                message="Failed to submit bulk task to queue",
                created_at=task_stats.created_at
            )
    
    def _process_scrape_task(self, message: Dict[str, Any]):
        """Process a scrape task from Kafka."""
        task_id = message.get('task_id')
        
        try:
            # Update task status
            if task_id in self.tasks:
                self.tasks[task_id].status = TaskStatus.PROCESSING
                self.tasks[task_id].started_at = datetime.utcnow()
                
                with self._lock:
                    self.queue_stats.pending_tasks -= 1
                    self.queue_stats.processing_tasks += 1
            
            # Process the scrape task
            asyncio.run(self._execute_scrape_task(message))
            
        except Exception as e:
            logger.error(f"Error processing scrape task {task_id}: {e}")
            self._mark_task_failed(task_id, str(e))
    
    async def _execute_scrape_task(self, message: Dict[str, Any]):
        """Execute a scrape task."""
        task_id = message.get('task_id')
        start_time = time.time()
        
        try:
            # Create scrape request
            from src.models.schemas import ScrapeRequest
            scrape_request = ScrapeRequest(
                query=message.get('query'),
                engines=message.get('engines', ['bing']),
                max_results=message.get('max_results', 10)
            )
            
            # Execute scraping
            results = await multi_engine_manager.scrape_multiple_engines_parallel(scrape_request)
            
            # Calculate execution time
            execution_time = time.time() - start_time
            
            # Update task stats
            if task_id in self.tasks:
                task_stats = self.tasks[task_id]
                task_stats.status = TaskStatus.COMPLETED
                task_stats.completed_at = datetime.utcnow()
                task_stats.execution_time = execution_time
                task_stats.total_results = len(results)
                task_stats.progress_percentage = 100.0
                
                # Update engine stats
                successful_engines = list(set(r.engine for r in results))
                task_stats.successful_engines = successful_engines
                
                with self._lock:
                    self.queue_stats.processing_tasks -= 1
                    self.queue_stats.completed_tasks += 1
            
            # Store result
            task_result = TaskResult(
                task_id=task_id,
                status=TaskStatus.COMPLETED,
                query=scrape_request.query,
                engines_used=scrape_request.engines,
                total_results=len(results),
                execution_time=execution_time,
                created_at=self.tasks[task_id].created_at,
                completed_at=datetime.utcnow(),
                results=[{
                    'title': r.title,
                    'url': r.url,
                    'description': r.description,
                    'engine': r.engine,
                    'position': r.position,
                    'timestamp': r.timestamp.isoformat()
                } for r in results],
                stats=self.tasks[task_id]
            )
            
            self.task_results[task_id] = task_result
            
            # Send result to Kafka
            result_message = {
                'task_id': task_id,
                'type': 'scrape_result',
                'status': 'completed',
                'results_count': len(results),
                'execution_time': execution_time,
                'callback_url': message.get('callback_url'),
                'user_id': message.get('user_id'),
                'completed_at': datetime.utcnow().isoformat()
            }
            
            await kafka_service.send_message(result_message, key=task_id)
            
            # Update user stats
            user_id = message.get('user_id')
            if user_id:
                self._update_user_stats(user_id, task_result)
            
            logger.info(f"Task {task_id} completed successfully in {execution_time:.2f}s")
            
        except Exception as e:
            logger.error(f"Error executing scrape task {task_id}: {e}")
            self._mark_task_failed(task_id, str(e))
    
    def _process_bulk_scrape_task(self, message: Dict[str, Any]):
        """Process a bulk scrape task from Kafka."""
        task_id = message.get('task_id')
        
        try:
            # Update task status
            if task_id in self.tasks:
                self.tasks[task_id].status = TaskStatus.PROCESSING
                self.tasks[task_id].started_at = datetime.utcnow()
                
                with self._lock:
                    self.queue_stats.pending_tasks -= 1
                    self.queue_stats.processing_tasks += 1
            
            # Process the bulk scrape task
            asyncio.run(self._execute_bulk_scrape_task(message))
            
        except Exception as e:
            logger.error(f"Error processing bulk scrape task {task_id}: {e}")
            self._mark_task_failed(task_id, str(e))
    
    async def _execute_bulk_scrape_task(self, message: Dict[str, Any]):
        """Execute a bulk scrape task."""
        task_id = message.get('task_id')
        queries = message.get('queries', [])
        engines = message.get('engines', ['bing'])
        max_results = message.get('max_results', 10)
        parallel_queries = message.get('parallel_queries', 3)
        
        start_time = time.time()
        all_results = []
        
        try:
            # Process queries in batches
            for i in range(0, len(queries), parallel_queries):
                batch = queries[i:i + parallel_queries]
                batch_tasks = []
                
                for query in batch:
                    from src.models.schemas import ScrapeRequest
                    scrape_request = ScrapeRequest(
                        query=query,
                        engines=engines,
                        max_results=max_results
                    )
                    
                    task = multi_engine_manager.scrape_multiple_engines_parallel(scrape_request)
                    batch_tasks.append(task)
                
                # Execute batch in parallel
                batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
                
                for results in batch_results:
                    if isinstance(results, Exception):
                        logger.error(f"Batch task failed: {results}")
                        continue
                    all_results.extend(results)
                
                # Update progress
                if task_id in self.tasks:
                    progress = ((i + len(batch)) / len(queries)) * 100
                    self.tasks[task_id].progress_percentage = progress
            
            # Calculate execution time
            execution_time = time.time() - start_time
            
            # Update task stats
            if task_id in self.tasks:
                task_stats = self.tasks[task_id]
                task_stats.status = TaskStatus.COMPLETED
                task_stats.completed_at = datetime.utcnow()
                task_stats.execution_time = execution_time
                task_stats.total_results = len(all_results)
                task_stats.progress_percentage = 100.0
                
                with self._lock:
                    self.queue_stats.processing_tasks -= 1
                    self.queue_stats.completed_tasks += 1
            
            # Store result
            task_result = TaskResult(
                task_id=task_id,
                status=TaskStatus.COMPLETED,
                query=f"Bulk scrape ({len(queries)} queries)",
                engines_used=engines,
                total_results=len(all_results),
                execution_time=execution_time,
                created_at=self.tasks[task_id].created_at,
                completed_at=datetime.utcnow(),
                results=[{
                    'title': r.title,
                    'url': r.url,
                    'description': r.description,
                    'engine': r.engine,
                    'position': r.position,
                    'timestamp': r.timestamp.isoformat()
                } for r in all_results],
                stats=self.tasks[task_id]
            )
            
            self.task_results[task_id] = task_result
            
            logger.info(f"Bulk task {task_id} completed successfully in {execution_time:.2f}s")
            
        except Exception as e:
            logger.error(f"Error executing bulk scrape task {task_id}: {e}")
            self._mark_task_failed(task_id, str(e))
    
    def _mark_task_failed(self, task_id: str, error: str):
        """Mark a task as failed."""
        if task_id in self.tasks:
            self.tasks[task_id].status = TaskStatus.FAILED
            self.tasks[task_id].last_error = error
            self.tasks[task_id].completed_at = datetime.utcnow()
            
            with self._lock:
                if self.tasks[task_id].status == TaskStatus.PROCESSING:
                    self.queue_stats.processing_tasks -= 1
                else:
                    self.queue_stats.pending_tasks -= 1
                self.queue_stats.failed_tasks += 1
    
    def _update_user_stats(self, user_id: str, task_result: TaskResult):
        """Update user statistics."""
        if user_id not in self.user_stats:
            self.user_stats[user_id] = {
                'total_tasks': 0,
                'completed_tasks': 0,
                'failed_tasks': 0,
                'total_results': 0,
                'total_execution_time': 0,
                'last_activity': datetime.utcnow()
            }
        
        stats = self.user_stats[user_id]
        stats['total_tasks'] += 1
        stats['last_activity'] = datetime.utcnow()
        
        if task_result.status == TaskStatus.COMPLETED:
            stats['completed_tasks'] += 1
            stats['total_results'] += task_result.total_results
            stats['total_execution_time'] += task_result.execution_time
        else:
            stats['failed_tasks'] += 1
    
    def get_task_status(self, task_id: str) -> Optional[TaskStats]:
        """Get task status."""
        return self.tasks.get(task_id)
    
    def get_task_result(self, task_id: str) -> Optional[TaskResult]:
        """Get task result."""
        return self.task_results.get(task_id)
    
    def get_user_tasks(self, user_id: str, page: int = 1, page_size: int = 20) -> List[TaskStats]:
        """Get tasks for a specific user."""
        # This is a simplified version - in production, you'd query the database
        user_tasks = []
        for task_stats in self.tasks.values():
            # You would need to store user_id in TaskStats for proper filtering
            user_tasks.append(task_stats)
        
        start_idx = (page - 1) * page_size
        end_idx = start_idx + page_size
        return user_tasks[start_idx:end_idx]
    
    def get_queue_stats(self) -> QueueStats:
        """Get current queue statistics."""
        return self.queue_stats
    
    def cancel_task(self, task_id: str) -> bool:
        """Cancel a task."""
        if task_id in self.tasks:
            task = self.tasks[task_id]
            if task.status in [TaskStatus.PENDING, TaskStatus.QUEUED]:
                task.status = TaskStatus.CANCELLED
                task.completed_at = datetime.utcnow()
                
                with self._lock:
                    self.queue_stats.pending_tasks -= 1
                
                logger.info(f"Task {task_id} cancelled")
                return True
        return False
    
    def cleanup_old_tasks(self, days: int = 7):
        """Clean up old tasks."""
        cutoff_date = datetime.utcnow() - timedelta(days=days)
        
        tasks_to_remove = []
        for task_id, task_stats in self.tasks.items():
            if task_stats.created_at < cutoff_date:
                tasks_to_remove.append(task_id)
        
        for task_id in tasks_to_remove:
            del self.tasks[task_id]
            if task_id in self.task_results:
                del self.task_results[task_id]
        
        logger.info(f"Cleaned up {len(tasks_to_remove)} old tasks")
    
    def close(self):
        """Close task manager."""
        self._running = False
        if kafka_service:
            kafka_service.close()
        logger.info("Task manager closed")


# Global task manager instance
task_manager = TaskManager()
