"""Tests for the cache_labels action."""

import pytest
from data_manager.actions.cache_labels import cache_labels_job
from django.contrib.auth import get_user_model
from projects.models import Project
from tasks.models import Annotation, Prediction, Task


@pytest.mark.django_db
@pytest.mark.parametrize(
    'source, control_tag, with_counters, expected_cache_column, use_predictions',
    [
        # Test case 1: Annotations, control tag 'ALL', with counters
        ('annotations', 'ALL', 'Yes', 'cache_all', False),
        # Test case 2: Annotations, specific control tag, with counters
        ('annotations', 'label', 'Yes', 'cache_label', False),
        # Test case 3: Annotations, control tag 'ALL', without counters
        ('annotations', 'ALL', 'No', 'cache_all', False),
        # Test case 4: Predictions, control tag 'ALL', with counters
        ('predictions', 'ALL', 'Yes', 'cache_predictions_all', True),
    ],
)
def test_cache_labels_job(source, control_tag, with_counters, expected_cache_column, use_predictions):
    # Initialize a test user and project
    User = get_user_model()
    test_user = User.objects.create(username='test_user')
    project = Project.objects.create(title='Test Project', created_by=test_user)

    # Create a few tasks
    tasks = []
    for i in range(3):
        task = Task.objects.create(project=project, data={'text': f'This is task {i}'})
        tasks.append(task)

    # Add a few annotations or predictions to these tasks
    for i, task in enumerate(tasks):
        result = [
            {
                'from_name': 'label',  # Control tag used in the result
                'to_name': 'text',
                'type': 'labels',
                'value': {'labels': [f'Label_{i%2+1}']},
            }
        ]
        if use_predictions:
            Prediction.objects.create(task=task, project=project, result=result, model_version='v1')
        else:
            Annotation.objects.create(task=task, project=project, completed_by=test_user, result=result)

    # Prepare the request data
    request_data = {'source': source, 'control_tag': control_tag, 'with_counters': with_counters}

    # Get the queryset of tasks to process
    queryset = Task.objects.filter(project=project)

    # Run cache_labels_job
    cache_labels_job(project, queryset, request_data=request_data)

    # Check that the expected cache column is added to task['data']
    for task in tasks:
        task.refresh_from_db()
        cache_column = expected_cache_column
        assert cache_column in task.data
        cached_labels = task.data[cache_column]
        assert cached_labels is not None

        # Verify the contents of the cached labels
        if use_predictions:
            source_objects = Prediction.objects.filter(task=task)
        else:
            source_objects = Annotation.objects.filter(task=task)

        all_labels = []
        for source_obj in source_objects:
            for result in source_obj.result:
                # Apply similar logic as in extract_labels
                from_name = result.get('from_name')
                if control_tag == 'ALL' or control_tag == from_name:
                    value = result.get('value', {})
                    for key in value:
                        if isinstance(value[key], list) and value[key] and isinstance(value[key][0], str):
                            all_labels.extend(value[key])
                            break

        if with_counters.lower() == 'yes':
            expected_cache = ', '.join(sorted([f'{label}: {all_labels.count(label)}' for label in set(all_labels)]))
        else:
            expected_cache = ', '.join(sorted(list(set(all_labels))))

        assert cached_labels == expected_cache
