"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
"""
import logging
from typing import List

from django.conf import settings
from django.db import models
from django.utils.translation import gettext_lazy as _
from tasks.models import PredictionMeta

logger = logging.getLogger(__name__)


class ModelProviders(models.TextChoices):
    OPENAI = 'OpenAI', _('OpenAI')
    AZURE_OPENAI = 'AzureOpenAI', _('AzureOpenAI')
    AZURE_AI_FOUNDRY = 'AzureAIFoundry', _('AzureAIFoundry')
    VERTEX_AI = 'VertexAI', _('VertexAI')
    GEMINI = 'Gemini', _('Gemini')
    ANTHROPIC = 'Anthropic', _('Anthropic')
    CUSTOM = 'Custom', _('Custom')


class ModelProviderConnectionScopes(models.TextChoices):
    ORG = 'Organization', _('Organization')
    USER = 'User', _('User')
    MODEL = 'Model', _('Model')


class ModelProviderConnection(models.Model):

    provider = models.CharField(max_length=255, choices=ModelProviders.choices, default=ModelProviders.OPENAI)

    api_key = models.TextField(_('api_key'), null=True, blank=True, help_text='Model provider API key')

    auth_token = models.TextField(_('auth_token'), null=True, blank=True, help_text='Model provider Auth token')

    deployment_name = models.CharField(max_length=512, null=True, blank=True, help_text='Azure OpenAI deployment name')

    endpoint = models.CharField(max_length=512, null=True, blank=True, help_text='Azure OpenAI endpoint')

    google_application_credentials = models.TextField(
        _('google application credentials'),
        null=True,
        blank=True,
        help_text='The content of GOOGLE_APPLICATION_CREDENTIALS json file',
    )

    google_project_id = models.CharField(
        _('google project id'), max_length=255, null=True, blank=True, help_text='Google project ID'
    )

    google_location = models.CharField(
        _('google location'), max_length=255, null=True, blank=True, help_text='Google project location'
    )

    cached_available_models = models.CharField(
        max_length=4096, null=True, blank=True, help_text='List of available models from the provider'
    )

    scope = models.CharField(
        max_length=255, choices=ModelProviderConnectionScopes.choices, default=ModelProviderConnectionScopes.ORG
    )

    organization = models.ForeignKey(
        'organizations.Organization', on_delete=models.CASCADE, related_name='model_provider_connections', null=True
    )

    created_by = models.ForeignKey(
        settings.AUTH_USER_MODEL,
        related_name='created_model_provider_connections',
        on_delete=models.SET_NULL,
        null=True,
    )

    # Future work - add foreign key for modelinterface / modelinstance

    created_at = models.DateTimeField(_('created at'), auto_now_add=True)

    updated_at = models.DateTimeField(_('updated at'), auto_now=True)

    is_internal = models.BooleanField(
        _('is_internal'),
        default=False,
        help_text='Whether the model provider connection is internal, not visible to the user',
        null=True,
        blank=True,
    )

    budget_limit = models.FloatField(
        _('budget_limit'),
        null=True,
        blank=True,
        default=None,
        help_text='Budget limit for the model provider connection (null if unlimited)',
    )

    budget_last_reset_date = models.DateTimeField(
        _('budget_last_reset_date'),
        null=True,
        blank=True,
        default=None,
        help_text='Date and time the budget was last reset',
    )

    budget_reset_period = models.CharField(
        _('budget_reset_period'),
        max_length=20,
        choices=[
            ('Monthly', 'Monthly'),
            ('Yearly', 'Yearly'),
        ],
        null=True,
        blank=True,
        default=None,
        help_text='Budget reset period for the model provider connection (null if not reset)',
    )

    budget_total_spent = models.FloatField(
        _('budget_total_spent'),
        null=True,
        blank=True,
        default=None,
        help_text='Tracked total budget spent for the given provider connection within the current budget period',
    )

    budget_alert_threshold = models.FloatField(
        _('budget_alert_threshold'),
        null=True,
        blank=True,
        default=None,
        help_text='Budget alert threshold for the given provider connection',
    )

    # Check if user is Admin or Owner
    # This will need to be updated if we ever use this model in LSO as `is_owner` and
    # `is_administrator` only exist in LSE
    def has_permission(self, user):
        return (
            user.is_administrator or user.is_owner or user.is_manager
        ) and user.active_organization_id == self.organization_id

    def update_budget_total_spent_from_predictions_meta(self, predictions_meta: List[PredictionMeta]):
        total_cost = sum(meta.total_cost or 0 for meta in predictions_meta)
        # opting for the goofy "self.budget_total_spent or 0" to avoid a db migration
        self.budget_total_spent = (self.budget_total_spent or 0) + total_cost
        self.save(update_fields=['budget_total_spent'])

    def has_reached_budget_limit(self):
        if (
            self.is_internal
            and self.budget_total_spent
            and self.budget_limit
            and self.budget_total_spent > self.budget_limit
        ):
            logger.info(
                f'Model connection {self.id} has reached the budget limit: '
                f'{self.budget_total_spent} > {self.budget_limit}'
            )
            return True
        return False
