"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
"""
import base64
import fnmatch
import logging
import mimetypes
import re
from urllib.parse import urlparse

import boto3
from botocore.exceptions import ClientError
from core.utils.params import get_env
from django.conf import settings
from tldextract import TLDExtract

logger = logging.getLogger(__name__)


def get_client_and_resource(
    aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, region_name=None, s3_endpoint=None
):
    aws_access_key_id = aws_access_key_id or get_env('AWS_ACCESS_KEY_ID')
    aws_secret_access_key = aws_secret_access_key or get_env('AWS_SECRET_ACCESS_KEY')
    aws_session_token = aws_session_token or get_env('AWS_SESSION_TOKEN')
    logger.debug(
        f'Create boto3 session with '
        f'access key id={aws_access_key_id}, '
        f'secret key={aws_secret_access_key[:4] + "..." if aws_secret_access_key else None}, '
        f'session token={aws_session_token}'
    )
    session = boto3.Session(
        aws_access_key_id=aws_access_key_id,
        aws_secret_access_key=aws_secret_access_key,
        aws_session_token=aws_session_token,
    )
    settings = {'region_name': region_name or get_env('S3_region') or 'us-east-1'}
    s3_endpoint = s3_endpoint or get_env('S3_ENDPOINT')
    if s3_endpoint:
        settings['endpoint_url'] = s3_endpoint
    client = session.client('s3', config=boto3.session.Config(signature_version='s3v4'), **settings)
    resource = session.resource('s3', config=boto3.session.Config(signature_version='s3v4'), **settings)
    return client, resource


def resolve_s3_url(url, client, presign=True, expires_in=3600):
    r = urlparse(url, allow_fragments=False)
    bucket_name = r.netloc
    key = r.path.lstrip('/')

    # Return blob as base64 encoded string if presigned urls are disabled
    if not presign:
        object = client.get_object(Bucket=bucket_name, Key=key)
        content_type = object['ResponseMetadata']['HTTPHeaders']['content-type']
        object_b64 = 'data:' + content_type + ';base64,' + base64.b64encode(object['Body'].read()).decode('utf-8')
        return object_b64

    # Otherwise try to generate presigned url
    try:
        params = {'Bucket': bucket_name, 'Key': key}

        # Override response Content-Type based on file extension so that S3 returns
        # the correct MIME type even if the object was uploaded without one.
        # Without this, S3 may return binary/octet-stream which causes browsers
        # to download files instead of displaying them inline (e.g. images, audio, video).
        content_type, _ = mimetypes.guess_type(key)
        if content_type:
            params['ResponseContentType'] = content_type

        presigned_url = client.generate_presigned_url(ClientMethod='get_object', Params=params, ExpiresIn=expires_in)
    except ClientError as exc:
        logger.warning(f"Can't generate presigned URL. Reason: {exc}")
        return url
    else:
        logger.debug('Presigned URL {presigned_url} generated for {url}'.format(presigned_url=presigned_url, url=url))
        return presigned_url


class AWS(object):
    @classmethod
    def get_blob_metadata(
        cls,
        url: str,
        bucket_name: str,
        client=None,
        aws_access_key_id=None,
        aws_secret_access_key=None,
        aws_session_token=None,
        region_name=None,
        s3_endpoint=None,
    ):
        """
        Get blob metadata by url
        :param url: Object key
        :param bucket_name: AWS bucket name
        :param client: AWS client for batch processing
        :param account_key: Azure account key
        :return: Object metadata dict("name": "value")
        """
        if client is None:
            client, _ = get_client_and_resource(
                aws_access_key_id=aws_access_key_id,
                aws_secret_access_key=aws_secret_access_key,
                aws_session_token=aws_session_token,
                region_name=region_name,
                s3_endpoint=s3_endpoint,
            )
        object = client.get_object(Bucket=bucket_name, Key=url)
        metadata = dict(object)
        # remove unused fields
        metadata.pop('Body', None)
        metadata.pop('ResponseMetadata', None)
        return metadata

    @classmethod
    def validate_pattern(cls, storage, pattern, glob_pattern=True):
        """
        Validate pattern against S3 Storage
        :param storage: S3 Storage instance
        :param pattern: Pattern to validate
        :param glob_pattern: If True, pattern is a glob pattern, otherwise it is a regex pattern
        :return: Message if pattern is not valid, empty string otherwise
        """
        client, bucket = storage.get_client_and_bucket()
        if glob_pattern:
            pattern = fnmatch.translate(pattern)
        regex = re.compile(pattern)

        if storage.prefix:
            list_kwargs = {'Prefix': storage.prefix.rstrip('/') + '/'}
            if not storage.recursive_scan:
                list_kwargs['Delimiter'] = '/'
            bucket_iter = bucket.objects.filter(**list_kwargs)
        else:
            bucket_iter = bucket.objects

        bucket_iter = bucket_iter.page_size(settings.CLOUD_STORAGE_CHECK_FOR_RECORDS_PAGE_SIZE).all()

        for index, obj in enumerate(bucket_iter):
            key = obj.key
            # skip directories
            if key.endswith('/'):
                logger.debug(key + ' is skipped because it is a folder')
                continue
            if regex and regex.match(key):
                logger.debug(key + ' matches file pattern')
                return ''
        return 'No objects found matching the provided glob pattern'


class S3StorageError(Exception):
    pass


# see https://github.com/john-kurkowski/tldextract?tab=readme-ov-file#note-about-caching
# prevents network call on first use
extractor = TLDExtract(suffix_list_urls=())


def catch_and_reraise_from_none(func):
    """
    For S3 storages - if s3_endpoint is not on a known domain, catch exception and
    raise a new one with the previous context suppressed. See also: https://peps.python.org/pep-0409/
    """

    def wrapper(self, *args, **kwargs):
        try:
            return func(self, *args, **kwargs)
        except Exception as e:
            if self.s3_endpoint and (
                domain := extractor.extract_urllib(urlparse(self.s3_endpoint)).registered_domain.lower()
            ) not in [trusted_domain.lower() for trusted_domain in settings.S3_TRUSTED_STORAGE_DOMAINS]:
                logger.error(f'Exception from unrecognized S3 domain: {e}', exc_info=True)
                raise S3StorageError(
                    f'Debugging info is not available for s3 endpoints on domain: {domain}. '
                    'Please contact your Label Studio devops team if you require detailed error reporting for this domain.'
                ) from None
            else:
                raise e

    return wrapper
