"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license."""

import logging

from django.conf import settings
from django.db import models
from django.db.models import Q
from django.utils.translation import gettext_lazy as _
from ml_model_providers.models import ModelProviderConnection, ModelProviders
from projects.models import Project
from rest_framework.exceptions import ValidationError
from tasks.models import Annotation, FailedPrediction, Prediction, PredictionMeta

logger = logging.getLogger(__name__)


# skills are partitions of projects (label config + input columns + output columns) into categories of labeling tasks
class SkillNames(models.TextChoices):
    TEXT_CLASSIFICATION = 'TextClassification', _('TextClassification')
    NAMED_ENTITY_RECOGNITION = 'NamedEntityRecognition', _('NamedEntityRecognition')


def validate_string_list(value):
    if not value:
        raise ValidationError('list should not be empty')
    if not isinstance(value, list):
        raise ValidationError('Value must be a list')
    if not all(isinstance(item, str) for item in value):
        raise ValidationError('All items in the list must be strings')


class ModelInterface(models.Model):
    title = models.CharField(_('title'), max_length=500, null=False, blank=False, help_text='Model name')

    description = models.TextField(_('description'), null=True, blank=True, help_text='Model description')

    created_by = models.ForeignKey(
        settings.AUTH_USER_MODEL, related_name='created_models', on_delete=models.SET_NULL, null=True
    )

    created_at = models.DateTimeField(_('created at'), auto_now_add=True)

    updated_at = models.DateTimeField(_('updated at'), auto_now=True)

    organization = models.ForeignKey(
        'organizations.Organization', on_delete=models.CASCADE, related_name='model_interfaces', null=True
    )

    skill_name = models.CharField(max_length=255, choices=SkillNames.choices, null=True)

    input_fields = models.JSONField(default=list, validators=[validate_string_list])

    output_classes = models.JSONField(default=list, validators=[validate_string_list])

    associated_projects = models.ManyToManyField('projects.Project', blank=True)

    def has_permission(self, user):
        return user.active_organization == self.organization


class ModelVersion(models.Model):
    class Meta:
        abstract = True

    title = models.CharField(_('title'), max_length=500, null=False, blank=False, help_text='Model name')

    parent_model = models.ForeignKey(ModelInterface, related_name='model_versions', on_delete=models.CASCADE)

    prompt = models.TextField(_('prompt'), null=False, blank=False, help_text='Prompt to execute')

    model_provider_connection = models.ForeignKey(
        ModelProviderConnection, related_name='model_versions', on_delete=models.SET_NULL, null=True
    )

    @property
    def full_title(self):
        return f'{self.parent_model.title}__{self.title}'

    def delete(self, *args, **kwargs):
        """
        Deletes Predictions associated with ModelVersion
        """
        model_runs = ModelRun.objects.filter(model_version=self.id)
        for model_run in model_runs:
            model_run.delete_predictions()
        super().delete(*args, **kwargs)


class ThirdPartyModelVersion(ModelVersion):
    provider = models.CharField(
        max_length=255,
        choices=ModelProviders.choices,
        default=ModelProviders.OPENAI,
        help_text='The model provider to use e.g. OpenAI',
    )

    provider_model_id = models.CharField(
        max_length=255,
        blank=False,
        null=False,
        help_text='The model ID to use within the given provider, e.g. gpt-3.5',
    )

    created_by = models.ForeignKey(
        settings.AUTH_USER_MODEL,
        related_name='created_third_party_model_versions',
        on_delete=models.SET_NULL,
        null=True,
    )

    created_at = models.DateTimeField(_('created at'), auto_now_add=True)

    updated_at = models.DateTimeField(_('updated at'), auto_now=True)

    organization = models.ForeignKey(
        'organizations.Organization', on_delete=models.CASCADE, related_name='third_party_model_versions', null=True
    )

    @property
    def project(self):
        # TODO: can it be just a property of the model version?
        if self.parent_model and self.parent_model.associated_projects.exists():
            return self.parent_model.associated_projects.first()
        return None

    def has_permission(self, user):
        return user.active_organization == self.organization


class ModelRun(models.Model):
    class ProjectSubset(models.TextChoices):
        ALL = 'All', _('All')
        HASGT = 'HasGT', _('HasGT')
        SAMPLE = 'Sample', _('Sample')

    class FileType(models.TextChoices):
        INPUT = 'Input', _('Input')
        OUTPUT = 'Output', _('Output')

    class ModelRunStatus(models.TextChoices):
        PENDING = 'Pending', _('Pending')
        INPROGRESS = 'InProgress', _('InProgress')
        COMPLETED = 'Completed', ('Completed')
        FAILED = 'Failed', ('Failed')
        CANCELED = 'Canceled', ('Canceled')

    organization = models.ForeignKey(
        'organizations.Organization', on_delete=models.CASCADE, related_name='model_runs', null=True
    )

    project = models.ForeignKey(Project, on_delete=models.CASCADE, related_name='model_runs')

    model_version = models.ForeignKey(ThirdPartyModelVersion, on_delete=models.CASCADE, related_name='model_runs')

    created_by = models.ForeignKey(
        settings.AUTH_USER_MODEL,
        related_name='model_runs',
        on_delete=models.SET_NULL,
        null=True,
    )

    project_subset = models.CharField(max_length=255, choices=ProjectSubset.choices, default=ProjectSubset.HASGT)

    status = models.CharField(max_length=255, choices=ModelRunStatus.choices, default=ModelRunStatus.PENDING)

    job_id = models.CharField(
        max_length=255,
        null=True,
        blank=True,
        default=None,
        help_text='Job ID for inference job for a ModelRun e.g. Adala job ID',
    )

    total_predictions = models.IntegerField(_('total predictions'), default=0)

    total_correct_predictions = models.IntegerField(_('total correct predictions'), default=0)

    total_tasks = models.IntegerField(_('total tasks'), default=0)

    created_at = models.DateTimeField(_('created at'), auto_now_add=True)

    triggered_at = models.DateTimeField(_('triggered at'), null=True, default=None)

    predictions_updated_at = models.DateTimeField(_('predictions updated at'), null=True, default=None)

    completed_at = models.DateTimeField(_('completed at'), null=True, default=None)

    def has_permission(self, user):
        return user.active_organization == self.organization

    def delete_predictions(self):
        """
        Deletes any predictions that have originated from a ModelRun

        Executing a raw SQL query here for speed. This ignores any foreign key relationships
        so if another model has a Prediction fk and set to on_delete=CASCADE for example,
        it will not take affect. The only relationship like this that currently exists
        is in Annotation.parent_prediction, which we are handling here.
        """
        predictions = Prediction.objects.filter(model_run=self.id)
        prediction_ids = [p.id for p in predictions]
        # to delete all dependencies where predictions are foreign keys.
        Annotation.objects.filter(parent_prediction__in=prediction_ids).update(parent_prediction=None)
        try:
            from stats.models import PredictionPairStats, PredictionStats

            prediction_stats_to_be_deleted = PredictionStats.objects.filter(prediction_to__in=prediction_ids)
            prediction_stats_to_be_deleted.delete()
            prediction_pair_stats_to_be_deleted = PredictionPairStats.objects.filter(
                Q(prediction_to_id__in=prediction_ids) | Q(prediction_from_id__in=prediction_ids)
            )
            prediction_pair_stats_to_be_deleted.delete()
        except Exception as e:
            logger.info(f'PredictionStats or PredictionPairStats model does not exist , exception:{e}')

        # Delete failed predictions. Currently no other model references this, no fk relationships to remove
        failed_predictions = FailedPrediction.objects.filter(model_run=self.id)
        failed_predictions_ids = [p.id for p in failed_predictions]

        # delete predictions meta
        PredictionMeta.objects.filter(prediction__in=prediction_ids).delete()
        PredictionMeta.objects.filter(failed_prediction__in=failed_predictions_ids).delete()

        # remove predictions from db
        predictions._raw_delete(predictions.db)
        failed_predictions._raw_delete(failed_predictions.db)

    def delete(self, *args, **kwargs):
        """
        Deletes Predictions associated with ModelRun
        """
        self.delete_predictions()
        super().delete(*args, **kwargs)
