"""
Unit tests for jingles app services.
"""

import tempfile
import shutil
import os
from unittest.mock import patch, MagicMock, mock_open
from django.test import TestCase
from django.contrib.auth import get_user_model
from django.utils import timezone
from datetime import timedelta

from apps.streams.models import Channel, StreamSession
from apps.jingles.models import JingleTemplate, JingleDetection, AdBreak
from apps.jingles.services import JingleDetector, AdBreakAnalyzer

User = get_user_model()


class JingleDetectorTest(TestCase):
    """Test cases for JingleDetector service."""
    
    def setUp(self):
        self.user = User.objects.create_user(
            username='testuser',
            email='test@example.com',
            password='testpass123'
        )
        self.temp_dir = tempfile.mkdtemp()
        
        self.channel = Channel.objects.create(
            name='Test Channel',
            slug='test-channel',
            hls_url='https://example.com/stream.m3u8',
            output_directory=self.temp_dir,
            created_by=self.user
        )
        
        self.session = StreamSession.objects.create(
            channel=self.channel,
            status='active',
            started_at=timezone.now()
        )
        
        self.template = JingleTemplate.objects.create(
            name='Test Jingle',
            slug='test-jingle',
            image_path='/tmp/test_jingle.jpg',
            similarity_threshold=0.8,
            created_by=self.user
        )
        
        self.detector = JingleDetector()
    
    def tearDown(self):
        shutil.rmtree(self.temp_dir, ignore_errors=True)
    
    @patch('cv2.imread')
    @patch('cv2.matchTemplate')
    @patch('cv2.minMaxLoc')
    def test_compare_images_match(self, mock_minmaxloc, mock_matchtemplate, mock_imread):
        """Test image comparison with match."""
        # Mock OpenCV functions
        mock_imread.return_value = MagicMock()  # Mock image data
        mock_matchtemplate.return_value = MagicMock()
        mock_minmaxloc.return_value = (0.15, 0.95, (0, 0), (100, 100))  # min_val, max_val, min_loc, max_loc
        
        similarity = self.detector.compare_images('/tmp/template.jpg', '/tmp/frame.jpg')
        
        self.assertEqual(similarity, 0.15)  # Should return min_val
        mock_imread.assert_called()
        mock_matchtemplate.assert_called_once()
        mock_minmaxloc.assert_called_once()
    
    @patch('cv2.imread')
    def test_compare_images_file_not_found(self, mock_imread):
        """Test image comparison with missing file."""
        mock_imread.return_value = None  # Simulate file not found
        
        similarity = self.detector.compare_images('/tmp/nonexistent.jpg', '/tmp/frame.jpg')
        
        self.assertEqual(similarity, 1.0)  # Should return max similarity (no match)
    
    @patch('subprocess.run')
    def test_extract_iframes_success(self, mock_subprocess):
        """Test I-frame extraction success."""
        mock_subprocess.return_value = MagicMock(returncode=0)
        
        # Mock the creation of iframe files
        iframe_dir = '/tmp/iframes'
        with patch('os.makedirs'), \
             patch('os.listdir', return_value=['iframe_001.jpg', 'iframe_002.jpg']), \
             patch('os.path.join', side_effect=lambda *args: '/'.join(args)):
            
            iframes = self.detector.extract_iframes('/tmp/segment.ts', iframe_dir)
            
            self.assertEqual(len(iframes), 2)
            self.assertIn('iframe_001.jpg', iframes[0])
            self.assertIn('iframe_002.jpg', iframes[1])
    
    @patch('subprocess.run')
    def test_extract_iframes_failure(self, mock_subprocess):
        """Test I-frame extraction failure."""
        mock_subprocess.return_value = MagicMock(returncode=1)
        
        iframes = self.detector.extract_iframes('/tmp/segment.ts', '/tmp/iframes')
        
        self.assertEqual(iframes, [])
    
    @patch.object(JingleDetector, 'extract_iframes')
    @patch.object(JingleDetector, 'compare_images')
    def test_detect_jingle_found(self, mock_compare, mock_extract):
        """Test jingle detection when match is found."""
        # Mock iframe extraction
        mock_extract.return_value = ['/tmp/iframe_001.jpg', '/tmp/iframe_002.jpg']
        
        # Mock image comparison - first frame matches
        mock_compare.side_effect = [0.7, 0.9]  # First is below threshold
        
        # Load templates
        self.detector.load_jingle_templates()
        
        result = self.detector.detect_jingle('/tmp/segment.ts', self.session)
        
        self.assertIsNotNone(result)
        jingle_name, iframe_path, similarity, template = result
        self.assertEqual(jingle_name, 'Test Jingle')
        self.assertEqual(similarity, 0.7)
        self.assertEqual(template, self.template)
    
    @patch.object(JingleDetector, 'extract_iframes')
    @patch.object(JingleDetector, 'compare_images')
    def test_detect_jingle_not_found(self, mock_compare, mock_extract):
        """Test jingle detection when no match is found."""
        # Mock iframe extraction
        mock_extract.return_value = ['/tmp/iframe_001.jpg']
        
        # Mock image comparison - no match
        mock_compare.return_value = 0.9  # Above threshold
        
        # Load templates
        self.detector.load_jingle_templates()
        
        result = self.detector.detect_jingle('/tmp/segment.ts', self.session)
        
        self.assertIsNone(result)
    
    @patch.object(JingleDetector, 'detect_jingle')
    def test_process_detection_success(self, mock_detect):
        """Test detection processing with successful match."""
        # Mock detection result
        mock_detect.return_value = (
            'Test Jingle',
            '/tmp/iframe_001.jpg',
            0.75,
            self.template
        )
        
        detection = self.detector.process_detection('/tmp/segment.ts', self.session)
        
        self.assertIsNotNone(detection)
        self.assertEqual(detection.jingle_template, self.template)
        self.assertEqual(detection.confidence_score, 0.75)
        self.assertEqual(detection.segment_filename, 'segment.ts')
    
    @patch.object(JingleDetector, 'detect_jingle')
    def test_process_detection_no_match(self, mock_detect):
        """Test detection processing with no match."""
        mock_detect.return_value = None
        
        detection = self.detector.process_detection('/tmp/segment.ts', self.session)
        
        self.assertIsNone(detection)
    
    def test_load_jingle_templates(self):
        """Test loading jingle templates."""
        # Create additional templates
        JingleTemplate.objects.create(
            name='Another Jingle',
            slug='another-jingle',
            image_path='/tmp/another_jingle.jpg',
            is_active=True,
            created_by=self.user
        )
        
        # Create inactive template
        JingleTemplate.objects.create(
            name='Inactive Jingle',
            slug='inactive-jingle',
            image_path='/tmp/inactive_jingle.jpg',
            is_active=False,
            created_by=self.user
        )
        
        self.detector.load_jingle_templates()
        
        # Should only load active templates
        self.assertEqual(len(self.detector.jingle_templates), 2)
        
        # Check template data structure
        template_names = [template[0] for template in self.detector.jingle_templates]
        self.assertIn('Test Jingle', template_names)
        self.assertIn('Another Jingle', template_names)
        self.assertNotIn('Inactive Jingle', template_names)


class AdBreakAnalyzerTest(TestCase):
    """Test cases for AdBreakAnalyzer service."""
    
    def setUp(self):
        self.user = User.objects.create_user(
            username='testuser',
            email='test@example.com',
            password='testpass123'
        )
        
        self.channel = Channel.objects.create(
            name='Test Channel',
            slug='test-channel',
            hls_url='https://example.com/stream.m3u8',
            output_directory='/tmp/test',
            created_by=self.user
        )
        
        self.session = StreamSession.objects.create(
            channel=self.channel,
            status='active',
            started_at=timezone.now()
        )
        
        self.template = JingleTemplate.objects.create(
            name='Test Jingle',
            slug='test-jingle',
            image_path='/tmp/test_jingle.jpg',
            created_by=self.user
        )
        
        self.analyzer = AdBreakAnalyzer()
    
    def test_start_ad_break(self):
        """Test starting an ad break."""
        detection = JingleDetection.objects.create(
            session=self.session,
            jingle_template=self.template,
            segment_filename='segment_001.ts',
            detection_time=timezone.now(),
            confidence_score=0.85,
            frame_path='/tmp/frame_001.jpg'
        )
        
        ad_break = self.analyzer.start_ad_break(detection)
        
        self.assertIsNotNone(ad_break)
        self.assertEqual(ad_break.session, self.session)
        self.assertEqual(ad_break.start_detection, detection)
        self.assertEqual(ad_break.status, 'active')
        self.assertIsNotNone(ad_break.start_time)
    
    def test_end_ad_break(self):
        """Test ending an ad break."""
        # Create start detection and ad break
        start_detection = JingleDetection.objects.create(
            session=self.session,
            jingle_template=self.template,
            segment_filename='segment_001.ts',
            detection_time=timezone.now(),
            confidence_score=0.85,
            frame_path='/tmp/frame_001.jpg'
        )
        
        ad_break = AdBreak.objects.create(
            session=self.session,
            start_detection=start_detection,
            start_time=timezone.now(),
            status='active'
        )
        
        # Create end detection
        end_detection = JingleDetection.objects.create(
            session=self.session,
            jingle_template=self.template,
            segment_filename='segment_010.ts',
            detection_time=timezone.now() + timedelta(minutes=3),
            confidence_score=0.90,
            frame_path='/tmp/frame_010.jpg'
        )
        
        result = self.analyzer.end_ad_break(ad_break, end_detection)
        
        self.assertTrue(result)
        ad_break.refresh_from_db()
        self.assertEqual(ad_break.status, 'completed')
        self.assertEqual(ad_break.end_detection, end_detection)
        self.assertIsNotNone(ad_break.end_time)
    
    def test_analyze_detection_start_ad_break(self):
        """Test detection analysis that starts an ad break."""
        detection = JingleDetection.objects.create(
            session=self.session,
            jingle_template=self.template,
            segment_filename='segment_001.ts',
            detection_time=timezone.now(),
            confidence_score=0.85,
            frame_path='/tmp/frame_001.jpg'
        )
        
        result = self.analyzer.analyze_detection(detection)
        
        self.assertEqual(result['action'], 'start_ad_break')
        self.assertIsNotNone(result['ad_break'])
        self.assertEqual(result['ad_break'].start_detection, detection)
    
    def test_analyze_detection_end_ad_break(self):
        """Test detection analysis that ends an ad break."""
        # Create active ad break
        start_detection = JingleDetection.objects.create(
            session=self.session,
            jingle_template=self.template,
            segment_filename='segment_001.ts',
            detection_time=timezone.now() - timedelta(minutes=3),
            confidence_score=0.85,
            frame_path='/tmp/frame_001.jpg'
        )
        
        ad_break = AdBreak.objects.create(
            session=self.session,
            start_detection=start_detection,
            start_time=timezone.now() - timedelta(minutes=3),
            status='active'
        )
        
        # Create end detection
        end_detection = JingleDetection.objects.create(
            session=self.session,
            jingle_template=self.template,
            segment_filename='segment_010.ts',
            detection_time=timezone.now(),
            confidence_score=0.90,
            frame_path='/tmp/frame_010.jpg'
        )
        
        result = self.analyzer.analyze_detection(end_detection)
        
        self.assertEqual(result['action'], 'end_ad_break')
        self.assertEqual(result['ad_break'], ad_break)
    
    def test_cleanup_stale_ad_breaks(self):
        """Test cleanup of stale ad breaks."""
        # Create stale ad break (active for too long)
        old_time = timezone.now() - timedelta(minutes=10)
        start_detection = JingleDetection.objects.create(
            session=self.session,
            jingle_template=self.template,
            segment_filename='segment_001.ts',
            detection_time=old_time,
            confidence_score=0.85,
            frame_path='/tmp/frame_001.jpg'
        )
        
        stale_ad_break = AdBreak.objects.create(
            session=self.session,
            start_detection=start_detection,
            start_time=old_time,
            status='active'
        )
        
        # Create recent ad break (should not be cleaned)
        recent_time = timezone.now() - timedelta(minutes=2)
        recent_detection = JingleDetection.objects.create(
            session=self.session,
            jingle_template=self.template,
            segment_filename='segment_020.ts',
            detection_time=recent_time,
            confidence_score=0.85,
            frame_path='/tmp/frame_020.jpg'
        )
        
        recent_ad_break = AdBreak.objects.create(
            session=self.session,
            start_detection=recent_detection,
            start_time=recent_time,
            status='active'
        )
        
        # Cleanup with 5 minute timeout
        cleaned_count = self.analyzer.cleanup_stale_ad_breaks(self.session, timeout_minutes=5)
        
        self.assertEqual(cleaned_count, 1)
        
        # Check that stale ad break was marked as timed out
        stale_ad_break.refresh_from_db()
        self.assertEqual(stale_ad_break.status, 'timed_out')
        
        # Check that recent ad break is still active
        recent_ad_break.refresh_from_db()
        self.assertEqual(recent_ad_break.status, 'active')
    
    def test_get_active_ad_break(self):
        """Test retrieving active ad break for session."""
        # No active ad break initially
        active_break = self.analyzer.get_active_ad_break(self.session)
        self.assertIsNone(active_break)
        
        # Create active ad break
        start_detection = JingleDetection.objects.create(
            session=self.session,
            jingle_template=self.template,
            segment_filename='segment_001.ts',
            detection_time=timezone.now(),
            confidence_score=0.85,
            frame_path='/tmp/frame_001.jpg'
        )
        
        ad_break = AdBreak.objects.create(
            session=self.session,
            start_detection=start_detection,
            start_time=timezone.now(),
            status='active'
        )
        
        active_break = self.analyzer.get_active_ad_break(self.session)
        self.assertEqual(active_break, ad_break)
    
    @patch('shutil.rmtree')
    def test_cleanup_detection_files(self, mock_rmtree):
        """Test cleanup of detection frame files."""
        # Create detections with frame paths
        detection1 = JingleDetection.objects.create(
            session=self.session,
            jingle_template=self.template,
            segment_filename='segment_001.ts',
            detection_time=timezone.now() - timedelta(days=8),
            confidence_score=0.85,
            frame_path='/tmp/frames/frame_001.jpg'
        )
        
        detection2 = JingleDetection.objects.create(
            session=self.session,
            jingle_template=self.template,
            segment_filename='segment_002.ts',
            detection_time=timezone.now() - timedelta(days=8),
            confidence_score=0.90,
            frame_path='/tmp/frames/frame_002.jpg'
        )
        
        # Create recent detection (should not be cleaned)
        recent_detection = JingleDetection.objects.create(
            session=self.session,
            jingle_template=self.template,
            segment_filename='segment_003.ts',
            detection_time=timezone.now(),
            confidence_score=0.88,
            frame_path='/tmp/frames/frame_003.jpg'
        )
        
        with patch('os.path.exists', return_value=True), \
             patch('os.remove') as mock_remove:
            
            cleaned_count = self.analyzer.cleanup_detection_files(days_old=7)
            
            # Should clean 2 old detections
            self.assertEqual(cleaned_count, 2)
            self.assertEqual(mock_remove.call_count, 2)
    
    def test_calculate_detection_statistics(self):
        """Test detection statistics calculation."""
        # Create test detections for today
        today = timezone.now().date()
        
        # Confirmed detections
        for i in range(3):
            JingleDetection.objects.create(
                session=self.session,
                jingle_template=self.template,
                segment_filename=f'segment_{i:03d}.ts',
                detection_time=timezone.now(),
                confidence_score=0.85,
                frame_path=f'/tmp/frame_{i:03d}.jpg',
                is_confirmed=True
            )
        
        # False positive
        JingleDetection.objects.create(
            session=self.session,
            jingle_template=self.template,
            segment_filename='segment_004.ts',
            detection_time=timezone.now(),
            confidence_score=0.70,
            frame_path='/tmp/frame_004.jpg',
            is_false_positive=True
        )
        
        # Unconfirmed detection
        JingleDetection.objects.create(
            session=self.session,
            jingle_template=self.template,
            segment_filename='segment_005.ts',
            detection_time=timezone.now(),
            confidence_score=0.80,
            frame_path='/tmp/frame_005.jpg'
        )
        
        stats = self.analyzer.calculate_detection_statistics(self.template, today)
        
        self.assertEqual(stats['total_detections'], 5)
        self.assertEqual(stats['confirmed_detections'], 3)
        self.assertEqual(stats['false_positives'], 1)
        self.assertEqual(stats['average_confidence'], 0.78)  # (0.85*3 + 0.70 + 0.80) / 5
    
    def test_get_detection_trends(self):
        """Test detection trend analysis."""
        # Create detections over multiple days
        base_time = timezone.now()
        
        for day_offset in range(7):
            detection_time = base_time - timedelta(days=day_offset)
            
            JingleDetection.objects.create(
                session=self.session,
                jingle_template=self.template,
                segment_filename=f'segment_{day_offset:03d}.ts',
                detection_time=detection_time,
                confidence_score=0.85,
                frame_path=f'/tmp/frame_{day_offset:03d}.jpg'
            )
        
        trends = self.analyzer.get_detection_trends(self.template, days=7)
        
        self.assertEqual(len(trends), 7)
        self.assertIn('date', trends[0])
        self.assertIn('detections', trends[0])
        self.assertIn('avg_confidence', trends[0])
