"""
Unit tests for streams app Celery tasks.
"""

import tempfile
import shutil
from unittest.mock import patch, MagicMock
from django.test import TestCase
from django.contrib.auth import get_user_model
from django.utils import timezone
from celery.exceptions import Retry

from apps.streams.models import Channel, StreamSession, HLSSegment
from apps.streams.tasks import (
    start_stream_capture, stop_stream_capture, process_new_segment,
    stream_health_check, cleanup_old_segments, send_stream_status_notification
)

User = get_user_model()


class StreamTasksTest(TestCase):
    """Test cases for stream-related Celery tasks."""
    
    def setUp(self):
        self.user = User.objects.create_user(
            username='testuser',
            email='test@example.com',
            password='testpass123'
        )
        self.temp_dir = tempfile.mkdtemp()
        
        self.channel = Channel.objects.create(
            name='Test Channel',
            slug='test-channel',
            hls_url='https://example.com/stream.m3u8',
            output_directory=self.temp_dir,
            created_by=self.user
        )
    
    def tearDown(self):
        shutil.rmtree(self.temp_dir, ignore_errors=True)
    
    @patch('apps.streams.services.StreamCaptureService.start_capture')
    @patch('apps.notifications.tasks.send_notification_via_rule.delay')
    def test_start_stream_capture_success(self, mock_notification, mock_capture):
        """Test successful stream capture start task."""
        # Mock successful capture
        mock_session = MagicMock()
        mock_session.id = 'test-session-id'
        mock_capture.return_value = mock_session
        
        result = start_stream_capture(str(self.channel.id))
        
        self.assertTrue(result['success'])
        self.assertEqual(result['session_id'], 'test-session-id')
        mock_capture.assert_called_once()
        mock_notification.assert_called_once()
    
    @patch('apps.streams.services.StreamCaptureService.start_capture')
    def test_start_stream_capture_failure(self, mock_capture):
        """Test stream capture start failure."""
        mock_capture.return_value = None  # Capture failed
        
        result = start_stream_capture(str(self.channel.id))
        
        self.assertFalse(result['success'])
        self.assertIn('error', result)
    
    @patch('apps.streams.services.StreamCaptureService.stop_capture')
    @patch('apps.notifications.tasks.send_notification_via_rule.delay')
    def test_stop_stream_capture_success(self, mock_notification, mock_stop):
        """Test successful stream capture stop task."""
        # Create active session
        session = StreamSession.objects.create(
            channel=self.channel,
            status='active',
            started_at=timezone.now()
        )
        
        mock_stop.return_value = True
        
        result = stop_stream_capture(str(self.channel.id))
        
        self.assertTrue(result['success'])
        mock_stop.assert_called_once()
        mock_notification.assert_called_once()
    
    def test_stop_stream_capture_no_active_session(self):
        """Test stopping stream with no active session."""
        result = stop_stream_capture(str(self.channel.id))
        
        self.assertFalse(result['success'])
        self.assertIn('no active session', result['error'].lower())
    
    @patch('apps.jingles.services.JingleDetector.process_detection')
    def test_process_new_segment_with_detection(self, mock_process):
        """Test processing new segment with jingle detection."""
        # Create session and segment
        session = StreamSession.objects.create(
            channel=self.channel,
            status='active',
            started_at=timezone.now()
        )
        
        segment = HLSSegment.objects.create(
            session=session,
            sequence_number=1,
            filename='segment_001.ts',
            file_path='/tmp/segment_001.ts',
            duration=6.0
        )
        
        # Mock detection result
        mock_detection = MagicMock()
        mock_detection.id = 'detection-id'
        mock_process.return_value = mock_detection
        
        result = process_new_segment(str(segment.id))
        
        self.assertTrue(result['success'])
        self.assertTrue(result['detection_found'])
        mock_process.assert_called_once()
    
    @patch('apps.jingles.services.JingleDetector.process_detection')
    def test_process_new_segment_no_detection(self, mock_process):
        """Test processing new segment without detection."""
        session = StreamSession.objects.create(
            channel=self.channel,
            status='active',
            started_at=timezone.now()
        )
        
        segment = HLSSegment.objects.create(
            session=session,
            sequence_number=1,
            filename='segment_001.ts',
            file_path='/tmp/segment_001.ts',
            duration=6.0
        )
        
        mock_process.return_value = None  # No detection
        
        result = process_new_segment(str(segment.id))
        
        self.assertTrue(result['success'])
        self.assertFalse(result['detection_found'])
    
    @patch('psutil.process_iter')
    def test_stream_health_check_healthy(self, mock_process_iter):
        """Test stream health check with healthy streams."""
        # Create active session
        session = StreamSession.objects.create(
            channel=self.channel,
            status='active',
            process_id=12345,
            started_at=timezone.now()
        )
        
        # Mock process exists
        mock_process = MagicMock()
        mock_process.info = {'pid': 12345, 'name': 'ffmpeg'}
        mock_process_iter.return_value = [mock_process]
        
        result = stream_health_check()
        
        self.assertTrue(result['success'])
        self.assertEqual(result['healthy_streams'], 1)
        self.assertEqual(result['unhealthy_streams'], 0)
    
    @patch('psutil.process_iter')
    def test_stream_health_check_unhealthy(self, mock_process_iter):
        """Test stream health check with dead processes."""
        # Create active session with dead process
        session = StreamSession.objects.create(
            channel=self.channel,
            status='active',
            process_id=99999,  # Non-existent PID
            started_at=timezone.now()
        )
        
        # Mock no matching processes
        mock_process_iter.return_value = []
        
        result = stream_health_check()
        
        self.assertTrue(result['success'])
        self.assertEqual(result['healthy_streams'], 0)
        self.assertEqual(result['unhealthy_streams'], 1)
        
        # Session should be marked as failed
        session.refresh_from_db()
        self.assertEqual(session.status, 'failed')
    
    @patch('os.listdir')
    @patch('os.path.getmtime')
    @patch('os.remove')
    def test_cleanup_old_segments_task(self, mock_remove, mock_getmtime, mock_listdir):
        """Test cleanup old segments task."""
        # Mock old files
        mock_listdir.return_value = ['segment_001.ts', 'segment_002.ts']
        mock_getmtime.return_value = timezone.now().timestamp() - 7200  # 2 hours ago
        
        result = cleanup_old_segments(str(self.channel.id), max_age_hours=1)
        
        self.assertTrue(result['success'])
        self.assertEqual(result['files_cleaned'], 2)
        self.assertEqual(mock_remove.call_count, 2)
    
    @patch('apps.notifications.tasks.send_notification_via_rule.delay')
    def test_send_stream_status_notification(self, mock_notification):
        """Test sending stream status notification."""
        # Create active session
        session = StreamSession.objects.create(
            channel=self.channel,
            status='active',
            started_at=timezone.now()
        )
        
        result = send_stream_status_notification()
        
        self.assertTrue(result['success'])
        self.assertEqual(result['notifications_sent'], 1)
        mock_notification.assert_called_once()
    
    def test_task_retry_on_failure(self):
        """Test task retry mechanism on failure."""
        with patch('apps.streams.models.Channel.objects.get') as mock_get:
            mock_get.side_effect = Exception("Database error")
            
            # This should raise an exception and trigger retry
            with self.assertRaises(Exception):
                start_stream_capture(str(self.channel.id))
