"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
"""
import bleach
from constants import SAFE_HTML_ATTRIBUTES, SAFE_HTML_TAGS
from django.db.models import Q
from drf_spectacular.utils import extend_schema_field, extend_schema_serializer
from fsm.serializer_fields import FSMStateField
from label_studio_sdk.label_interface import LabelInterface
from label_studio_sdk.label_interface.control_tags import (
    BrushLabelsTag,
    BrushTag,
    ChoicesTag,
    DateTimeTag,
    EllipseLabelsTag,
    EllipseTag,
    HyperTextLabelsTag,
    KeyPointLabelsTag,
    KeyPointTag,
    LabelsTag,
    NumberTag,
    ParagraphLabelsTag,
    PolygonLabelsTag,
    PolygonTag,
    RatingTag,
    RectangleLabelsTag,
    RectangleTag,
    TaxonomyTag,
    TextAreaTag,
    TimeSeriesLabelsTag,
    VideoRectangleTag,
)
from projects.models import Project, ProjectImport, ProjectOnboarding, ProjectReimport, ProjectSummary
from rest_flex_fields import FlexFieldsModelSerializer
from rest_framework import serializers
from rest_framework.serializers import SerializerMethodField
from tasks.models import Task
from users.serializers import UserSimpleSerializer


@extend_schema_field({'type': 'object', 'additionalProperties': True})
class OpenApiObjectJSONField(serializers.JSONField):
    """
    A JSON field that is always rendered as a generic OpenAPI object.

    drf-spectacular may otherwise produce a schema with only metadata (e.g. nullable/readOnly/description)
    and omit `type`/`$ref`, which breaks some OpenAPI doc renderers.
    """


class CreatedByFromContext:
    requires_context = True

    def __call__(self, serializer_field):
        return serializer_field.context.get('created_by')


@extend_schema_serializer(deprecate_fields=['show_ground_truth_first'])
class ProjectSerializer(FlexFieldsModelSerializer):
    """Serializer get numbers from project queryset annotation,
    make sure, that you use correct one(Project.objects.with_counts())
    """

    task_number = serializers.IntegerField(default=None, read_only=True, help_text='Total task number in project')
    total_annotations_number = serializers.IntegerField(
        default=None,
        read_only=True,
        help_text='Total annotations number in project including '
        'skipped_annotations_number and ground_truth_number.',
    )
    total_predictions_number = serializers.IntegerField(
        default=None,
        read_only=True,
        help_text='Total predictions number in project including '
        'skipped_annotations_number, ground_truth_number, and '
        'useful_annotation_number.',
    )
    useful_annotation_number = serializers.IntegerField(
        default=None,
        read_only=True,
        help_text='Useful annotation number in project not including '
        'skipped_annotations_number and ground_truth_number. '
        'Total annotations = annotation_number + '
        'skipped_annotations_number + ground_truth_number',
    )
    ground_truth_number = serializers.IntegerField(
        default=None, read_only=True, help_text='Honeypot annotation number in project'
    )
    skipped_annotations_number = serializers.IntegerField(
        default=None, read_only=True, help_text='Skipped by collaborators annotation number in project'
    )
    num_tasks_with_annotations = serializers.IntegerField(
        default=None, read_only=True, help_text='Tasks with annotations count'
    )

    created_by = UserSimpleSerializer(default=CreatedByFromContext(), help_text='Project owner')

    control_weights = OpenApiObjectJSONField(
        required=False, allow_null=True, help_text='Dict of weights for each control tag in metric calculation.'
    )
    parsed_label_config = OpenApiObjectJSONField(
        default=None, read_only=True, help_text='JSON-formatted labeling configuration'
    )
    start_training_on_annotation_update = SerializerMethodField(
        default=None, read_only=False, help_text='Start model training after any annotations are submitted or updated'
    )
    config_has_control_tags = SerializerMethodField(
        default=None, read_only=True, help_text='Flag to detect is project ready for labeling'
    )
    config_suitable_for_bulk_annotation = serializers.SerializerMethodField(
        default=None, read_only=True, help_text='Flag to detect is project ready for bulk annotation'
    )
    finished_task_number = serializers.IntegerField(default=None, read_only=True, help_text='Finished tasks')

    queue_total = serializers.SerializerMethodField()
    queue_done = serializers.SerializerMethodField()
    state = FSMStateField(read_only=True)  # FSM state - automatically uses annotation if present

    @property
    def user_id(self):
        try:
            return self.context['request'].user.id
        except KeyError:
            return next(iter(self.context['user_cache']))

    @staticmethod
    def get_config_has_control_tags(project) -> bool:
        return len(project.get_parsed_config()) > 0

    @staticmethod
    def get_config_suitable_for_bulk_annotation(project) -> bool:
        li = LabelInterface(project.label_config)

        # List of tags that should not be present
        disallowed_tags = [
            LabelsTag,
            BrushTag,
            BrushLabelsTag,
            EllipseTag,
            EllipseLabelsTag,
            KeyPointTag,
            KeyPointLabelsTag,
            PolygonTag,
            PolygonLabelsTag,
            RectangleTag,
            RectangleLabelsTag,
            HyperTextLabelsTag,
            ParagraphLabelsTag,
            TimeSeriesLabelsTag,
            VideoRectangleTag,
        ]

        # Return False if any disallowed tag is present
        for tag_class in disallowed_tags:
            if li.find_tags_by_class(tag_class):
                return False

        # Check perRegion/perItem for expanded list of tags, plus value="no" for Choices/Taxonomy
        allowed_tags_for_checks = [ChoicesTag, TaxonomyTag, DateTimeTag, NumberTag, RatingTag, TextAreaTag]
        for tag_class in allowed_tags_for_checks:
            tags = li.find_tags_by_class(tag_class)
            for tag in tags:
                per_region = tag.attr.get('perRegion', 'false').lower() == 'true'
                per_item = tag.attr.get('perItem', 'false').lower() == 'true'
                if per_region or per_item:
                    return False
                # For ChoicesTag and TaxonomyTag, the value attribute must not be set at all
                if tag_class in [ChoicesTag, TaxonomyTag]:
                    if 'value' in tag.attr:
                        return False

        # For TaxonomyTag, check labeling and apiUrl
        taxonomy_tags = li.find_tags_by_class(TaxonomyTag)
        for tag in taxonomy_tags:
            labeling = tag.attr.get('labeling', 'false').lower() == 'true'
            if labeling:
                return False
            api_url = tag.attr.get('apiUrl', None)
            if api_url is not None:
                return False

        # If all checks pass, return True
        return True

    @staticmethod
    def get_parsed_label_config(project):
        return project.get_parsed_config()

    def get_start_training_on_annotation_update(self, instance) -> bool:
        # FIXME: remake this logic with start_training_on_annotation_update
        return True if instance.min_annotations_to_start_training else False

    def to_internal_value(self, data):
        # FIXME: remake this logic with start_training_on_annotation_update
        initial_data = data
        data = super().to_internal_value(data)

        if 'start_training_on_annotation_update' in initial_data:
            data['min_annotations_to_start_training'] = int(initial_data['start_training_on_annotation_update'])

        if 'expert_instruction' in initial_data:
            data['expert_instruction'] = bleach.clean(
                initial_data['expert_instruction'], tags=SAFE_HTML_TAGS, attributes=SAFE_HTML_ATTRIBUTES
            )

        return data

    def validate_color(self, value):
        # color : "#FF4C25"
        if value.startswith('#') and len(value) == 7:
            try:
                int(value[1:], 16)
                return value
            except ValueError:
                pass
        raise serializers.ValidationError('Color must be in "#RRGGBB" format')

    class Meta:
        model = Project
        extra_kwargs = {
            'memberships': {'required': False},
            'title': {'required': False},
            'created_by': {'required': False},
        }
        fields = [
            'id',
            'title',
            'description',
            'label_config',
            'expert_instruction',
            'show_instruction',
            'show_skip_button',
            'enable_empty_annotation',
            'show_annotation_history',
            'organization',
            'color',
            'maximum_annotations',
            'is_published',
            'model_version',
            'is_draft',
            'created_by',
            'created_at',
            'min_annotations_to_start_training',
            'start_training_on_annotation_update',
            'show_collab_predictions',
            'num_tasks_with_annotations',
            'task_number',
            'useful_annotation_number',
            'ground_truth_number',
            'skipped_annotations_number',
            'total_annotations_number',
            'total_predictions_number',
            'sampling',
            'show_ground_truth_first',
            'annotator_evaluation_enabled',
            'show_overlap_first',
            'overlap_cohort_percentage',
            'task_data_login',
            'task_data_password',
            'control_weights',
            'parsed_label_config',
            'evaluate_predictions_automatically',
            'config_has_control_tags',
            'skip_queue',
            'reveal_preannotations_interactively',
            'pinned_at',
            'finished_task_number',
            'queue_total',
            'queue_done',
            'config_suitable_for_bulk_annotation',
            'state',
        ]

    def validate_label_config(self, value):
        if self.instance is None:
            # No project created yet
            Project.validate_label_config(value)
        else:
            # Existing project is updated
            self.instance.validate_config(value)
        return value

    def validate_model_version(self, value):
        """Custom model_version validation"""
        p = self.instance

        # Only run the validation if model_version is about to change
        # and it contains a string
        if p is not None and p.model_version != value and value != '':
            # that model_version should either match live ml backend
            # or match version in predictions

            if p.ml_backends.filter(title=value).union(p.predictions.filter(project=p, model_version=value)).exists():
                return value
            else:
                raise serializers.ValidationError(
                    "Model version doesn't exist either as live model or as static predictions."
                )

        return value

    def update(self, instance, validated_data):
        if validated_data.get('show_collab_predictions') is False:
            instance.model_version = ''

        return super().update(instance, validated_data)

    def get_queue_total(self, project) -> int:
        remain = project.tasks.filter(
            Q(is_labeled=False) & ~Q(annotations__completed_by_id=self.user_id)
            | Q(annotations__completed_by_id=self.user_id)
        ).distinct()
        return remain.count()

    def get_queue_done(self, project) -> int:
        tasks_filter = {
            'project': project,
            'annotations__completed_by_id': self.user_id,
        }

        if project.skip_queue == project.SkipQueue.REQUEUE_FOR_ME:
            tasks_filter['annotations__was_cancelled'] = False

        already_done_tasks = Task.objects.filter(**tasks_filter)
        result = already_done_tasks.distinct().count()

        return result


class ProjectCountsSerializer(ProjectSerializer):
    class Meta:
        model = Project
        fields = [
            'id',
            'task_number',
            'finished_task_number',
            'total_predictions_number',
            'total_annotations_number',
            'num_tasks_with_annotations',
            'useful_annotation_number',
            'ground_truth_number',
            'skipped_annotations_number',
        ]


class ProjectOnboardingSerializer(serializers.ModelSerializer):
    class Meta:
        model = ProjectOnboarding
        fields = '__all__'


class ProjectLabelConfigSerializer(serializers.Serializer):
    label_config = serializers.CharField(help_text=Project.label_config.field.help_text)

    def validate_label_config(self, config):
        Project.validate_label_config(config)
        return config


class ProjectSummarySerializer(serializers.ModelSerializer):
    class Meta:
        model = ProjectSummary
        fields = '__all__'


class ProjectImportSerializer(serializers.ModelSerializer):
    class Meta:
        model = ProjectImport
        fields = [
            'id',
            'project',
            'preannotated_from_fields',
            'commit_to_project',
            'return_task_ids',
            'status',
            'url',
            'error',
            'created_at',
            'updated_at',
            'finished_at',
            'task_count',
            'annotation_count',
            'prediction_count',
            'duration',
            'file_upload_ids',
            'could_be_tasks_list',
            'found_formats',
            'data_columns',
            'tasks',
            'task_ids',
        ]


class ProjectReimportSerializer(serializers.ModelSerializer):
    class Meta:
        model = ProjectReimport
        fields = [
            'id',
            'project',
            'status',
            'error',
            'task_count',
            'annotation_count',
            'prediction_count',
            'duration',
            'file_upload_ids',
            'files_as_tasks_list',
            'found_formats',
            'data_columns',
        ]


class ProjectModelVersionExtendedSerializer(serializers.Serializer):
    model_version = serializers.CharField()
    count = serializers.IntegerField()
    latest = serializers.DateTimeField()


class ProjectModelVersionParamsSerializer(serializers.Serializer):
    extended = serializers.BooleanField(required=False, default=False)
    include_live_models = serializers.BooleanField(required=False, default=False)
    limit = serializers.IntegerField(required=False, default=None)


class GetFieldsSerializer(serializers.Serializer):
    include = serializers.CharField(
        required=False,
        help_text=(
            'Comma-separated list of count fields to include in the response to optimize performance. '
            'Available fields: task_number, finished_task_number, total_predictions_number, '
            'total_annotations_number, num_tasks_with_annotations, useful_annotation_number, '
            'ground_truth_number, skipped_annotations_number. If not specified, all count fields are included.'
        ),
    )
    filter = serializers.CharField(
        required=False,
        default='all',
        help_text=(
            "Filter projects by pinned status. Use 'pinned_only' to return only pinned projects, "
            "'exclude_pinned' to return only non-pinned projects, or 'all' to return all projects."
        ),
    )
    search = serializers.CharField(
        required=False, default=None, help_text='Search term for project title and description'
    )

    def validate_include(self, value):
        if value is not None:
            value = value.split(',')
        return value

    def validate_filter(self, value):
        if value in ['all', 'pinned_only', 'exclude_pinned']:
            return value
