import logging
from datetime import datetime

from core.permissions import ViewClassPermission, all_permissions
from django.utils.decorators import method_decorator
from drf_spectacular.utils import extend_schema
from jwt_auth.auth import TokenAuthenticationPhaseout
from jwt_auth.models import LSAPIToken, TruncatedLSAPIToken
from jwt_auth.serializers import (
    JWTSettingsSerializer,
    LSAPITokenCreateSerializer,
    LSAPITokenListSerializer,
    TokenRefreshResponseSerializer,
    TokenRotateResponseSerializer,
)
from rest_framework import generics, status
from rest_framework.authentication import SessionAuthentication
from rest_framework.exceptions import APIException
from rest_framework.generics import CreateAPIView
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework_simplejwt.authentication import JWTAuthentication
from rest_framework_simplejwt.exceptions import TokenBackendError, TokenError
from rest_framework_simplejwt.token_blacklist.models import BlacklistedToken, OutstandingToken
from rest_framework_simplejwt.views import TokenRefreshView, TokenViewBase

logger = logging.getLogger(__name__)


class TokenExistsError(APIException):
    status_code = status.HTTP_409_CONFLICT
    default_detail = 'You already have a valid token. Please revoke it before creating a new one.'
    default_code = 'token_exists'


@method_decorator(
    name='get',
    decorator=extend_schema(
        tags=['JWT'],
        summary='Retrieve JWT Settings',
        description='Retrieve JWT settings for the currently active organization.',
        extensions={
            'x-fern-sdk-group-name': 'jwt_settings',
            'x-fern-sdk-method-name': 'get',
            'x-fern-audiences': ['public'],
        },
    ),
)
@method_decorator(
    name='post',
    decorator=extend_schema(
        tags=['JWT'],
        summary='Update JWT Settings',
        description='Update JWT settings for the currently active organization.',
        extensions={
            'x-fern-sdk-group-name': 'jwt_settings',
            'x-fern-sdk-method-name': 'update',
            'x-fern-audiences': ['public'],
        },
    ),
)
class JWTSettingsAPI(CreateAPIView):
    serializer_class = JWTSettingsSerializer
    permission_required = ViewClassPermission(
        GET=all_permissions.organizations_view,
        POST=all_permissions.organizations_change,
    )

    def get_object(self):
        jwt = self.request.user.active_organization.jwt
        self.check_object_permissions(self.request, jwt)
        return jwt

    def get(self, request, *args, **kwargs):
        jwt_settings = self.get_object()
        return Response(self.get_serializer(jwt_settings).data)

    def post(self, request, *args, **kwargs):
        jwt_settings = self.get_object()
        serializer = self.get_serializer(data=request.data, instance=jwt_settings)
        serializer.is_valid(raise_exception=True)
        serializer.save()
        return Response(serializer.data)


class DecoratedTokenRefreshView(TokenRefreshView):
    @extend_schema(
        tags=['JWT'],
        summary='Refresh JWT token',
        description='Get a new access token, using a refresh token.',
        responses={
            status.HTTP_200_OK: TokenRefreshResponseSerializer,
        },
        extensions={
            'x-fern-sdk-group-name': 'tokens',
            'x-fern-sdk-method-name': 'refresh',
            'x-fern-audiences': ['public'],
        },
    )
    def post(self, request, *args, **kwargs):
        return super().post(request, *args, **kwargs)


@method_decorator(
    name='get',
    decorator=extend_schema(
        tags=['JWT'],
        summary='List API tokens',
        description='List all API tokens for the current user.',
        responses={
            status.HTTP_200_OK: LSAPITokenListSerializer,
        },
        extensions={
            'x-fern-sdk-group-name': 'tokens',
            'x-fern-sdk-method-name': 'list',
            'x-fern-audiences': ['public'],
        },
    ),
)
@method_decorator(
    name='post',
    decorator=extend_schema(
        tags=['JWT'],
        summary='Create API token',
        description='Create a new API token for the current user.',
        responses={
            status.HTTP_201_CREATED: LSAPITokenCreateSerializer,
        },
        extensions={
            'x-fern-sdk-group-name': 'tokens',
            'x-fern-sdk-method-name': 'create',
            'x-fern-audiences': ['public'],
        },
    ),
)
class LSAPITokenView(generics.ListCreateAPIView):
    permission_required = all_permissions.users_token_any
    token_class = LSAPIToken

    def get_queryset(self):
        """Returns all non-expired non-blacklisted tokens for the current user.

        The `list` method handles filtering for refresh tokens (as opposed to access tokens),
        since simple-jwt makes it hard to do this at the DB level."""
        # Notably, if the list of non-expired blacklisted tokens ever gets too long
        # (e.g. users from orgs who have not set a token expiration for their org
        # revoke enough tokens for this to blow up), this will become inefficient.
        # Would be ideal to just add a "blacklisted" attr to our own subclass of
        # OutstandingToken so we can check at that level, or just clean up
        # OutstandingTokens that have been blacklisted every so often.
        current_blacklisted_tokens = BlacklistedToken.objects.filter(token__expires_at__gt=datetime.now()).values_list(
            'token_id', flat=True
        )
        return OutstandingToken.objects.filter(user_id=self.request.user.id, expires_at__gt=datetime.now()).exclude(
            id__in=current_blacklisted_tokens
        )

    def list(self, request, *args, **kwargs):
        all_tokens = self.get_queryset()

        def _maybe_get_token(token: OutstandingToken):
            try:
                return TruncatedLSAPIToken(str(token.token))
            except (TokenError, TokenBackendError) as e:  # expired/invalid token
                logger.debug('JWT API token validation failed: %s', e)
                return None

        # Annoyingly, token_type not stored directly so we have to filter it here.
        # Shouldn't be many unexpired tokens to iterate through.
        token_objects = list(filter(None, [_maybe_get_token(token) for token in all_tokens]))
        refresh_tokens = [tok for tok in token_objects if tok['token_type'] == 'refresh']

        serializer = self.get_serializer(refresh_tokens, many=True)
        data = serializer.data
        return Response(data)

    def get_serializer_class(self):
        if self.request.method == 'POST':
            return LSAPITokenCreateSerializer
        return LSAPITokenListSerializer

    def perform_create(self, serializer):
        # Check for existing valid tokens
        existing_tokens = self.get_queryset()
        if existing_tokens.exists():
            raise TokenExistsError()

        token = self.token_class.for_user(self.request.user)
        serializer.instance = token


class LSTokenBlacklistView(TokenViewBase):
    _serializer_class = 'jwt_auth.serializers.LSAPITokenBlacklistSerializer'

    @extend_schema(
        tags=['JWT'],
        summary='Blacklist a JWT refresh token',
        description='Adds a JWT refresh token to the blacklist, preventing it from being used to obtain new access tokens.',
        responses={
            status.HTTP_204_NO_CONTENT: 'Token was successfully blacklisted',
            status.HTTP_404_NOT_FOUND: 'Token is already blacklisted',
        },
        extensions={
            'x-fern-sdk-group-name': 'tokens',
            'x-fern-sdk-method-name': 'blacklist',
            'x-fern-audiences': ['public'],
        },
    )
    def post(self, request, *args, **kwargs):
        serializer = self.get_serializer(data=request.data)
        try:
            # Notably, simple jwt's serializer (which we inherit from) calls
            # .blacklist() on the token under the hood
            serializer.is_valid(raise_exception=True)
        except TokenError as e:
            logger.error('Token error occurred while trying to blacklist a token: %s', str(e), exc_info=True)
            return Response({'detail': 'Token is invalid or already blacklisted.'}, status=status.HTTP_404_NOT_FOUND)

        return Response(status=status.HTTP_204_NO_CONTENT)


class LSAPITokenRotateView(TokenViewBase):
    # Have to explicitly set authentication_classes here, due to how auth works in our middleware, request.user is not set
    # properly before executing the view.
    authentication_classes = [JWTAuthentication, TokenAuthenticationPhaseout, SessionAuthentication]
    permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
    permission_required = all_permissions.users_token_any
    _serializer_class = 'jwt_auth.serializers.LSAPITokenRotateSerializer'
    token_class = LSAPIToken

    @extend_schema(
        tags=['JWT'],
        summary='Rotate JWT refresh token',
        description='Creates a new JWT refresh token and blacklists the current one.',
        responses={
            status.HTTP_200_OK: TokenRotateResponseSerializer,
            status.HTTP_400_BAD_REQUEST: 'Invalid token or token already blacklisted',
        },
        extensions={
            'x-fern-sdk-group-name': 'tokens',
            'x-fern-sdk-method-name': 'rotate',
            'x-fern-audiences': ['public'],
        },
    )
    def post(self, request, *args, **kwargs):
        serializer = self.get_serializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        current_token = serializer.validated_data['refresh']

        # Blacklist the current token
        try:
            current_token.blacklist()
        except TokenError:
            return Response({'detail': 'Token is invalid or already blacklisted.'}, status=status.HTTP_400_BAD_REQUEST)

        # Create a new token for the user
        new_token = self.create_token(request.user)
        return Response({'refresh': new_token.get_full_jwt()}, status=status.HTTP_200_OK)

    def create_token(self, user):
        """Create a new token for the user. Can be overridden by child classes to use different token classes."""
        return self.token_class.for_user(user)
