"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
"""
import logging
import os
import urllib

import requests
from core.feature_flags import flag_set
from core.utils.common import load_func
from core.version import get_git_version
from data_export.serializers import ExportDataSerializer
from django.conf import settings
from django.contrib.auth.models import AnonymousUser
from django.db.models import Count
from requests.adapters import HTTPAdapter
from requests.auth import HTTPBasicAuth

from label_studio.core.utils.params import get_env

version = get_git_version()
logger = logging.getLogger(__name__)

CONNECTION_TIMEOUT = float(get_env('ML_CONNECTION_TIMEOUT', 1))  # seconds
TIMEOUT_DEFAULT = float(get_env('ML_TIMEOUT_DEFAULT', 100))  # seconds

TIMEOUT_TRAIN = float(get_env('ML_TIMEOUT_TRAIN', 30))
TIMEOUT_PREDICT = float(get_env('ML_TIMEOUT_PREDICT', 100))
TIMEOUT_HEALTH = float(get_env('ML_TIMEOUT_HEALTH', 1))
TIMEOUT_SETUP = float(get_env('ML_TIMEOUT_SETUP', 3))
TIMEOUT_DUPLICATE_MODEL = float(get_env('ML_TIMEOUT_DUPLICATE_MODEL', 1))
TIMEOUT_DELETE = float(get_env('ML_TIMEOUT_DELETE', 1))
TIMEOUT_TRAIN_JOB_STATUS = float(get_env('ML_TIMEOUT_TRAIN_JOB_STATUS', 1))

# TODO
# we would need to make it configurable on the ML backend side too
PREDICT_URL = 'predict'
HEALTH_URL = 'health'
VALIDATE_URL = 'validate'
SETUP_URL = 'setup'
DUPLICATE_URL = 'duplicate_model'
DELETE_URL = 'delete'
JOB_STATUS_URL = 'job_status'
VERSIONS_URL = 'versions'


class BaseHTTPAPI(object):
    MAX_RETRIES = 2
    HEADERS = {
        'User-Agent': 'heartex/' + (version or ''),
    }

    def __init__(
        self, url, timeout=None, connection_timeout=None, max_retries=None, headers=None, auth_method=None, **kwargs
    ):
        self._url = url
        self._timeout = timeout or TIMEOUT_DEFAULT
        self._connection_timeout = connection_timeout or CONNECTION_TIMEOUT
        self._headers = headers or {}
        self._auth_method = auth_method

        # TODO basic auth parameters must be required for auth_method == 'basic'
        self._basic_auth = (kwargs.get('basic_auth_user'), kwargs.get('basic_auth_pass'))

        self._max_retries = max_retries or self.MAX_RETRIES
        self._sessions = {self._session_key(): self.create_session()}

    def create_session(self):
        session = requests.Session()
        session.headers.update(self.HEADERS)
        session.headers.update(self._headers)
        session.mount('http://', HTTPAdapter(max_retries=self._max_retries))
        session.mount('https://', HTTPAdapter(max_retries=self._max_retries))
        return session

    def _session_key(self):
        return os.getpid()

    @property
    def http(self):
        key = self._session_key()
        if key in self._sessions:
            return self._sessions[key]
        else:
            session = self.create_session()
            self._sessions[key] = session
            return session

    def _prepare_kwargs(self, kwargs):
        # add timeout if it's not presented
        if 'timeout' not in kwargs:
            kwargs['timeout'] = self._connection_timeout, self._timeout

        if self._basic_auth[0] and self._basic_auth[1]:
            kwargs['auth'] = HTTPBasicAuth(*self._basic_auth)

        # add connection timeout if it's not presented
        elif isinstance(kwargs['timeout'], float) or isinstance(kwargs['timeout'], int):
            kwargs['timeout'] = (self._connection_timeout, kwargs['timeout'])

    def request(self, method, *args, **kwargs):
        self._prepare_kwargs(kwargs)
        return self.http.request(method, *args, **kwargs)

    def get(self, *args, **kwargs):
        return self.request('GET', *args, **kwargs)

    def post(self, *args, **kwargs):
        return self.request('POST', *args, **kwargs)


class MLApiResult:
    """
    Class for storing the result of ML API request
    """

    def __init__(self, url='', request='', response=None, headers=None, type='ok', status_code=200):
        self.url = url
        self.request = request
        self.response = {} if response is None else response
        self.headers = {} if headers is None else headers
        self.type = type
        self.status_code = status_code

    @property
    def is_error(self):
        return self.type == 'error'

    @property
    def error_message(self):
        return self.response.get('error')


class MLApi(BaseHTTPAPI):
    """
    Class for ML API connector
    """

    def __init__(self, **kwargs):
        super(MLApi, self).__init__(**kwargs)
        self._validate_request_timeout = 10

    def _get_url(self, url_suffix):
        url = self._url
        if url[-1] != '/':
            url += '/'
        return urllib.parse.urljoin(url, url_suffix)

    def _request(self, url_suffix, request=None, verbose=True, method='POST', *args, **kwargs):
        assert method in ('POST', 'GET')
        url = self._get_url(url_suffix)
        request = request or {}
        headers = dict(self.http.headers)

        response = None
        try:
            if method == 'POST':
                response = self.post(url=url, json=request, *args, **kwargs)
            else:
                response = self.get(url=url, *args, **kwargs)
            response.raise_for_status()
        except requests.exceptions.RequestException as e:
            error_string = str(e)
            status_code = response.status_code if response is not None else 0
            return MLApiResult(url, request, {'error': error_string}, headers, 'error', status_code=status_code)
        status_code = response.status_code
        try:
            response = response.json()
        except ValueError as e:
            return MLApiResult(
                url=url,
                request=request,
                response={'error': str(e), 'response': response.content},
                headers=headers,
                type='error',
                status_code=status_code,
            )

        return MLApiResult(url=url, request=request, response=response, headers=headers, status_code=status_code)

    def _create_project_uid(self, project):
        time_id = int(project.created_at.timestamp())
        return f'{project.id}.{time_id}'

    def train(self, project, use_ground_truth=False):
        # TODO Replace AnonymousUser with real user from request
        user = AnonymousUser()
        # Identify if feature flag is turned on
        if flag_set('ff_back_dev_1417_start_training_mlbackend_webhooks_250122_long', user):
            request = {
                'action': 'START_TRAINING',
                'project': load_func(settings.WEBHOOK_SERIALIZERS['project'])(instance=project).data,
            }
            return self._request('webhook', request, verbose=False, timeout=TIMEOUT_PREDICT)
        else:
            # get only tasks with annotations
            tasks = project.tasks.annotate(num_annotations=Count('annotations')).filter(num_annotations__gt=0)
            # create serialized tasks with annotations: {"data": {...}, "annotations": [{...}], "predictions": [{...}]}
            tasks_ser = ExportDataSerializer(tasks, many=True).data
            logger.debug(f'{len(tasks_ser)} tasks with annotations are sent to ML backend for training.')
            request = {
                'annotations': tasks_ser,
                'project': self._create_project_uid(project),
                'label_config': project.label_config,
                'params': {'login': project.task_data_login, 'password': project.task_data_password},
            }
            return self._request('train', request, verbose=False, timeout=TIMEOUT_PREDICT)

    def _prep_prediction_req(self, tasks, project, context=None):
        request = {
            'tasks': tasks,
            'project': self._create_project_uid(project),
            'label_config': project.label_config,
            'params': {
                'login': project.task_data_login,
                'password': project.task_data_password,
                'context': context,
            },
        }

        return request

    def make_predictions(self, tasks, project, context=None):
        request = self._prep_prediction_req(tasks, project, context=context)
        return self._request(PREDICT_URL, request, verbose=False, timeout=TIMEOUT_PREDICT)

    def health(self):
        return self._request(HEALTH_URL, method='GET', timeout=TIMEOUT_HEALTH)

    def validate(self, config):
        return self._request(VALIDATE_URL, request={'config': config}, timeout=self._validate_request_timeout)

    def setup(self, project, extra_params=None, **kwargs):
        return self._request(
            SETUP_URL,
            request={
                'project': self._create_project_uid(project),
                'schema': project.label_config,
                'hostname': settings.HOSTNAME if settings.HOSTNAME else ('http://localhost:' + settings.INTERNAL_PORT),
                'access_token': project.created_by.auth_token.key,
                'extra_params': extra_params,
            },
            timeout=TIMEOUT_SETUP,
        )

    def duplicate_model(self, project_src, project_dst):
        return self._request(
            DUPLICATE_URL,
            request={
                'project_src': self._create_project_uid(project_src),
                'project_dst': self._create_project_uid(project_dst),
            },
            timeout=TIMEOUT_DUPLICATE_MODEL,
        )

    def delete(self, project):
        return self._request(
            DELETE_URL, request={'project': self._create_project_uid(project)}, timeout=TIMEOUT_DELETE
        )

    def get_train_job_status(self, train_job):
        return self._request(JOB_STATUS_URL, request={'job': train_job.job_id}, timeout=TIMEOUT_TRAIN_JOB_STATUS)

    def get_versions(self, project):
        return self._request(
            VERSIONS_URL, request={'project': self._create_project_uid(project)}, timeout=TIMEOUT_SETUP, method='GET'
        )


def get_ml_api(project):
    if project.ml_backend_active_connection is None:
        return None
    if project.ml_backend_active_connection.ml_backend is None:
        return None
    return MLApi(
        url=project.ml_backend_active_connection.ml_backend.url,
        timeout=project.ml_backend_active_connection.ml_backend.timeout,
    )
